Delete my_dataset.py
This commit is contained in:
372
my_dataset.py
372
my_dataset.py
@@ -1,372 +0,0 @@
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
import open3d as o3d
|
||||
from typing import Dict, List, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from tqdm import tqdm
|
||||
|
||||
@dataclass
|
||||
class FrameData:
|
||||
"""单帧原始数据,完全未经过数学转换"""
|
||||
frame_index: int # 帧编号(从0开始)
|
||||
lidar_points: np.ndarray # (N, 3) float32,原始点云坐标
|
||||
image_paths: Dict[str, str] # 视角名 -> 图像路径
|
||||
camera_raw: Dict[str, dict] # 视角名称 -> {
|
||||
# "projection_matrix": List[float], # 4x4 投影矩阵
|
||||
# "image_size": List[int] # [width, height]
|
||||
# }
|
||||
labels: List[dict] # 原始3D框标注,每个框包含:
|
||||
# - "num": int # 框ID
|
||||
# - "color": str
|
||||
# - "class": str # 原始类别名
|
||||
# - "points": List[List[float]] # 8个点的xyz坐标
|
||||
# - "center": List[float] # [x,y,z]
|
||||
# - "rotateZ": float # 绕Z轴旋转角,弧度
|
||||
# # 正方向:从X正半轴转向Y正半轴(逆时针)
|
||||
# - "dx": float
|
||||
# - "dy": float
|
||||
# - "dz": float
|
||||
contain_labels: bool = False # 是否包含标注框
|
||||
|
||||
@dataclass
|
||||
class SampleData:
|
||||
"""一个 sample 文件夹包含的全部帧"""
|
||||
sample_id: str # 数字ID
|
||||
frames: Dict[int, FrameData] # 帧索引 -> FrameData
|
||||
|
||||
class MyDataset:
|
||||
f"""
|
||||
私有数据集加载。点云坐标系(世界坐标系)Y轴正方向为前,X轴正方向为右,Z轴正方向为上。
|
||||
数据组织规范:
|
||||
root_folder/
|
||||
sample{id}/
|
||||
├── pcd_sequence{id}/ # 点云帧
|
||||
├── 0/, 1/, .../ # 各帧图像
|
||||
├── test{id}.json # 相机参数
|
||||
└── test{id}-mark.json # 3D标注
|
||||
"""
|
||||
def __init__(self, root_folder: str):
|
||||
self.root = root_folder
|
||||
self._samples: Dict[str, SampleData] = {}
|
||||
|
||||
self.color2class = {
|
||||
"#5414ED": "car",
|
||||
"#F6EE64": "pick-up-truck",
|
||||
"#F6A087": "small-truck",
|
||||
"#BC4EF1": "truck",
|
||||
"#4E9AF1": "bus",
|
||||
"#F1A94E": "special-vehicle",
|
||||
"#E1DFDD": "ignore",
|
||||
"#F91906": "tricyclist-withrider",
|
||||
"#FA5F51": "tricyclist-withoutrider",
|
||||
"#B8CB30": "bicycle-withrider",
|
||||
"#E6FD4E": "bicycle-withoutrider",
|
||||
"#876363": "people",
|
||||
"#2CBDF5": "crowd-people",
|
||||
"#C9F52C": "crowd-bicycle",
|
||||
"#DC6788": "crowd-car",
|
||||
"#6EC913": "traffic-cone",
|
||||
"#0DDE69": "plastic-barrier",
|
||||
"#8260D2": "crash-barrels",
|
||||
"#F1D1D1": "warning-triangle",
|
||||
"#FE6DF4": "crowd-traffic-cone",
|
||||
"#D1AA35": "crowd-plastic-barrier",
|
||||
"#3BE8D0": "crowd-crash-barrels",
|
||||
"#2B7567": "crowd-warning-triangle"
|
||||
}
|
||||
self.class2color = {cls: color for color, cls in self.color2class.items()}
|
||||
|
||||
self.cam2filename = {
|
||||
"front_120": "scanofilm_surround_front_120_8M.jpg",
|
||||
"front_left_100": "scanofilm_surround_front_left_100_2M.jpg",
|
||||
"front_right_100": "scanofilm_surround_front_right_100_2M.jpg",
|
||||
"rear_100": "scanofilm_surround_rear_100_2M.jpg",
|
||||
"rear_left_100": "scanofilm_surround_rear_left_100_2M.jpg",
|
||||
"rear_right_100": "scanofilm_surround_rear_right_100_2M.jpg",
|
||||
}
|
||||
|
||||
self._load_all_samples()
|
||||
# ------------------------------------------------------------------
|
||||
# 路径与ID辅助方法
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _extract_id(folder_name: str) -> str:
|
||||
"""从文件夹名提取连续数字作为sample_id"""
|
||||
return ''.join(c for c in os.path.basename(folder_name) if c.isdigit())
|
||||
|
||||
def _find_sample_folders(self) -> List[str]:
|
||||
"""返回所有 sample{id} 文件夹路径"""
|
||||
pattern = os.path.join(self.root, "sample*")
|
||||
return [f for f in glob.glob(pattern) if os.path.isdir(f)]
|
||||
|
||||
def _find_paths_in_sample(self, sample_folder: str, sample_id: str) -> dict:
|
||||
"""返回 sample 内部各资源路径字典"""
|
||||
pcd_folder = os.path.join(sample_folder, f"pcd_sequence{sample_id}")
|
||||
img_folders = []
|
||||
if os.path.isdir(pcd_folder):
|
||||
pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd"))
|
||||
num_frames = len(pcd_files)
|
||||
img_folders = [os.path.join(sample_folder, str(i)) for i in range(num_frames)]
|
||||
page_json = os.path.join(sample_folder, f"test{sample_id}.json")
|
||||
mark_json = os.path.join(sample_folder, f"test{sample_id}-mark.json")
|
||||
return {
|
||||
"pcd_folder": pcd_folder,
|
||||
"img_folders": img_folders,
|
||||
"page_json": page_json,
|
||||
"mark_json": mark_json
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 原始数据读取(无转换)
|
||||
# ------------------------------------------------------------------
|
||||
def _read_lidar_points(self, pcd_path: str) -> np.ndarray:
|
||||
try:
|
||||
# Convert path to bytes for better Chinese path support on Windows
|
||||
pcd_path_bytes = pcd_path.encode('utf-8') if isinstance(pcd_path, str) else pcd_path
|
||||
pcd = o3d.io.read_point_cloud(
|
||||
pcd_path_bytes,
|
||||
remove_nan_points=True,
|
||||
remove_infinite_points=True,
|
||||
format="auto"
|
||||
)
|
||||
pcd_np = np.asarray(pcd.points, dtype=np.float32) # [N, 3]
|
||||
return pcd_np
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading point cloud {pcd_path}: {e}")
|
||||
return None
|
||||
|
||||
def _load_all_lidar_frames(self, pcd_folder: str) -> Dict[int, np.ndarray]:
|
||||
idx2pcd_np = {}
|
||||
pcd_files = glob.glob(os.path.join(pcd_folder, "*.pcd"))
|
||||
for pcd_file in pcd_files:
|
||||
idx = int(os.path.basename(pcd_file).split(".")[0].split("n")[-1])
|
||||
pcd_np = self._read_lidar_points(pcd_file)
|
||||
if pcd_np is not None:
|
||||
idx2pcd_np[idx] = pcd_np
|
||||
return idx2pcd_np
|
||||
|
||||
def _load_image_paths(self, img_folders: List[str]) -> Dict[int, Dict[str, str]]:
|
||||
"""返回 帧索引 -> {视角名: 图像路径}"""
|
||||
idx2imgs = {}
|
||||
for folder in img_folders:
|
||||
if not os.path.isdir(folder):
|
||||
continue
|
||||
idx = int(os.path.basename(folder))
|
||||
paths = {}
|
||||
for cam, fname in self.cam2filename.items():
|
||||
full_path = os.path.join(folder, fname)
|
||||
assert os.path.isfile(full_path), f"Image file not found: {full_path}"
|
||||
paths[cam] = full_path
|
||||
if paths:
|
||||
idx2imgs[idx] = paths
|
||||
return idx2imgs
|
||||
|
||||
def _load_camera_raw(self, page_json_path: str) -> Dict[int, Dict[str, dict]]:
|
||||
"""返回 帧索引 -> {视角名: {"projection_matrix": list, "image_size": list}}"""
|
||||
if not os.path.isfile(page_json_path):
|
||||
return {}
|
||||
with open(page_json_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
try:
|
||||
extend_source = data['data']['files'][0]['extendSources']
|
||||
except (KeyError, IndexError):
|
||||
return {}
|
||||
|
||||
idx2cam = {}
|
||||
for d in extend_source:
|
||||
page_element = d.get('pageElement', {})
|
||||
if "sensor" not in page_element:
|
||||
continue
|
||||
# 提取帧索引
|
||||
try:
|
||||
idx = int(d['fileName'].split('.')[0])
|
||||
except:
|
||||
continue
|
||||
sensor = page_element['sensor']
|
||||
|
||||
cam_name = sensor.replace("ofilm_surround_", "")
|
||||
for suffix in ["_2M", "_8M"]:
|
||||
if cam_name.endswith(suffix):
|
||||
cam_name = cam_name[:-len(suffix)]
|
||||
break
|
||||
|
||||
if cam_name not in self.cam2filename:
|
||||
continue
|
||||
|
||||
matrix = page_element.get('mtx', [])
|
||||
image_size = page_element.get('imageSize', [])
|
||||
if not matrix or not image_size:
|
||||
continue
|
||||
|
||||
matrix = np.array(matrix, dtype=np.float64)
|
||||
assert matrix.shape == (3, 4), f"Invalid projection matrix shape: {matrix.shape}"
|
||||
cam_info = {
|
||||
"projection_matrix": matrix,
|
||||
"image_size": image_size
|
||||
}
|
||||
if idx not in idx2cam:
|
||||
idx2cam[idx] = {}
|
||||
idx2cam[idx][cam_name] = cam_info
|
||||
return idx2cam
|
||||
|
||||
def extract_text_label_from_json(self, json_path: str) -> List[Dict]:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)['data']['list'][0]
|
||||
|
||||
type_dict = data['rate']
|
||||
idx2rate = {}
|
||||
for key, value in type_dict.items():
|
||||
float_value = float(value['accuracy'].replace('%', '')) / 100.0
|
||||
idx2rate[key] = float_value
|
||||
|
||||
# 找出最大准确率对应的索引
|
||||
max_idx = max(idx2rate, key=idx2rate.get)
|
||||
max_rate = idx2rate[max_idx]
|
||||
|
||||
if max_rate == 1.0:
|
||||
result = data['result']
|
||||
label_list = result[max_idx]
|
||||
|
||||
return label_list
|
||||
else:
|
||||
return []
|
||||
|
||||
def process_text_label(self, json_path: str):
|
||||
label_list = self.extract_text_label_from_json(json_path)
|
||||
|
||||
if len(label_list) == 0:
|
||||
return {}
|
||||
|
||||
idx2label: Dict[int, List[Dict]] = {}
|
||||
for label in label_list:
|
||||
idx2label[label['index']] = self.process_label_for_one_frame(label['value'])
|
||||
return idx2label
|
||||
|
||||
def process_label_for_one_frame(self, label_list: List[Dict]) -> List[Dict]:
|
||||
if len(label_list) == 0:
|
||||
return []
|
||||
|
||||
# there could be multiple 3D bounding boxes in one frame
|
||||
labels = []
|
||||
for label in label_list:
|
||||
new_label = {}
|
||||
|
||||
# num is the id of the 3D bounding box
|
||||
new_label['num'] = label['num']
|
||||
|
||||
# label
|
||||
new_label['label'] = label['label']
|
||||
|
||||
# class
|
||||
new_label['class'] = label['label']['class-name']
|
||||
|
||||
# is_moving
|
||||
new_label['is_moving'] = not bool(label['label']['static'])
|
||||
|
||||
# isolation
|
||||
# True: 被建筑物,护栏,绿化带等隔离开
|
||||
# False: 可⾏驶区域
|
||||
new_label['isolation'] = bool(label['label']['isolation'])
|
||||
|
||||
# color
|
||||
new_label['color'] = self.class2color[new_label['class']]
|
||||
|
||||
# corners
|
||||
assert len(label['newPoints']) == 8, f"Points number for one 3D bounding box is 8, but got {len(label['newPoints'])}."
|
||||
eight_points = []
|
||||
for i in range(8):
|
||||
eight_points.append(
|
||||
[
|
||||
float(label['newPoints'][i]['x']),
|
||||
float(label['newPoints'][i]['y']),
|
||||
float(label['newPoints'][i]['z'])
|
||||
]
|
||||
)
|
||||
new_label['points'] = eight_points
|
||||
|
||||
# center
|
||||
new_label['center'] = [
|
||||
float(label['x']),
|
||||
float(label['y']),
|
||||
float(label['z'])
|
||||
]
|
||||
|
||||
# rotate - 只保留绕Z轴的旋转
|
||||
# 单位:弧度
|
||||
# 正方向:从X轴正半轴转向Y轴正半轴(逆时针)
|
||||
new_label['rotateZ'] = float(label['rotateZ'])
|
||||
|
||||
new_label['dx'] = float(label['width'])
|
||||
new_label['dy'] = float(label['height'])
|
||||
new_label['dz'] = float(label['depth'])
|
||||
|
||||
labels.append(new_label)
|
||||
return labels
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sample 级组装
|
||||
# ------------------------------------------------------------------
|
||||
def _load_one_sample(self, sample_folder: str) -> Optional[SampleData]:
|
||||
"""加载单个 sample 文件夹,返回 SampleData 对象"""
|
||||
sample_id = self._extract_id(sample_folder)
|
||||
paths = self._find_paths_in_sample(sample_folder, sample_id)
|
||||
|
||||
# 读取各模态原始数据
|
||||
lidar_dict = self._load_all_lidar_frames(paths["pcd_folder"])
|
||||
img_dict = self._load_image_paths(paths["img_folders"])
|
||||
cam_dict = self._load_camera_raw(paths["page_json"])
|
||||
label_dict = self.process_text_label(paths["mark_json"])
|
||||
if not (len(label_dict) == len(lidar_dict) == len(img_dict) == len(cam_dict)):
|
||||
print(f"Sample {sample_id}: Mismatch in frame counts between lidar, images, cameras, and labels.")
|
||||
return None
|
||||
|
||||
frames = {}
|
||||
for idx in range(len(lidar_dict)):
|
||||
frames[idx] = FrameData(
|
||||
frame_index=idx,
|
||||
lidar_points=lidar_dict[idx],
|
||||
image_paths=img_dict[idx],
|
||||
camera_raw=cam_dict[idx],
|
||||
labels=label_dict[idx],
|
||||
contain_labels=len(label_dict[idx]) > 0
|
||||
)
|
||||
if not frames:
|
||||
return None
|
||||
return SampleData(sample_id=sample_id, frames=frames)
|
||||
|
||||
def _load_all_samples(self):
|
||||
"""遍历所有 sample 文件夹并加载"""
|
||||
sample_folders = self._find_sample_folders()
|
||||
process_bar = tqdm(sample_folders, desc="Loading samples")
|
||||
for sf in process_bar:
|
||||
sample = self._load_one_sample(sf)
|
||||
if sample is not None:
|
||||
self._samples[sample.sample_id] = sample
|
||||
|
||||
print(f"Loaded {len(self._samples)} samples.")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公开接口
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def sample_ids(self) -> List[str]:
|
||||
return list(self._samples.keys())
|
||||
|
||||
def get_sample(self, sample_id: str) -> Optional[SampleData]:
|
||||
return self._samples.get(sample_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, idx: int) -> SampleData:
|
||||
"""通过整数索引访问 sample(按sample_id排序)"""
|
||||
sids = sorted(self.sample_ids)
|
||||
if idx < 0 or idx >= len(sids):
|
||||
raise IndexError(f"Index {idx} out of range [0, {len(sids)-1}]")
|
||||
return self._samples[sids[idx]]
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = MyDataset(root_folder="sample")
|
||||
Reference in New Issue
Block a user