import os import glob import json import numpy as np import open3d as o3d from typing import Dict, List, Optional, Union from dataclasses import dataclass, field from tqdm import tqdm @dataclass class FrameData: """单帧原始数据,完全未经过数学转换""" frame_index: int # 帧编号(从0开始) lidar_points: np.ndarray # (N, 3) float32,原始点云坐标 image_paths: Dict[str, str] # 视角名 -> 图像路径 camera_raw: Dict[str, dict] # 视角名称 -> { # "projection_matrix": List[float], # 4x4 投影矩阵 # "image_size": List[int] # [width, height] # } labels: List[dict] # 原始3D框标注,每个框包含: # - "num": int # 框ID # - "color": str # - "class": str # 原始类别名 # - "points": List[List[float]] # 8个点的xyz坐标 # - "center": List[float] # [x,y,z] # - "rotateZ": float # 绕Z轴旋转角,弧度 # # 正方向:从X正半轴转向Y正半轴(逆时针) # - "dx": float # - "dy": float # - "dz": float contain_labels: bool = False # 是否包含标注框 @dataclass class SampleData: """一个 sample 文件夹包含的全部帧""" sample_id: str # 数字ID frames: Dict[int, FrameData] # 帧索引 -> FrameData class MyDataset: f""" 私有数据集加载。点云坐标系(世界坐标系)Y轴正方向为前,X轴正方向为右,Z轴正方向为上。 数据组织规范: root_folder/ sample{id}/ ├── pcd_sequence{id}/ # 点云帧 ├── 0/, 1/, .../ # 各帧图像 ├── test{id}.json # 相机参数 └── test{id}-mark.json # 3D标注 """ def __init__(self, root_folder: str): self.root = root_folder self._samples: Dict[str, SampleData] = {} self.color2class = { "#5414ED": "car", "#F6EE64": "pick-up-truck", "#F6A087": "small-truck", "#BC4EF1": "truck", "#4E9AF1": "bus", "#F1A94E": "special-vehicle", "#E1DFDD": "ignore", "#F91906": "tricyclist-withrider", "#FA5F51": "tricyclist-withoutrider", "#B8CB30": "bicycle-withrider", "#E6FD4E": "bicycle-withoutrider", "#876363": "people", "#2CBDF5": "crowd-people", "#C9F52C": "crowd-bicycle", "#DC6788": "crowd-car", "#6EC913": "traffic-cone", "#0DDE69": "plastic-barrier", "#8260D2": "crash-barrels", "#F1D1D1": "warning-triangle", "#FE6DF4": "crowd-traffic-cone", "#D1AA35": "crowd-plastic-barrier", "#3BE8D0": "crowd-crash-barrels", "#2B7567": "crowd-warning-triangle" } self.class2color = {cls: color for color, cls in self.color2class.items()} self.cam2filename = { "front_120": "scanofilm_surround_front_120_8M.jpg", "front_left_100": "scanofilm_surround_front_left_100_2M.jpg", "front_right_100": "scanofilm_surround_front_right_100_2M.jpg", "rear_100": "scanofilm_surround_rear_100_2M.jpg", "rear_left_100": "scanofilm_surround_rear_left_100_2M.jpg", "rear_right_100": "scanofilm_surround_rear_right_100_2M.jpg", } self._load_all_samples() # ------------------------------------------------------------------ # 路径与ID辅助方法 # ------------------------------------------------------------------ @staticmethod def _extract_id(folder_name: str) -> str: """从文件夹名提取连续数字作为sample_id""" return ''.join(c for c in os.path.basename(folder_name) if c.isdigit()) def _find_sample_folders(self) -> List[str]: """返回所有 sample{id} 文件夹路径""" pattern = os.path.join(self.root, "sample*") return [f for f in glob.glob(pattern) if os.path.isdir(f)] def _find_paths_in_sample(self, sample_folder: str, sample_id: str) -> dict: """返回 sample 内部各资源路径字典""" pcd_folder = os.path.join(sample_folder, f"pcd_sequence{sample_id}") img_folders = [] if os.path.isdir(pcd_folder): pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd")) num_frames = len(pcd_files) img_folders = [os.path.join(sample_folder, str(i)) for i in range(num_frames)] page_json = os.path.join(sample_folder, f"test{sample_id}.json") mark_json = os.path.join(sample_folder, f"test{sample_id}-mark.json") return { "pcd_folder": pcd_folder, "img_folders": img_folders, "page_json": page_json, "mark_json": mark_json } # ------------------------------------------------------------------ # 原始数据读取(无转换) # ------------------------------------------------------------------ def _read_lidar_points(self, pcd_path: str) -> np.ndarray: try: # Convert path to bytes for better Chinese path support on Windows pcd_path_bytes = pcd_path.encode('utf-8') if isinstance(pcd_path, str) else pcd_path pcd = o3d.io.read_point_cloud( pcd_path_bytes, remove_nan_points=True, remove_infinite_points=True, format="auto" ) pcd_np = np.asarray(pcd.points, dtype=np.float32) # [N, 3] return pcd_np except Exception as e: print(f"Error loading point cloud {pcd_path}: {e}") return None def _load_all_lidar_frames(self, pcd_folder: str) -> Dict[int, np.ndarray]: idx2pcd_np = {} pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd")) for pcd_file in pcd_files: idx = int(os.path.basename(pcd_file).split(".")[0].split("n")[-1]) pcd_np = self._read_lidar_points(pcd_file) if pcd_np is not None: idx2pcd_np[idx] = pcd_np return idx2pcd_np def _load_image_paths(self, img_folders: List[str]) -> Dict[int, Dict[str, str]]: """返回 帧索引 -> {视角名: 图像路径}""" idx2imgs = {} for folder in img_folders: if not os.path.isdir(folder): continue idx = int(os.path.basename(folder)) paths = {} for cam, fname in self.cam2filename.items(): full_path = os.path.join(folder, fname) assert os.path.isfile(full_path), f"Image file not found: {full_path}" paths[cam] = full_path if paths: idx2imgs[idx] = paths return idx2imgs def _load_camera_raw(self, page_json_path: str) -> Dict[int, Dict[str, dict]]: """返回 帧索引 -> {视角名: {"projection_matrix": list, "image_size": list}}""" if not os.path.isfile(page_json_path): return {} with open(page_json_path, "r", encoding="utf-8") as f: data = json.load(f) try: extend_source = data['data']['files'][0]['extendSources'] except (KeyError, IndexError): return {} idx2cam = {} for d in extend_source: page_element = d.get('pageElement', {}) if "sensor" not in page_element: continue # 提取帧索引 try: idx = int(d['fileName'].split('.')[0]) except: continue sensor = page_element['sensor'] cam_name = sensor.replace("ofilm_surround_", "") for suffix in ["_2M", "_8M"]: if cam_name.endswith(suffix): cam_name = cam_name[:-len(suffix)] break if cam_name not in self.cam2filename: continue matrix = page_element.get('mtx', []) image_size = page_element.get('imageSize', []) if not matrix or not image_size: continue matrix = np.array(matrix, dtype=np.float64) assert matrix.shape == (3, 4), f"Invalid projection matrix shape: {matrix.shape}" cam_info = { "projection_matrix": matrix, "image_size": image_size } if idx not in idx2cam: idx2cam[idx] = {} idx2cam[idx][cam_name] = cam_info return idx2cam def extract_text_label_from_json(self, json_path: str) -> List[Dict]: with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f)['data']['list'][0] type_dict = data['rate'] idx2rate = {} for key, value in type_dict.items(): float_value = float(value['accuracy'].replace('%', '')) / 100.0 idx2rate[key] = float_value # 找出最大准确率对应的索引 max_idx = max(idx2rate, key=idx2rate.get) max_rate = idx2rate[max_idx] if max_rate == 1.0: result = data['result'] label_list = result[max_idx] return label_list else: return [] def process_text_label(self, json_path: str): label_list = self.extract_text_label_from_json(json_path) if len(label_list) == 0: return {} idx2label: Dict[int, List[Dict]] = {} for label in label_list: idx2label[label['index']] = self.process_label_for_one_frame(label['value']) return idx2label def process_label_for_one_frame(self, label_list: List[Dict]) -> List[Dict]: if len(label_list) == 0: return [] # there could be multiple 3D bounding boxes in one frame labels = [] for label in label_list: new_label = {} # num is the id of the 3D bounding box new_label['num'] = label['num'] # label new_label['label'] = label['label'] # class new_label['class'] = label['label']['class-name'] # is_moving new_label['is_moving'] = not bool(label['label']['static']) # isolation # True: 被建筑物,护栏,绿化带等隔离开 # False: 可⾏驶区域 new_label['isolation'] = bool(label['label']['isolation']) # color new_label['color'] = self.class2color[new_label['class']] # corners assert len(label['newPoints']) == 8, f"Points number for one 3D bounding box is 8, but got {len(label['newPoints'])}." eight_points = [] for i in range(8): eight_points.append( [ float(label['newPoints'][i]['x']), float(label['newPoints'][i]['y']), float(label['newPoints'][i]['z']) ] ) new_label['points'] = eight_points # center new_label['center'] = [ float(label['x']), float(label['y']), float(label['z']) ] # rotate - 只保留绕Z轴的旋转 # 单位:弧度 # 正方向:从X轴正半轴转向Y轴正半轴(逆时针) new_label['rotateZ'] = float(label['rotateZ']) new_label['dx'] = float(label['width']) new_label['dy'] = float(label['height']) new_label['dz'] = float(label['depth']) labels.append(new_label) return labels # ------------------------------------------------------------------ # Sample 级组装 # ------------------------------------------------------------------ def _load_one_sample(self, sample_folder: str) -> Optional[SampleData]: """加载单个 sample 文件夹,返回 SampleData 对象""" sample_id = self._extract_id(sample_folder) paths = self._find_paths_in_sample(sample_folder, sample_id) # 读取各模态原始数据 lidar_dict = self._load_all_lidar_frames(paths["pcd_folder"]) img_dict = self._load_image_paths(paths["img_folders"]) cam_dict = self._load_camera_raw(paths["page_json"]) label_dict = self.process_text_label(paths["mark_json"]) if not (len(label_dict) == len(lidar_dict) == len(img_dict) == len(cam_dict)): print(f"Sample {sample_id}: Mismatch in frame counts between lidar, images, cameras, and labels.") return None frames = {} for idx in range(len(lidar_dict)): frames[idx] = FrameData( frame_index=idx, lidar_points=lidar_dict[idx], image_paths=img_dict[idx], camera_raw=cam_dict[idx], labels=label_dict[idx], contain_labels=len(label_dict[idx]) > 0 ) if not frames: return None return SampleData(sample_id=sample_id, frames=frames) def _load_all_samples(self): """遍历所有 sample 文件夹并加载""" sample_folders = self._find_sample_folders() process_bar = tqdm(sample_folders, desc="Loading samples") for sf in process_bar: sample = self._load_one_sample(sf) if sample is not None: self._samples[sample.sample_id] = sample print(f"Loaded {len(self._samples)} samples.") # ------------------------------------------------------------------ # 公开接口 # ------------------------------------------------------------------ @property def sample_ids(self) -> List[str]: return list(self._samples.keys()) def get_sample(self, sample_id: str) -> Optional[SampleData]: return self._samples.get(sample_id) def __len__(self) -> int: return len(self._samples) def __getitem__(self, idx: int) -> SampleData: """通过整数索引访问 sample(按sample_id排序)""" sids = sorted(self.sample_ids) if idx < 0 or idx >= len(sids): raise IndexError(f"Index {idx} out of range [0, {len(sids)-1}]") return self._samples[sids[idx]] if __name__ == "__main__": dataset = MyDataset(root_folder="sample")