Files
puyang_time_sequence/my_dataset.py
2026-03-12 15:51:21 +08:00

372 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import glob
import json
import numpy as np
import open3d as o3d
from typing import Dict, List, Optional, Union
from dataclasses import dataclass, field
from tqdm import tqdm
@dataclass
class FrameData:
"""单帧原始数据,完全未经过数学转换"""
frame_index: int # 帧编号从0开始
lidar_points: np.ndarray # (N, 3) float32原始点云坐标
image_paths: Dict[str, str] # 视角名 -> 图像路径
camera_raw: Dict[str, dict] # 视角名称 -> {
# "projection_matrix": List[float], # 4x4 投影矩阵
# "image_size": List[int] # [width, height]
# }
labels: List[dict] # 原始3D框标注每个框包含
# - "num": int # 框ID
# - "color": str
# - "class": str # 原始类别名
# - "points": List[List[float]] # 8个点的xyz坐标
# - "center": List[float] # [x,y,z]
# - "rotateZ": float # 绕Z轴旋转角弧度
# # 正方向从X正半轴转向Y正半轴逆时针
# - "dx": float
# - "dy": float
# - "dz": float
contain_labels: bool = False # 是否包含标注框
@dataclass
class SampleData:
"""一个 sample 文件夹包含的全部帧"""
sample_id: str # 数字ID
frames: Dict[int, FrameData] # 帧索引 -> FrameData
class MyDataset:
f"""
私有数据集加载。点云坐标系世界坐标系Y轴正方向为前X轴正方向为右Z轴正方向为上。
数据组织规范:
root_folder/
sample{id}/
├── pcd_sequence{id}/ # 点云帧
├── 0/, 1/, .../ # 各帧图像
├── test{id}.json # 相机参数
└── test{id}-mark.json # 3D标注
"""
def __init__(self, root_folder: str):
self.root = root_folder
self._samples: Dict[str, SampleData] = {}
self.color2class = {
"#5414ED": "car",
"#F6EE64": "pick-up-truck",
"#F6A087": "small-truck",
"#BC4EF1": "truck",
"#4E9AF1": "bus",
"#F1A94E": "special-vehicle",
"#E1DFDD": "ignore",
"#F91906": "tricyclist-withrider",
"#FA5F51": "tricyclist-withoutrider",
"#B8CB30": "bicycle-withrider",
"#E6FD4E": "bicycle-withoutrider",
"#876363": "people",
"#2CBDF5": "crowd-people",
"#C9F52C": "crowd-bicycle",
"#DC6788": "crowd-car",
"#6EC913": "traffic-cone",
"#0DDE69": "plastic-barrier",
"#8260D2": "crash-barrels",
"#F1D1D1": "warning-triangle",
"#FE6DF4": "crowd-traffic-cone",
"#D1AA35": "crowd-plastic-barrier",
"#3BE8D0": "crowd-crash-barrels",
"#2B7567": "crowd-warning-triangle"
}
self.class2color = {cls: color for color, cls in self.color2class.items()}
self.cam2filename = {
"front_120": "scanofilm_surround_front_120_8M.jpg",
"front_left_100": "scanofilm_surround_front_left_100_2M.jpg",
"front_right_100": "scanofilm_surround_front_right_100_2M.jpg",
"rear_100": "scanofilm_surround_rear_100_2M.jpg",
"rear_left_100": "scanofilm_surround_rear_left_100_2M.jpg",
"rear_right_100": "scanofilm_surround_rear_right_100_2M.jpg",
}
self._load_all_samples()
# ------------------------------------------------------------------
# 路径与ID辅助方法
# ------------------------------------------------------------------
@staticmethod
def _extract_id(folder_name: str) -> str:
"""从文件夹名提取连续数字作为sample_id"""
return ''.join(c for c in os.path.basename(folder_name) if c.isdigit())
def _find_sample_folders(self) -> List[str]:
"""返回所有 sample{id} 文件夹路径"""
pattern = os.path.join(self.root, "sample*")
return [f for f in glob.glob(pattern) if os.path.isdir(f)]
def _find_paths_in_sample(self, sample_folder: str, sample_id: str) -> dict:
"""返回 sample 内部各资源路径字典"""
pcd_folder = os.path.join(sample_folder, f"pcd_sequence{sample_id}")
img_folders = []
if os.path.isdir(pcd_folder):
pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd"))
num_frames = len(pcd_files)
img_folders = [os.path.join(sample_folder, str(i)) for i in range(num_frames)]
page_json = os.path.join(sample_folder, f"test{sample_id}.json")
mark_json = os.path.join(sample_folder, f"test{sample_id}-mark.json")
return {
"pcd_folder": pcd_folder,
"img_folders": img_folders,
"page_json": page_json,
"mark_json": mark_json
}
# ------------------------------------------------------------------
# 原始数据读取(无转换)
# ------------------------------------------------------------------
def _read_lidar_points(self, pcd_path: str) -> np.ndarray:
try:
# Convert path to bytes for better Chinese path support on Windows
pcd_path_bytes = pcd_path.encode('utf-8') if isinstance(pcd_path, str) else pcd_path
pcd = o3d.io.read_point_cloud(
pcd_path_bytes,
remove_nan_points=True,
remove_infinite_points=True,
format="auto"
)
pcd_np = np.asarray(pcd.points, dtype=np.float32) # [N, 3]
return pcd_np
except Exception as e:
print(f"Error loading point cloud {pcd_path}: {e}")
return None
def _load_all_lidar_frames(self, pcd_folder: str) -> Dict[int, np.ndarray]:
idx2pcd_np = {}
pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd"))
for pcd_file in pcd_files:
idx = int(os.path.basename(pcd_file).split(".")[0].split("n")[-1])
pcd_np = self._read_lidar_points(pcd_file)
if pcd_np is not None:
idx2pcd_np[idx] = pcd_np
return idx2pcd_np
def _load_image_paths(self, img_folders: List[str]) -> Dict[int, Dict[str, str]]:
"""返回 帧索引 -> {视角名: 图像路径}"""
idx2imgs = {}
for folder in img_folders:
if not os.path.isdir(folder):
continue
idx = int(os.path.basename(folder))
paths = {}
for cam, fname in self.cam2filename.items():
full_path = os.path.join(folder, fname)
assert os.path.isfile(full_path), f"Image file not found: {full_path}"
paths[cam] = full_path
if paths:
idx2imgs[idx] = paths
return idx2imgs
def _load_camera_raw(self, page_json_path: str) -> Dict[int, Dict[str, dict]]:
"""返回 帧索引 -> {视角名: {"projection_matrix": list, "image_size": list}}"""
if not os.path.isfile(page_json_path):
return {}
with open(page_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
try:
extend_source = data['data']['files'][0]['extendSources']
except (KeyError, IndexError):
return {}
idx2cam = {}
for d in extend_source:
page_element = d.get('pageElement', {})
if "sensor" not in page_element:
continue
# 提取帧索引
try:
idx = int(d['fileName'].split('.')[0])
except:
continue
sensor = page_element['sensor']
cam_name = sensor.replace("ofilm_surround_", "")
for suffix in ["_2M", "_8M"]:
if cam_name.endswith(suffix):
cam_name = cam_name[:-len(suffix)]
break
if cam_name not in self.cam2filename:
continue
matrix = page_element.get('mtx', [])
image_size = page_element.get('imageSize', [])
if not matrix or not image_size:
continue
matrix = np.array(matrix, dtype=np.float64)
assert matrix.shape == (3, 4), f"Invalid projection matrix shape: {matrix.shape}"
cam_info = {
"projection_matrix": matrix,
"image_size": image_size
}
if idx not in idx2cam:
idx2cam[idx] = {}
idx2cam[idx][cam_name] = cam_info
return idx2cam
def extract_text_label_from_json(self, json_path: str) -> List[Dict]:
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)['data']['list'][0]
type_dict = data['rate']
idx2rate = {}
for key, value in type_dict.items():
float_value = float(value['accuracy'].replace('%', '')) / 100.0
idx2rate[key] = float_value
# 找出最大准确率对应的索引
max_idx = max(idx2rate, key=idx2rate.get)
max_rate = idx2rate[max_idx]
if max_rate == 1.0:
result = data['result']
label_list = result[max_idx]
return label_list
else:
return []
def process_text_label(self, json_path: str):
label_list = self.extract_text_label_from_json(json_path)
if len(label_list) == 0:
return {}
idx2label: Dict[int, List[Dict]] = {}
for label in label_list:
idx2label[label['index']] = self.process_label_for_one_frame(label['value'])
return idx2label
def process_label_for_one_frame(self, label_list: List[Dict]) -> List[Dict]:
if len(label_list) == 0:
return []
# there could be multiple 3D bounding boxes in one frame
labels = []
for label in label_list:
new_label = {}
# num is the id of the 3D bounding box
new_label['num'] = label['num']
# label
new_label['label'] = label['label']
# class
new_label['class'] = label['label']['class-name']
# is_moving
new_label['is_moving'] = not bool(label['label']['static'])
# isolation
# True: 被建筑物,护栏,绿化带等隔离开
# False: 可⾏驶区域
new_label['isolation'] = bool(label['label']['isolation'])
# color
new_label['color'] = self.class2color[new_label['class']]
# corners
assert len(label['newPoints']) == 8, f"Points number for one 3D bounding box is 8, but got {len(label['newPoints'])}."
eight_points = []
for i in range(8):
eight_points.append(
[
float(label['newPoints'][i]['x']),
float(label['newPoints'][i]['y']),
float(label['newPoints'][i]['z'])
]
)
new_label['points'] = eight_points
# center
new_label['center'] = [
float(label['x']),
float(label['y']),
float(label['z'])
]
# rotate - 只保留绕Z轴的旋转
# 单位:弧度
# 正方向从X轴正半轴转向Y轴正半轴逆时针
new_label['rotateZ'] = float(label['rotateZ'])
new_label['dx'] = float(label['width'])
new_label['dy'] = float(label['height'])
new_label['dz'] = float(label['depth'])
labels.append(new_label)
return labels
# ------------------------------------------------------------------
# Sample 级组装
# ------------------------------------------------------------------
def _load_one_sample(self, sample_folder: str) -> Optional[SampleData]:
"""加载单个 sample 文件夹,返回 SampleData 对象"""
sample_id = self._extract_id(sample_folder)
paths = self._find_paths_in_sample(sample_folder, sample_id)
# 读取各模态原始数据
lidar_dict = self._load_all_lidar_frames(paths["pcd_folder"])
img_dict = self._load_image_paths(paths["img_folders"])
cam_dict = self._load_camera_raw(paths["page_json"])
label_dict = self.process_text_label(paths["mark_json"])
if not (len(label_dict) == len(lidar_dict) == len(img_dict) == len(cam_dict)):
print(f"Sample {sample_id}: Mismatch in frame counts between lidar, images, cameras, and labels.")
return None
frames = {}
for idx in range(len(lidar_dict)):
frames[idx] = FrameData(
frame_index=idx,
lidar_points=lidar_dict[idx],
image_paths=img_dict[idx],
camera_raw=cam_dict[idx],
labels=label_dict[idx],
contain_labels=len(label_dict[idx]) > 0
)
if not frames:
return None
return SampleData(sample_id=sample_id, frames=frames)
def _load_all_samples(self):
"""遍历所有 sample 文件夹并加载"""
sample_folders = self._find_sample_folders()
process_bar = tqdm(sample_folders, desc="Loading samples")
for sf in process_bar:
sample = self._load_one_sample(sf)
if sample is not None:
self._samples[sample.sample_id] = sample
print(f"Loaded {len(self._samples)} samples.")
# ------------------------------------------------------------------
# 公开接口
# ------------------------------------------------------------------
@property
def sample_ids(self) -> List[str]:
return list(self._samples.keys())
def get_sample(self, sample_id: str) -> Optional[SampleData]:
return self._samples.get(sample_id)
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, idx: int) -> SampleData:
"""通过整数索引访问 sample按sample_id排序"""
sids = sorted(self.sample_ids)
if idx < 0 or idx >= len(sids):
raise IndexError(f"Index {idx} out of range [0, {len(sids)-1}]")
return self._samples[sids[idx]]
if __name__ == "__main__":
dataset = MyDataset(root_folder="sample")