Delete my_dataset.py

This commit is contained in:
2026-03-29 19:24:31 +08:00
parent 1750a1fb4f
commit f659fe5fb4

View File

@@ -1,372 +0,0 @@
import os
import glob
import json
import numpy as np
import open3d as o3d
from typing import Dict, List, Optional, Union
from dataclasses import dataclass, field
from tqdm import tqdm
@dataclass
class FrameData:
"""单帧原始数据,完全未经过数学转换"""
frame_index: int # 帧编号从0开始
lidar_points: np.ndarray # (N, 3) float32原始点云坐标
image_paths: Dict[str, str] # 视角名 -> 图像路径
camera_raw: Dict[str, dict] # 视角名称 -> {
# "projection_matrix": List[float], # 4x4 投影矩阵
# "image_size": List[int] # [width, height]
# }
labels: List[dict] # 原始3D框标注每个框包含
# - "num": int # 框ID
# - "color": str
# - "class": str # 原始类别名
# - "points": List[List[float]] # 8个点的xyz坐标
# - "center": List[float] # [x,y,z]
# - "rotateZ": float # 绕Z轴旋转角弧度
# # 正方向从X正半轴转向Y正半轴逆时针
# - "dx": float
# - "dy": float
# - "dz": float
contain_labels: bool = False # 是否包含标注框
@dataclass
class SampleData:
"""一个 sample 文件夹包含的全部帧"""
sample_id: str # 数字ID
frames: Dict[int, FrameData] # 帧索引 -> FrameData
class MyDataset:
f"""
私有数据集加载。点云坐标系世界坐标系Y轴正方向为前X轴正方向为右Z轴正方向为上。
数据组织规范:
root_folder/
sample{id}/
├── pcd_sequence{id}/ # 点云帧
├── 0/, 1/, .../ # 各帧图像
├── test{id}.json # 相机参数
└── test{id}-mark.json # 3D标注
"""
def __init__(self, root_folder: str):
self.root = root_folder
self._samples: Dict[str, SampleData] = {}
self.color2class = {
"#5414ED": "car",
"#F6EE64": "pick-up-truck",
"#F6A087": "small-truck",
"#BC4EF1": "truck",
"#4E9AF1": "bus",
"#F1A94E": "special-vehicle",
"#E1DFDD": "ignore",
"#F91906": "tricyclist-withrider",
"#FA5F51": "tricyclist-withoutrider",
"#B8CB30": "bicycle-withrider",
"#E6FD4E": "bicycle-withoutrider",
"#876363": "people",
"#2CBDF5": "crowd-people",
"#C9F52C": "crowd-bicycle",
"#DC6788": "crowd-car",
"#6EC913": "traffic-cone",
"#0DDE69": "plastic-barrier",
"#8260D2": "crash-barrels",
"#F1D1D1": "warning-triangle",
"#FE6DF4": "crowd-traffic-cone",
"#D1AA35": "crowd-plastic-barrier",
"#3BE8D0": "crowd-crash-barrels",
"#2B7567": "crowd-warning-triangle"
}
self.class2color = {cls: color for color, cls in self.color2class.items()}
self.cam2filename = {
"front_120": "scanofilm_surround_front_120_8M.jpg",
"front_left_100": "scanofilm_surround_front_left_100_2M.jpg",
"front_right_100": "scanofilm_surround_front_right_100_2M.jpg",
"rear_100": "scanofilm_surround_rear_100_2M.jpg",
"rear_left_100": "scanofilm_surround_rear_left_100_2M.jpg",
"rear_right_100": "scanofilm_surround_rear_right_100_2M.jpg",
}
self._load_all_samples()
# ------------------------------------------------------------------
# 路径与ID辅助方法
# ------------------------------------------------------------------
@staticmethod
def _extract_id(folder_name: str) -> str:
"""从文件夹名提取连续数字作为sample_id"""
return ''.join(c for c in os.path.basename(folder_name) if c.isdigit())
def _find_sample_folders(self) -> List[str]:
"""返回所有 sample{id} 文件夹路径"""
pattern = os.path.join(self.root, "sample*")
return [f for f in glob.glob(pattern) if os.path.isdir(f)]
def _find_paths_in_sample(self, sample_folder: str, sample_id: str) -> dict:
"""返回 sample 内部各资源路径字典"""
pcd_folder = os.path.join(sample_folder, f"pcd_sequence{sample_id}")
img_folders = []
if os.path.isdir(pcd_folder):
pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd"))
num_frames = len(pcd_files)
img_folders = [os.path.join(sample_folder, str(i)) for i in range(num_frames)]
page_json = os.path.join(sample_folder, f"test{sample_id}.json")
mark_json = os.path.join(sample_folder, f"test{sample_id}-mark.json")
return {
"pcd_folder": pcd_folder,
"img_folders": img_folders,
"page_json": page_json,
"mark_json": mark_json
}
# ------------------------------------------------------------------
# 原始数据读取(无转换)
# ------------------------------------------------------------------
def _read_lidar_points(self, pcd_path: str) -> np.ndarray:
try:
# Convert path to bytes for better Chinese path support on Windows
pcd_path_bytes = pcd_path.encode('utf-8') if isinstance(pcd_path, str) else pcd_path
pcd = o3d.io.read_point_cloud(
pcd_path_bytes,
remove_nan_points=True,
remove_infinite_points=True,
format="auto"
)
pcd_np = np.asarray(pcd.points, dtype=np.float32) # [N, 3]
return pcd_np
except Exception as e:
print(f"Error loading point cloud {pcd_path}: {e}")
return None
def _load_all_lidar_frames(self, pcd_folder: str) -> Dict[int, np.ndarray]:
idx2pcd_np = {}
pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd"))
for pcd_file in pcd_files:
idx = int(os.path.basename(pcd_file).split(".")[0].split("n")[-1])
pcd_np = self._read_lidar_points(pcd_file)
if pcd_np is not None:
idx2pcd_np[idx] = pcd_np
return idx2pcd_np
def _load_image_paths(self, img_folders: List[str]) -> Dict[int, Dict[str, str]]:
"""返回 帧索引 -> {视角名: 图像路径}"""
idx2imgs = {}
for folder in img_folders:
if not os.path.isdir(folder):
continue
idx = int(os.path.basename(folder))
paths = {}
for cam, fname in self.cam2filename.items():
full_path = os.path.join(folder, fname)
assert os.path.isfile(full_path), f"Image file not found: {full_path}"
paths[cam] = full_path
if paths:
idx2imgs[idx] = paths
return idx2imgs
def _load_camera_raw(self, page_json_path: str) -> Dict[int, Dict[str, dict]]:
"""返回 帧索引 -> {视角名: {"projection_matrix": list, "image_size": list}}"""
if not os.path.isfile(page_json_path):
return {}
with open(page_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
try:
extend_source = data['data']['files'][0]['extendSources']
except (KeyError, IndexError):
return {}
idx2cam = {}
for d in extend_source:
page_element = d.get('pageElement', {})
if "sensor" not in page_element:
continue
# 提取帧索引
try:
idx = int(d['fileName'].split('.')[0])
except:
continue
sensor = page_element['sensor']
cam_name = sensor.replace("ofilm_surround_", "")
for suffix in ["_2M", "_8M"]:
if cam_name.endswith(suffix):
cam_name = cam_name[:-len(suffix)]
break
if cam_name not in self.cam2filename:
continue
matrix = page_element.get('mtx', [])
image_size = page_element.get('imageSize', [])
if not matrix or not image_size:
continue
matrix = np.array(matrix, dtype=np.float64)
assert matrix.shape == (3, 4), f"Invalid projection matrix shape: {matrix.shape}"
cam_info = {
"projection_matrix": matrix,
"image_size": image_size
}
if idx not in idx2cam:
idx2cam[idx] = {}
idx2cam[idx][cam_name] = cam_info
return idx2cam
def extract_text_label_from_json(self, json_path: str) -> List[Dict]:
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)['data']['list'][0]
type_dict = data['rate']
idx2rate = {}
for key, value in type_dict.items():
float_value = float(value['accuracy'].replace('%', '')) / 100.0
idx2rate[key] = float_value
# 找出最大准确率对应的索引
max_idx = max(idx2rate, key=idx2rate.get)
max_rate = idx2rate[max_idx]
if max_rate == 1.0:
result = data['result']
label_list = result[max_idx]
return label_list
else:
return []
def process_text_label(self, json_path: str):
label_list = self.extract_text_label_from_json(json_path)
if len(label_list) == 0:
return {}
idx2label: Dict[int, List[Dict]] = {}
for label in label_list:
idx2label[label['index']] = self.process_label_for_one_frame(label['value'])
return idx2label
def process_label_for_one_frame(self, label_list: List[Dict]) -> List[Dict]:
if len(label_list) == 0:
return []
# there could be multiple 3D bounding boxes in one frame
labels = []
for label in label_list:
new_label = {}
# num is the id of the 3D bounding box
new_label['num'] = label['num']
# label
new_label['label'] = label['label']
# class
new_label['class'] = label['label']['class-name']
# is_moving
new_label['is_moving'] = not bool(label['label']['static'])
# isolation
# True: 被建筑物,护栏,绿化带等隔离开
# False: 可⾏驶区域
new_label['isolation'] = bool(label['label']['isolation'])
# color
new_label['color'] = self.class2color[new_label['class']]
# corners
assert len(label['newPoints']) == 8, f"Points number for one 3D bounding box is 8, but got {len(label['newPoints'])}."
eight_points = []
for i in range(8):
eight_points.append(
[
float(label['newPoints'][i]['x']),
float(label['newPoints'][i]['y']),
float(label['newPoints'][i]['z'])
]
)
new_label['points'] = eight_points
# center
new_label['center'] = [
float(label['x']),
float(label['y']),
float(label['z'])
]
# rotate - 只保留绕Z轴的旋转
# 单位:弧度
# 正方向从X轴正半轴转向Y轴正半轴逆时针
new_label['rotateZ'] = float(label['rotateZ'])
new_label['dx'] = float(label['width'])
new_label['dy'] = float(label['height'])
new_label['dz'] = float(label['depth'])
labels.append(new_label)
return labels
# ------------------------------------------------------------------
# Sample 级组装
# ------------------------------------------------------------------
def _load_one_sample(self, sample_folder: str) -> Optional[SampleData]:
"""加载单个 sample 文件夹,返回 SampleData 对象"""
sample_id = self._extract_id(sample_folder)
paths = self._find_paths_in_sample(sample_folder, sample_id)
# 读取各模态原始数据
lidar_dict = self._load_all_lidar_frames(paths["pcd_folder"])
img_dict = self._load_image_paths(paths["img_folders"])
cam_dict = self._load_camera_raw(paths["page_json"])
label_dict = self.process_text_label(paths["mark_json"])
if not (len(label_dict) == len(lidar_dict) == len(img_dict) == len(cam_dict)):
print(f"Sample {sample_id}: Mismatch in frame counts between lidar, images, cameras, and labels.")
return None
frames = {}
for idx in range(len(lidar_dict)):
frames[idx] = FrameData(
frame_index=idx,
lidar_points=lidar_dict[idx],
image_paths=img_dict[idx],
camera_raw=cam_dict[idx],
labels=label_dict[idx],
contain_labels=len(label_dict[idx]) > 0
)
if not frames:
return None
return SampleData(sample_id=sample_id, frames=frames)
def _load_all_samples(self):
"""遍历所有 sample 文件夹并加载"""
sample_folders = self._find_sample_folders()
process_bar = tqdm(sample_folders, desc="Loading samples")
for sf in process_bar:
sample = self._load_one_sample(sf)
if sample is not None:
self._samples[sample.sample_id] = sample
print(f"Loaded {len(self._samples)} samples.")
# ------------------------------------------------------------------
# 公开接口
# ------------------------------------------------------------------
@property
def sample_ids(self) -> List[str]:
return list(self._samples.keys())
def get_sample(self, sample_id: str) -> Optional[SampleData]:
return self._samples.get(sample_id)
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, idx: int) -> SampleData:
"""通过整数索引访问 sample按sample_id排序"""
sids = sorted(self.sample_ids)
if idx < 0 or idx >= len(sids):
raise IndexError(f"Index {idx} out of range [0, {len(sids)-1}]")
return self._samples[sids[idx]]
if __name__ == "__main__":
dataset = MyDataset(root_folder="sample")