Add visualize_training.py
This commit is contained in:
29
visualize_training.py
Normal file
29
visualize_training.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
def visualize_training_history(threshold):
|
||||||
|
csv_path = f'training_history_threshold_{threshold}.csv'
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
epochs = df['epoch']
|
||||||
|
train_loss = df['train_loss']
|
||||||
|
val_loss = df['val_loss']
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
|
||||||
|
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
|
||||||
|
|
||||||
|
plt.title('Training and Validation Loss Over Epochs (CE)')
|
||||||
|
plt.xlabel('Epochs')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
plt.savefig(f'training_visualization_threshold_{threshold}.png')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print(f'可视化完成,图表已保存为 training_visualization_threshold_{threshold}.png')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for i in range(11):
|
||||||
|
visualize_training_history(i)
|
||||||
Reference in New Issue
Block a user