Add visualize_training.py

This commit is contained in:
2026-02-27 11:41:49 +08:00
parent e7202bcdd7
commit 98eb31bf69

29
visualize_training.py Normal file
View File

@@ -0,0 +1,29 @@
import pandas as pd
import matplotlib.pyplot as plt
def visualize_training_history(threshold):
csv_path = f'training_history_threshold_{threshold}.csv'
df = pd.read_csv(csv_path)
epochs = df['epoch']
train_loss = df['train_loss']
val_loss = df['val_loss']
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
plt.title('Training and Validation Loss Over Epochs (CE)')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig(f'training_visualization_threshold_{threshold}.png')
plt.show()
print(f'可视化完成,图表已保存为 training_visualization_threshold_{threshold}.png')
if __name__ == "__main__":
for i in range(11):
visualize_training_history(i)