29 lines
858 B
Python
29 lines
858 B
Python
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
|
|
def visualize_training_history(threshold):
|
|
csv_path = f'training_history_threshold_{threshold}.csv'
|
|
|
|
df = pd.read_csv(csv_path)
|
|
epochs = df['epoch']
|
|
train_loss = df['train_loss']
|
|
val_loss = df['val_loss']
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
|
|
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
|
|
|
|
plt.title('Training and Validation Loss Over Epochs (CE)')
|
|
plt.xlabel('Epochs')
|
|
plt.ylabel('Loss')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
|
|
plt.savefig(f'training_visualization_threshold_{threshold}.png')
|
|
plt.show()
|
|
|
|
print(f'可视化完成,图表已保存为 training_visualization_threshold_{threshold}.png')
|
|
|
|
if __name__ == "__main__":
|
|
for i in range(11):
|
|
visualize_training_history(i) |