import torch
import numpy as np
import statistics as st
import os
import pickle
import argparse

from model import LeNet, TwoLayerNeuralNet, ThreeLayerNeuralNet, FourLayerNeuralNet


def summarize(model_name, sample_num, data_num):
    test_loss_list = []

    for seed in range(sample_num):
        checkpoint = torch.load(f"large_experiment_result/sgd/{model_name}/{data_num}/model_weights_{seed}.pth")
        test_loss = checkpoint["test_loss"]
        test_loss_list.append(test_loss.item())

    test_loss_mean, test_loss_std = st.mean(test_loss_list), st.stdev(test_loss_list)
    return {"test_loss_mean": test_loss_mean, "test_loss_std": test_loss_std}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--sample_num', type=int)
    parser.add_argument('--data_num', type=int)

    args = parser.parse_args()
    model_name = args.model_name
    sample_num = args.sample_num
    data_num = args.data_num

    result = summarize(model_name, sample_num, data_num)
    folder = f"large_experiment_result/summary/sgd/{model_name}/{data_num}"
    os.makedirs(folder, exist_ok=True)
    with open(f"{folder}/result.pcl",'wb') as f:
        pickle.dump(result, f)


if __name__ == "__main__":
    main()
