diff --git a/my_dataset.py b/my_dataset.py deleted file mode 100644 index ca420f3..0000000 --- a/my_dataset.py +++ /dev/null @@ -1,372 +0,0 @@ -import os -import glob -import json -import numpy as np -import open3d as o3d -from typing import Dict, List, Optional, Union -from dataclasses import dataclass, field -from tqdm import tqdm - -@dataclass -class FrameData: - """单帧原始数据,完全未经过数学转换""" - frame_index: int # 帧编号(从0开始) - lidar_points: np.ndarray # (N, 3) float32,原始点云坐标 - image_paths: Dict[str, str] # 视角名 -> 图像路径 - camera_raw: Dict[str, dict] # 视角名称 -> { - # "projection_matrix": List[float], # 4x4 投影矩阵 - # "image_size": List[int] # [width, height] - # } - labels: List[dict] # 原始3D框标注,每个框包含: - # - "num": int # 框ID - # - "color": str - # - "class": str # 原始类别名 - # - "points": List[List[float]] # 8个点的xyz坐标 - # - "center": List[float] # [x,y,z] - # - "rotateZ": float # 绕Z轴旋转角,弧度 - # # 正方向:从X正半轴转向Y正半轴(逆时针) - # - "dx": float - # - "dy": float - # - "dz": float - contain_labels: bool = False # 是否包含标注框 - -@dataclass -class SampleData: - """一个 sample 文件夹包含的全部帧""" - sample_id: str # 数字ID - frames: Dict[int, FrameData] # 帧索引 -> FrameData - -class MyDataset: - f""" - 私有数据集加载。点云坐标系(世界坐标系)Y轴正方向为前,X轴正方向为右,Z轴正方向为上。 - 数据组织规范: - root_folder/ - sample{id}/ - ├── pcd_sequence{id}/ # 点云帧 - ├── 0/, 1/, .../ # 各帧图像 - ├── test{id}.json # 相机参数 - └── test{id}-mark.json # 3D标注 - """ - def __init__(self, root_folder: str): - self.root = root_folder - self._samples: Dict[str, SampleData] = {} - - self.color2class = { - "#5414ED": "car", - "#F6EE64": "pick-up-truck", - "#F6A087": "small-truck", - "#BC4EF1": "truck", - "#4E9AF1": "bus", - "#F1A94E": "special-vehicle", - "#E1DFDD": "ignore", - "#F91906": "tricyclist-withrider", - "#FA5F51": "tricyclist-withoutrider", - "#B8CB30": "bicycle-withrider", - "#E6FD4E": "bicycle-withoutrider", - "#876363": "people", - "#2CBDF5": "crowd-people", - "#C9F52C": "crowd-bicycle", - "#DC6788": "crowd-car", - "#6EC913": "traffic-cone", - "#0DDE69": "plastic-barrier", - "#8260D2": "crash-barrels", - "#F1D1D1": "warning-triangle", - "#FE6DF4": "crowd-traffic-cone", - "#D1AA35": "crowd-plastic-barrier", - "#3BE8D0": "crowd-crash-barrels", - "#2B7567": "crowd-warning-triangle" - } - self.class2color = {cls: color for color, cls in self.color2class.items()} - - self.cam2filename = { - "front_120": "scanofilm_surround_front_120_8M.jpg", - "front_left_100": "scanofilm_surround_front_left_100_2M.jpg", - "front_right_100": "scanofilm_surround_front_right_100_2M.jpg", - "rear_100": "scanofilm_surround_rear_100_2M.jpg", - "rear_left_100": "scanofilm_surround_rear_left_100_2M.jpg", - "rear_right_100": "scanofilm_surround_rear_right_100_2M.jpg", - } - - self._load_all_samples() - # ------------------------------------------------------------------ - # 路径与ID辅助方法 - # ------------------------------------------------------------------ - @staticmethod - def _extract_id(folder_name: str) -> str: - """从文件夹名提取连续数字作为sample_id""" - return ''.join(c for c in os.path.basename(folder_name) if c.isdigit()) - - def _find_sample_folders(self) -> List[str]: - """返回所有 sample{id} 文件夹路径""" - pattern = os.path.join(self.root, "sample*") - return [f for f in glob.glob(pattern) if os.path.isdir(f)] - - def _find_paths_in_sample(self, sample_folder: str, sample_id: str) -> dict: - """返回 sample 内部各资源路径字典""" - pcd_folder = os.path.join(sample_folder, f"pcd_sequence{sample_id}") - img_folders = [] - if os.path.isdir(pcd_folder): - pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd")) - num_frames = len(pcd_files) - img_folders = [os.path.join(sample_folder, str(i)) for i in range(num_frames)] - page_json = os.path.join(sample_folder, f"test{sample_id}.json") - mark_json = os.path.join(sample_folder, f"test{sample_id}-mark.json") - return { - "pcd_folder": pcd_folder, - "img_folders": img_folders, - "page_json": page_json, - "mark_json": mark_json - } - - # ------------------------------------------------------------------ - # 原始数据读取(无转换) - # ------------------------------------------------------------------ - def _read_lidar_points(self, pcd_path: str) -> np.ndarray: - try: - # Convert path to bytes for better Chinese path support on Windows - pcd_path_bytes = pcd_path.encode('utf-8') if isinstance(pcd_path, str) else pcd_path - pcd = o3d.io.read_point_cloud( - pcd_path_bytes, - remove_nan_points=True, - remove_infinite_points=True, - format="auto" - ) - pcd_np = np.asarray(pcd.points, dtype=np.float32) # [N, 3] - return pcd_np - - except Exception as e: - print(f"Error loading point cloud {pcd_path}: {e}") - return None - - def _load_all_lidar_frames(self, pcd_folder: str) -> Dict[int, np.ndarray]: - idx2pcd_np = {} - pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd")) - for pcd_file in pcd_files: - idx = int(os.path.basename(pcd_file).split(".")[0].split("n")[-1]) - pcd_np = self._read_lidar_points(pcd_file) - if pcd_np is not None: - idx2pcd_np[idx] = pcd_np - return idx2pcd_np - - def _load_image_paths(self, img_folders: List[str]) -> Dict[int, Dict[str, str]]: - """返回 帧索引 -> {视角名: 图像路径}""" - idx2imgs = {} - for folder in img_folders: - if not os.path.isdir(folder): - continue - idx = int(os.path.basename(folder)) - paths = {} - for cam, fname in self.cam2filename.items(): - full_path = os.path.join(folder, fname) - assert os.path.isfile(full_path), f"Image file not found: {full_path}" - paths[cam] = full_path - if paths: - idx2imgs[idx] = paths - return idx2imgs - - def _load_camera_raw(self, page_json_path: str) -> Dict[int, Dict[str, dict]]: - """返回 帧索引 -> {视角名: {"projection_matrix": list, "image_size": list}}""" - if not os.path.isfile(page_json_path): - return {} - with open(page_json_path, "r", encoding="utf-8") as f: - data = json.load(f) - try: - extend_source = data['data']['files'][0]['extendSources'] - except (KeyError, IndexError): - return {} - - idx2cam = {} - for d in extend_source: - page_element = d.get('pageElement', {}) - if "sensor" not in page_element: - continue - # 提取帧索引 - try: - idx = int(d['fileName'].split('.')[0]) - except: - continue - sensor = page_element['sensor'] - - cam_name = sensor.replace("ofilm_surround_", "") - for suffix in ["_2M", "_8M"]: - if cam_name.endswith(suffix): - cam_name = cam_name[:-len(suffix)] - break - - if cam_name not in self.cam2filename: - continue - - matrix = page_element.get('mtx', []) - image_size = page_element.get('imageSize', []) - if not matrix or not image_size: - continue - - matrix = np.array(matrix, dtype=np.float64) - assert matrix.shape == (3, 4), f"Invalid projection matrix shape: {matrix.shape}" - cam_info = { - "projection_matrix": matrix, - "image_size": image_size - } - if idx not in idx2cam: - idx2cam[idx] = {} - idx2cam[idx][cam_name] = cam_info - return idx2cam - - def extract_text_label_from_json(self, json_path: str) -> List[Dict]: - with open(json_path, 'r', encoding='utf-8') as f: - data = json.load(f)['data']['list'][0] - - type_dict = data['rate'] - idx2rate = {} - for key, value in type_dict.items(): - float_value = float(value['accuracy'].replace('%', '')) / 100.0 - idx2rate[key] = float_value - - # 找出最大准确率对应的索引 - max_idx = max(idx2rate, key=idx2rate.get) - max_rate = idx2rate[max_idx] - - if max_rate == 1.0: - result = data['result'] - label_list = result[max_idx] - - return label_list - else: - return [] - - def process_text_label(self, json_path: str): - label_list = self.extract_text_label_from_json(json_path) - - if len(label_list) == 0: - return {} - - idx2label: Dict[int, List[Dict]] = {} - for label in label_list: - idx2label[label['index']] = self.process_label_for_one_frame(label['value']) - return idx2label - - def process_label_for_one_frame(self, label_list: List[Dict]) -> List[Dict]: - if len(label_list) == 0: - return [] - - # there could be multiple 3D bounding boxes in one frame - labels = [] - for label in label_list: - new_label = {} - - # num is the id of the 3D bounding box - new_label['num'] = label['num'] - - # label - new_label['label'] = label['label'] - - # class - new_label['class'] = label['label']['class-name'] - - # is_moving - new_label['is_moving'] = not bool(label['label']['static']) - - # isolation - # True: 被建筑物,护栏,绿化带等隔离开 - # False: 可⾏驶区域 - new_label['isolation'] = bool(label['label']['isolation']) - - # color - new_label['color'] = self.class2color[new_label['class']] - - # corners - assert len(label['newPoints']) == 8, f"Points number for one 3D bounding box is 8, but got {len(label['newPoints'])}." - eight_points = [] - for i in range(8): - eight_points.append( - [ - float(label['newPoints'][i]['x']), - float(label['newPoints'][i]['y']), - float(label['newPoints'][i]['z']) - ] - ) - new_label['points'] = eight_points - - # center - new_label['center'] = [ - float(label['x']), - float(label['y']), - float(label['z']) - ] - - # rotate - 只保留绕Z轴的旋转 - # 单位:弧度 - # 正方向:从X轴正半轴转向Y轴正半轴(逆时针) - new_label['rotateZ'] = float(label['rotateZ']) - - new_label['dx'] = float(label['width']) - new_label['dy'] = float(label['height']) - new_label['dz'] = float(label['depth']) - - labels.append(new_label) - return labels - - # ------------------------------------------------------------------ - # Sample 级组装 - # ------------------------------------------------------------------ - def _load_one_sample(self, sample_folder: str) -> Optional[SampleData]: - """加载单个 sample 文件夹,返回 SampleData 对象""" - sample_id = self._extract_id(sample_folder) - paths = self._find_paths_in_sample(sample_folder, sample_id) - - # 读取各模态原始数据 - lidar_dict = self._load_all_lidar_frames(paths["pcd_folder"]) - img_dict = self._load_image_paths(paths["img_folders"]) - cam_dict = self._load_camera_raw(paths["page_json"]) - label_dict = self.process_text_label(paths["mark_json"]) - if not (len(label_dict) == len(lidar_dict) == len(img_dict) == len(cam_dict)): - print(f"Sample {sample_id}: Mismatch in frame counts between lidar, images, cameras, and labels.") - return None - - frames = {} - for idx in range(len(lidar_dict)): - frames[idx] = FrameData( - frame_index=idx, - lidar_points=lidar_dict[idx], - image_paths=img_dict[idx], - camera_raw=cam_dict[idx], - labels=label_dict[idx], - contain_labels=len(label_dict[idx]) > 0 - ) - if not frames: - return None - return SampleData(sample_id=sample_id, frames=frames) - - def _load_all_samples(self): - """遍历所有 sample 文件夹并加载""" - sample_folders = self._find_sample_folders() - process_bar = tqdm(sample_folders, desc="Loading samples") - for sf in process_bar: - sample = self._load_one_sample(sf) - if sample is not None: - self._samples[sample.sample_id] = sample - - print(f"Loaded {len(self._samples)} samples.") - - # ------------------------------------------------------------------ - # 公开接口 - # ------------------------------------------------------------------ - @property - def sample_ids(self) -> List[str]: - return list(self._samples.keys()) - - def get_sample(self, sample_id: str) -> Optional[SampleData]: - return self._samples.get(sample_id) - - def __len__(self) -> int: - return len(self._samples) - - def __getitem__(self, idx: int) -> SampleData: - """通过整数索引访问 sample(按sample_id排序)""" - sids = sorted(self.sample_ids) - if idx < 0 or idx >= len(sids): - raise IndexError(f"Index {idx} out of range [0, {len(sids)-1}]") - return self._samples[sids[idx]] - -if __name__ == "__main__": - dataset = MyDataset(root_folder="sample") \ No newline at end of file