Add visualize_training.py
This commit is contained in:
29
visualize_training.py
Normal file
29
visualize_training.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def visualize_training_history(threshold):
|
||||
csv_path = f'training_history_threshold_{threshold}.csv'
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
epochs = df['epoch']
|
||||
train_loss = df['train_loss']
|
||||
val_loss = df['val_loss']
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
|
||||
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
|
||||
|
||||
plt.title('Training and Validation Loss Over Epochs (CE)')
|
||||
plt.xlabel('Epochs')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.savefig(f'training_visualization_threshold_{threshold}.png')
|
||||
plt.show()
|
||||
|
||||
print(f'可视化完成,图表已保存为 training_visualization_threshold_{threshold}.png')
|
||||
|
||||
if __name__ == "__main__":
|
||||
for i in range(11):
|
||||
visualize_training_history(i)
|
||||
Reference in New Issue
Block a user