import pandas as pd import matplotlib.pyplot as plt def visualize_training_history(threshold): csv_path = f'training_history_threshold_{threshold}.csv' df = pd.read_csv(csv_path) epochs = df['epoch'] train_loss = df['train_loss'] val_loss = df['val_loss'] plt.figure(figsize=(10, 6)) plt.plot(epochs, train_loss, 'b-', label='Training Loss') plt.plot(epochs, val_loss, 'r-', label='Validation Loss') plt.title('Training and Validation Loss Over Epochs (CE)') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.savefig(f'training_visualization_threshold_{threshold}.png') plt.show() print(f'可视化完成,图表已保存为 training_visualization_threshold_{threshold}.png') if __name__ == "__main__": for i in range(11): visualize_training_history(i)