import torch
import numpy as np
import os

from data import get_the_data
from train import train
from model import define_the_model
from test import test
from optim import define_optimizer
import input_args
from stamp import param_stamp
from store import storeit
from make_data_fig2 import run

def read_the_result(column, seed, noise_train, noise_test):
    return "tab1" + "/" + column + "_" + "res" + "-" + "se_" + str(seed) + "-" + "ntr_" + str(round(noise_train, 1))\
        + "-" + "nte_" + str(round(noise_test, 1))


def mean_sems(dataset, noise_train, noise_test):
    column = "c1" if dataset == "CIFAR10" else "c2"
    accs = np.array([read_the_result(column, seed, noise_train, noise_test) for seed in range(11, 20)])
    mean = np.mean(accs)
    sem = np.sqrt(np.var(accs)/(9))
    return f"({noise_train}, {noise_test}): {round(mean*100, 2)} +- {round(sem*100, 2)} "


def run(args):
    list_of_noise = [(0.0, 0.0), (0.0, 0.5), (0.5, 0.0), (0.5, 0.5), (0.9, 0.9), (0.5, 1.5), (0.9, 1.5),\
        (1.5, 0.5), (1.5, 0.9), (1.5, 1.5)]
    # first calculate the means and the sems
    datasets = ["CIFAR10", "CIFAR20"]
    for dataset in datasets:
        print("Dataset: " + dataset)
        for noise_train, noise_test in list_of_noise:
            print(mean_sems(dataset, noise_train, noise_test), sep=" ")


if __name__ == '__main__':
    run()