diff --git a/visualize_training.py b/visualize_training.py new file mode 100644 index 0000000..e4c8622 --- /dev/null +++ b/visualize_training.py @@ -0,0 +1,29 @@ +import pandas as pd +import matplotlib.pyplot as plt + +def visualize_training_history(threshold): + csv_path = f'training_history_threshold_{threshold}.csv' + + df = pd.read_csv(csv_path) + epochs = df['epoch'] + train_loss = df['train_loss'] + val_loss = df['val_loss'] + + plt.figure(figsize=(10, 6)) + plt.plot(epochs, train_loss, 'b-', label='Training Loss') + plt.plot(epochs, val_loss, 'r-', label='Validation Loss') + + plt.title('Training and Validation Loss Over Epochs (CE)') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + plt.savefig(f'training_visualization_threshold_{threshold}.png') + plt.show() + + print(f'可视化完成,图表已保存为 training_visualization_threshold_{threshold}.png') + +if __name__ == "__main__": + for i in range(11): + visualize_training_history(i) \ No newline at end of file