import random
random.seed(0)
import numpy as np
np.random.seed(0)

import hydra
import torch
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
import random
from src.multitask_modulated_net import MultiTaskModulatedNet
from src.cnn_controller import TaskController

from src.datasets.cifar100 import CIFAR100,TASK_GROUPS_CIFAR100
import os
import json
import matplotlib.pyplot as plt
import matplotlib
import os
from src.noise_corruption import compute_corruptions_results_model
from src.classification_metrics import  MetricAggregator,Accuracy
from src.classification_criterions import LossAggregator,CrossEntropyLoss
NUM_CPUS =4
from src.model_calibration import plot_model_calibration
from src.adversarial_robustness import evaluate_adversarial_robustness, pgd
from src.pca_lda import plot_pca_lda

device = "cuda"
        
def load_except_classif(now,pre):
    pre_dict = pre.state_dict()
    pre_dict.pop("classifiers.bias")
    pre_dict.pop("classifiers.weight")
    
    now.load_state_dict(pre_dict,strict=False)
    return now 

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


class TorchEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, torch.Tensor):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

def convert_dict_for_json(d):
    new_dict = {}
    if type(d) == list:
        return [convert_dict_for_json(x) for x in d]
    
    for k,v in d.items():
        if type(v) == dict:
            new_dict[k] = convert_dict_for_json(v)
        elif type(v) == str:
            new_dict[k] = v
        elif isinstance(v,torch.Tensor):
            if v.numel() == 1:
                new_dict[k] = float(v.item())
            else:
                new_dict[k] = v.tolist()
        elif type(v) == list:
            new_dict[k] = [convert_dict_for_json(x) for x in v]
        else:
            new_dict[k] = float(v)
    return new_dict

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    set_seed(worker_seed)

if torch.cuda.is_available():
    DEVICE_TYPE = 'cuda'
else:
    DEVICE_TYPE  = "cpu"
    
DEVICE  = torch.device(DEVICE_TYPE)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

print("Using",DEVICE)
import logging
logger = logging.getLogger(__name__)

def plot_seeds_mean(ax,x,accuracies, label,color):
    means = np.mean(accuracies, axis=0)
    stds = np.std(accuracies, axis=0)
    ax.plot(x,means, label=label, color=color)
    ax.fill_between(x, means-stds, means+stds,
            alpha=0.2, edgecolor=color, facecolor=color)

    
def copy_important_files(files_to_save):
    import shutil
    # Copy important files to the output directory  using system calls
    print("Copying important files to the output directory")
    print("Files to save:", files_to_save)
    os.mkdir("./code/")
    for f in files_to_save:
        shutil.copyfile(f, os.path.join("./code/",f.split("/")[-1]))

def plot_corruptions_difference(results,folder="corruptions",corruptions_results_path=None):
    fig, ax = plt.subplots()
    if results is  None:
        results = json.load(open(f"{folder}/results.json", 'r'))
    corruptions = list(results[0]["comodulation"].keys())
    cmap = matplotlib.colormaps.get_cmap("tab20").colors
    for corruption in corruptions:
        corr_results_att = np.array([c["attention"][corruption] for c in results])
        corr_results_comod = np.array([c["comodulation"][corruption] for c in results])
        x = np.arange(0, corr_results_comod.shape[-1], 1)
        plot_seeds_mean(ax, x, corr_results_comod - corr_results_att, corruption, cmap[corruptions.index(corruption)])
        ax.set_xlabel("Corruption level")
        ax.set_ylabel("Accuracy")
        ax.set_title(corruption)
        ax.legend()
        
    plt.savefig(f"{folder}/all_corrupt_differences.pdf",format="pdf",bbox_inches="tight")
    plt.close()

def plot_corruptions_results(results,folder="corruptions"):
    
    corruptions = results[0]["comodulation"].keys()
    for corruption in corruptions:
        if "pretrained" in results[0]:
            corr_results_pre = np.array([c["pretrained"][corruption] for c in results])
        corr_results_att = np.array([c["attention"][corruption] for c in results])
        corr_results_comod = np.array([c["comodulation"][corruption] for c in results])
        fig, ax = plt.subplots()
        x = np.arange(0, corr_results_comod.shape[-1], 1)
        if "pretrained" in results[0]:
            plot_seeds_mean(ax, x, corr_results_pre, "Pretraining", "blue")
        plot_seeds_mean(ax, x, corr_results_comod, "Comodulation", "green")
        plot_seeds_mean(ax, x, corr_results_att, "Attention", "red")
        ax.set_xlabel("Corruption level")
        ax.set_ylabel("Accuracy")
        ax.set_title(corruption)
        ax.legend()
        plt.savefig(f"{folder}/corruption_{corruption}.pdf",format="pdf",bbox_inches="tight")
        plt.close()
        fig, ax = plt.subplots()
        plot_seeds_mean(ax, x, corr_results_comod - corr_results_att, "delt", "green")
        ax.set_xlabel("Corruption level")
        ax.set_ylabel("Comodulation - Attention")
        ax.set_title(corruption)
        ax.legend()
        plt.savefig(f"{folder}/difference_corruption_{corruption}.pdf",format="pdf",bbox_inches="tight")
        plt.close()
            
    json.dump(results, open(f"{folder}/results.json","w"),cls=NumpyEncoder)

def plot_eval_batch(results, dict_key="eval_during_train",plot_name="results",y_axis="F1 Score",plot_cumsum=False,make_individual_plots=False):
        
    att_evals_train = np.array([d[dict_key] for d in results["ATTENTION"]])
    com_evals_train = np.array([d[dict_key] for d in results["COMODULATION"]])
    
    if att_evals_train.ndim == 3:
        att_evals_train = att_evals_train.reshape(att_evals_train.shape[0],-1)
        com_evals_train = com_evals_train.reshape(com_evals_train.shape[0],-1)
        
    finetune_x = np.arange(0, com_evals_train.shape[-1], 1)

    fig, ax = plt.subplots()

    plot_seeds_mean(ax,finetune_x, att_evals_train, "Attention", "blue")
    plot_seeds_mean(ax,finetune_x, com_evals_train, "Comodulation", "red")

    if dict_key == "eval_during_train":
        pre_evals_train = np.array([d[dict_key] for d in results["PRETRAIN"]])
        pre_evals_rep = np.ones((finetune_x.shape[0]))*pre_evals_train.mean()
        ax.plot(finetune_x, pre_evals_rep, label="Pretraining", color="black", linestyle="--")
        
    ax.set_xlabel("# of training batches")
    ax.set_ylabel(y_axis)
    ax.legend()    
    fig.savefig(f"{plot_name}_std.pdf",format="pdf",bbox_inches="tight")

    if plot_cumsum:
        fig, ax = plt.subplots()
        att_cumsum = np.cumsum(att_evals_train,axis=1)
        com_cumsum = np.cumsum(com_evals_train,axis=1)
        plot_seeds_mean(ax,finetune_x, att_cumsum, "Attention", "blue")
        plot_seeds_mean(ax,finetune_x, com_cumsum, "Comodulation", "red")

        ax.set_xlabel("# of training batches")
        ax.set_ylabel(y_axis)
        ax.legend()    
        fig.savefig(f"{plot_name}_cumsum..pdf",format="pdf",bbox_inches="tight")
    plt.close()
    if make_individual_plots:
        if not os.path.isdir(plot_name):
            os.mkdir(plot_name)
        for i in range(att_evals_train.shape[0]):
            fig, ax = plt.subplots()
            ax.plot(finetune_x, att_evals_train[i], label="Attention", color="blue")
            ax.plot(finetune_x, com_evals_train[i], label="Comodulation", color="red")
            ax.set_xlabel("# of training batches")
            ax.set_ylabel(y_axis)
            ax.legend()    
            fig.savefig(f"{plot_name}/seed_{i}.pdf",format="pdf",bbox_inches="tight")
            plt.close()
    
    
def get_dataloaders_cifar100(task_groups, batch_size,image_size,data_dir=None,add_augmentations=False,generator=None,coarse_labels=False,seed=0,test_batch_size=200):
    
    train_dataset = CIFAR100(data_dir=data_dir,split="train", add_augmentations=add_augmentations,image_size=image_size,coarse_labels=coarse_labels)
    
    test_dataset = CIFAR100(data_dir=data_dir,split="test", add_augmentations=False,image_size=image_size,coarse_labels=coarse_labels)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_CPUS, pin_memory=True,drop_last=True,generator=generator,worker_init_fn=seed_worker)
    val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=NUM_CPUS, pin_memory=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=NUM_CPUS, pin_memory=False)
    print("Train dataset size:",len(train_dataset))
    print("Test dataset size:",len(test_dataset))
    return train_loader, val_loader,test_loader


def evaluate(model, val_loader, criterion, num_tasks, num_classes, class_to_class_idx,train_config=None):
    model.eval()
    other_val_metrics = {}
    
    metrics = MetricAggregator([Accuracy()])        
        
    with torch.no_grad(): 
        for i,batch in enumerate(val_loader):                    
            # Get data
            data, targets = batch[0], batch[1]
            if len(batch) == 2:
                tasks = None
                data, targets = data.to(DEVICE), [elt.to(DEVICE) for elt in targets]
            elif len(batch) == 3:
                tasks = batch[2]
                data, targets,tasks = data.to(DEVICE),targets.to(DEVICE),tasks.to(DEVICE)
#            data, targets = data.to(DEVICE), [elt.to(DEVICE) for elt in targets]
            # Forward
            outputs,decoder_activity,_ = model(data,tasks=tasks)
            metrics.update(outputs, targets)

        task_metrics = metrics.compute()
        return metrics, other_val_metrics
        
    
def train_test( train_config,modulation_params,network, controller, task_groups,is_attention=False,is_comodulation=False,seed=0,is_output_weights=False):
    generator = torch.Generator()
    generator = generator.manual_seed(seed)
    set_seed(seed)

    num_tasks = len(task_groups)
    num_classes = sum([len(elt) for elt in task_groups])

    acc_dict = {"train": [], "val": [], "test": [], "eval_during_train":[],"num_conflicting_gradients":[],"num_informative_neurons":[],"conflicting_gradients_avg_norm":[]}
    
    is_pretrain = not (is_attention or is_comodulation )
    class_to_class_idx = []
    last_class = 0
    for task in range(num_tasks):
        class_to_class_idx.append( torch.arange(last_class, last_class+len(task_groups[task])).int() )
        last_class += len(task_groups[task])
    
    model = MultiTaskModulatedNet(network=network, controller=controller,device=DEVICE, is_attention=is_attention,is_comodulation=is_comodulation,**modulation_params)
    model.set_modulation_training(is_attention or is_comodulation)
    model.to(DEVICE)

    WHATS_GOING_ON = "ATTENTION" if is_attention else "COMODULATION" if is_comodulation   else "PRETRAINING"
    
    logger.info("Training with %s"%WHATS_GOING_ON)
    
    if not is_pretrain and model.retrain_last:
        optim_dicts = [{"params": model.network.classifiers.parameters(),"lr":train_config.classifier_lr},
                       {"params": controller.parameters(),"lr":train_config.lr}]
        optimizer = torch.optim.Adam(optim_dicts)
    else:
        optimizer = torch.optim.Adam(model.get_parameters_for_optimizer(), lr=train_config.lr)#,)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=20, gamma=0.5)
    
    train_loader, val_loader, test_loader= get_dataloaders_cifar100(task_groups=task_groups,data_dir=train_config.data_dir,batch_size=train_config.batch_size,image_size=train_config.image_size,add_augmentations=(train_config.add_augmentations ),generator=generator,coarse_labels=train_config.coarse_labels,seed=seed)

    smaller_val_loader = val_loader#torch.utils.data.DataLoader(smaller_val_ds, batch_size=64, drop_last=False, shuffle=False, num_workers=NUM_CPUS, pin_memory=True)
    
    criterion = LossAggregator([CrossEntropyLoss()]) 
    
    for epoch in range(1, train_config.max_epoch + 1):
    # Train loop
        if (is_attention or is_comodulation ) and train_config.eval_every_batch:
            acc_dict["eval_during_train"].append([])
            val_task_metrics, other_val_metrics = evaluate(model,val_loader=smaller_val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
            acc_dict["eval_during_train"][-1].append(val_task_metrics.get_mean())


        train_metrics = MetricAggregator([Accuracy() ])
            
        for batch_num, batch in enumerate(train_loader):
            # Get data
            
            data, targets,tasks = batch[0], batch[1],batch[2]
            data, targets,tasks = data.to(DEVICE),targets.to(DEVICE),tasks.to(DEVICE)
            
            # Forward pass
            if train_config.adversarial_training:
                data = pgd(model,data,tasks,targets,k=5,eps=train_config.adversarial_eps)
            model.train()
            if (is_comodulation or is_attention) and not train_config.retrain_everything:
                model.controller.train()
                model.network.resnet.eval()
            if is_output_weights:
                model.network.resnet.eval()

            outputs,decoder_activity,controller_params = model(data,tasks=tasks)
            task_losses = criterion.compute(outputs,targets)#compute_losses(outputs,targets,criterion,mask_val=-1)

            loss = torch.sum(task_losses)
            bce_loss = loss.item()
            if train_config.do_l1_loss_on_controller and not is_pretrain:
                loss += train_config.l1_loss_weight*controller_params.abs().mean()
                                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_item = loss.item()
            del loss

            train_metrics.update(outputs, targets)
            
            if train_config.eval_every_batch and (batch_num % train_config.eval_every_n_batch == 0):
                #val_losses, val_accs, val_precs, val_recs, val_fscores, other_val_metrics = evaluate(model,val_loader=smaller_val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
                val_task_metrics, other_val_metrics = evaluate(model,val_loader=smaller_val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
                acc_dict["eval_during_train"][-1].append(val_task_metrics.get_mean())
            
            if batch_num % train_config.log_freq == 0:
                logger.info('Epoch {}, iter {}/{}, Loss : {:.4f}, Class loss {:.4f}, fscores {:.4f}'.format(epoch, batch_num+1, len(train_loader), loss_item,float(bce_loss), train_metrics.get_mean()))

            if (train_config.max_batch > 0) and (batch_num > train_config.max_batch):
                break
        acc_dict["train"].append(train_metrics.get_mean())    
        lr_scheduler.step()
        
        #############
        # Eval loop #
        #############
        if model.is_comodulation and model.compute_gain_once_with_train_set:
            model.compute_gain_on_training_set(train_loader)

        if not train_config.eval_every_batch:
            val_task_metrics, other_val_metrics = evaluate(model,val_loader=smaller_val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
            acc_dict["eval_during_train"].append(val_task_metrics.get_mean())

        val_task_metrics, other_val_metrics = evaluate(model,val_loader=val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)

        eval_string ='{} EVAL EPOCH {}, '.format(WHATS_GOING_ON,epoch)
        eval_string += val_task_metrics.get_string()
        eval_string += f"MEAN {val_task_metrics.get_mean():.4f}"
        logger.info(eval_string)

        mean_metric = val_task_metrics.get_mean()

        acc_dict["val"].append(mean_metric)
    del val_loader
    model.eval()
    test_task_metrics, other_val_metrics = evaluate(model,val_loader=test_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
    #test_losses, test_accs, test_precs, test_recs, test_fscores,_ = evaluate(model,val_loader=test_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx)
    test_string ='{} TEST, '.format(WHATS_GOING_ON)
    test_string += test_task_metrics.get_string()
    test_string += f"MEAN {test_task_metrics.get_mean():.4f}"
    logger.info(test_string)
    
    acc_dict["test"].append(test_task_metrics.get_mean())
    
    return model, acc_dict

@hydra.main(config_path="configs", config_name="train_mt_cifar")
def main(config):
    
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    print(config)

    task_groups = TASK_GROUPS_CIFAR100

    config.network.task_groups = task_groups
    if config.copy_files_to_save:
        copy_important_files(config.files_to_save)
        
    if config.make_embeddings_plot:
        os.mkdir("./plots/")
    
    if config.plot_corruptions:
        corruptions_results = []
        os.mkdir("./corruptions/")
    if config.plot_adversarial:
        adversarial_results = []
        os.mkdir("./adversarial/")
    if config.plot_pca_lda:
        os.mkdir("./pca_lda/")
        pca_lda_dicts = []
    if config.plot_model_calibration:
        os.mkdir("./calibrations/")
        calibration_dicts = []
    
    model_variant = "bias" if config.network.dec_normal_bias_init > 0 else "residual" if config.network.use_residual_connection else "base"
    comod_bp = "comod_bp" if config.modulation_params.comod_backprop else ""
    fix_gain_ts = "fix_gain_ts" if config.modulation_params.compute_gain_once_with_train_set else ""
    config_help_name = f"{model_variant}_{comod_bp}_{fix_gain_ts}"
    with open(config_help_name, 'w') as f:
        # Optionally, write some content to the file
        f.write("!")

    COMODULATION_RESULTS = []
    ATTENTION_RESULTS = []
    PRETRAIN_RESULTS = []
    OUT_WEIGHTS_RESULTS = []
    
    for seed in  range(config.n_train_test):
        
        corruptions_results_seed = {}
        adversarial_results_seed = {}
        print("Train test from scratch {}".format(seed))
        set_seed(seed)
        config.network.coarse_labels = config.pretrain_coarse_labels

        network = hydra.utils.instantiate(config.network).to(DEVICE)
        pretrained_model, pretrain_acc_dict = train_test(config.pretrain_config,config.modulation_params, network, None, task_groups,is_attention=False,is_comodulation=False,seed=seed)
        if config.pretrain_config.plot_corruptions:
            corruptions_results_seed["pretrained"] = (compute_corruptions_results_model(pretrained_model,task_groups=task_groups,seed=seed,device=DEVICE))
            
        PRETRAIN_RESULTS.append(pretrain_acc_dict)
        pretrained_network = pretrained_model.network
        set_seed(seed)
        config.network.coarse_labels = False
        if config.do_only_output_weights:
            output_weights_network = hydra.utils.instantiate(config.network)
            output_weights_network = load_except_classif(output_weights_network,pretrained_network)
            config.modulation_params.only_retrain_last =  True
            _, ow_acc_dict = train_test(config.attention_config,config.modulation_params, output_weights_network, None,task_groups,is_attention=False,is_comodulation=False,is_output_weights=True,seed=seed)            
            OUT_WEIGHTS_RESULTS.append(ow_acc_dict)
            config.modulation_params.only_retrain_last =  False
            
            continue
        set_seed(seed)
        controller = TaskController(network.get_layer_to_modulate().conv_out_shape,
                                    task_groups=task_groups,
                                    device=DEVICE,
                                **config.controller).to(DEVICE)
    
        att_network = hydra.utils.instantiate(config.network)
        if config.pretrain_coarse_labels:
            att_network = load_except_classif(att_network,pretrained_network)
        else:
            att_network.load_state_dict(pretrained_network.state_dict())
        att_network = att_network.to(DEVICE)
        att_model, att_acc_dict = train_test(config.attention_config,config.modulation_params, att_network, controller,task_groups,is_attention=True,is_comodulation=False,seed=seed)            

        ATTENTION_RESULTS.append(att_acc_dict)

        set_seed(seed)
        controller = TaskController(network.get_layer_to_modulate().conv_out_shape,
                                    task_groups=task_groups,
                                    device=DEVICE,
                                **config.controller).to(DEVICE)
            
        com_network = hydra.utils.instantiate(config.network)
        if config.pretrain_coarse_labels:
            com_network = load_except_classif(com_network,pretrained_network)
        else:
            com_network.load_state_dict(pretrained_network.state_dict())
        com_network = com_network.to(DEVICE)

        com_model, com_acc_dict = train_test(config.comodulation_config,config.modulation_params, com_network, controller,task_groups,is_attention=False,is_comodulation=True,seed=seed)
        COMODULATION_RESULTS.append(com_acc_dict)

        if config.plot_pca_lda: 
            _,_,test_dataloader = get_dataloaders_cifar100(task_groups=None,data_dir=comodulation_config.data_dir,batch_size=50,image_size=32,generator=None,coarse_labels=False,add_augmentations=False,seed=seed,test_batch_size=200)
            p_l_dict = plot_pca_lda(pretrained_model,att_model,com_model,test_dataloader,folder="pca_lda",seed=seed)
            pca_lda_dicts.append(p_l_dict)

        if config.comodulation_config.plot_adversarial: 
            _,_,test_dataloader = get_dataloaders_cifar100(task_groups=None,data_dir=comodulation_config.data_dir,batch_size=50,image_size=32,generator=None,coarse_labels=False,add_augmentations=False,seed=seed,test_batch_size=200)
            adversarial_results_seed["attention"] = evaluate_adversarial_robustness(att_model,test_dataloader)
            adversarial_results_seed["comodulation"] = evaluate_adversarial_robustness(com_model,test_dataloader)
            adversarial_results.append(adversarial_results_seed)

        if config.comodulation_config.plot_corruptions: 
            corruptions_results_seed["attention"] = (compute_corruptions_results_model(att_model,dataset_class=CIFAR100,task_groups=task_groups,seed=seed,device=DEVICE))
            corruptions_results_seed["comodulation"] = compute_corruptions_results_model(com_model,seed=seed,task_groups=task_groups,dataset_class=CIFAR100,device=DEVICE)
            corruptions_results.append(corruptions_results_seed)

        if config.plot_model_calibration:
            _,_,test_dataloader = get_dataloaders_cifar100(task_groups=None,data_dir=comodulation_config.data_dir,batch_size=50,image_size=32,generator=None,coarse_labels=False,add_augmentations=False,seed=seed,test_batch_size=200)
            calibration_dicts.append(plot_model_calibration(att_model,com_model,test_dataloader,plot_name=f"calibrations/seed_{seed}.pdf"))
        
            
    logger.info("F1 Scores FOR EVERY SEED:")
    
    PRETRAIN_TEST = [x.get("test",0) for x in PRETRAIN_RESULTS]
    ATTENTION_TEST = [x.get("test",0) for x in ATTENTION_RESULTS]
    COMODULATION_TEST = [x.get("test",0) for x in COMODULATION_RESULTS]
    if config.do_only_output_weights:
        OUT_WEIGHT_TEST = [x["test"] for x in OUT_WEIGHTS_RESULTS]
        PRETRAIN_TEST = [x["test"] for x in PRETRAIN_RESULTS]
        ATTENTION_TEST = [0 for _ in OUT_WEIGHT_TEST]
        COMODULATION_TEST = [0 for _ in OUT_WEIGHT_TEST]
        
    for seed_resu in range(config.n_train_test):
        logger.info("SEED {}".format(seed_resu))
        str_out_weights = ","
        if config.do_only_output_weights:
            str_out_weights = ",OUT WEIGHTS {},".format(OUT_WEIGHTS_RESULTS[seed_resu]["test"])
            
        logger.info("PRETRAIN  {} {}  ATTENTION {},  COMODULATION {}".format(PRETRAIN_TEST[seed_resu],str_out_weights,ATTENTION_TEST[seed_resu],COMODULATION_TEST[seed_resu]))

    logger.info("Classification MEAN:")
    if config.do_only_output_weights:
        logger.info("PRETRAIN  {}, OUTPUT_WEIGHTS {}, ATTENTION {},  COMODULATION {} ".format(np.mean(PRETRAIN_TEST),np.mean(OUT_WEIGHT_TEST),np.mean(ATTENTION_TEST),np.mean(COMODULATION_TEST)))
    else:
        logger.info("PRETRAIN  {}, ATTENTION {},  COMODULATION {} ".format(np.mean(PRETRAIN_TEST),np.mean(ATTENTION_TEST),np.mean(COMODULATION_TEST)))
    save_dict = {"PRETRAIN":PRETRAIN_RESULTS,"ATTENTION":ATTENTION_RESULTS,"COMODULATION":COMODULATION_RESULTS,"OUT_WEIGHTS":OUT_WEIGHTS_RESULTS}
    ## Save results in json dict:

    if config.comodulation_config.eval_every_batch:
        plot_eval_batch(save_dict,dict_key="eval_during_train",plot_name="results",y_axis="F1 score",make_individual_plots=True)

    if config.comodulation_config.plot_corruptions:
        plot_corruptions_results(corruptions_results)
        plot_corruptions_difference(corruptions_results,folder="corruptions")

    if config.comodulation_config.plot_adversarial:
        plot_corruptions_results(adversarial_results,folder="adversarial")
        plot_corruptions_difference(adversarial_results,folder="adversarial")

    if config.plot_pca_lda:
        with open("./pca_lda/pca_lda_results.json","w") as f:
            json.dump(pca_lda_dicts,f,cls=NumpyEncoder)

    if config.plot_model_calibration:
        with open("./calibrations/calibration_results.json","w") as f:
            json.dump(calibration_dicts,f,cls=NumpyEncoder)
            
        
    with open("results.json","w") as f:
        json.dump(save_dict,f,cls=TorchEncoder)

            

if __name__ == '__main__':
    main()


