import random
random.seed(0)
import numpy as np
np.random.seed(0)
import hydra
import torch
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
import random
from src.multitask_modulated_net import MultiTaskModulatedNet
from src.cnn_controller import TaskController
from src.datasets.celeba import CelebaGroupedDataset
from src.datasets.celeba import task_groups_classess_names
import json
import matplotlib.pyplot as plt
import os
from src.umap_celeba import plot_controller_embedding
from src.classification_metrics import F1, MetricAggregator
from src.classification_criterions import LossAggregator, MaskedBCEWithLogitsLoss
NUM_CPUS =4

TASK_GROUPS_CELEBA = [[2,10,13,14,18,20,25,26,39],
        [3,15,23,1,12],
        [4,5,8,9,11,17,28,32,33],
        [6,21,31,36],
        [7,27],
        [0,16,22,24,30],
        [19,29],
        [34,35,37,38]]

class TorchEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, torch.Tensor):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

def convert_dict_for_json(d):
    new_dict = {}
    if type(d) == list:
        return [convert_dict_for_json(x) for x in d]
    
    for k,v in d.items():
        if type(v) == dict:
            new_dict[k] = convert_dict_for_json(v)
        elif type(v) == str:
            new_dict[k] = v
        elif isinstance(v,torch.Tensor):
            if v.numel() == 1:
                new_dict[k] = float(v.item())
            else:
                new_dict[k] = v.tolist()
        elif type(v) == list:
            new_dict[k] = [convert_dict_for_json(x) for x in v]
        else:
            new_dict[k] = float(v)
    return new_dict

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    set_seed(worker_seed)

if torch.cuda.is_available():
    DEVICE_TYPE = 'cuda'
else:
    DEVICE_TYPE  = "cpu"
    
DEVICE  = torch.device(DEVICE_TYPE)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

print("Using",DEVICE)
import logging
logger = logging.getLogger(__name__)


    
def copy_important_files(files_to_save):
    import shutil
    # Copy important files to the output directory  using system calls
    print("Copying important files to the output directory")
    print("Files to save:", files_to_save)
    os.mkdir("./code/")
    for f in files_to_save:
        shutil.copyfile(f, os.path.join("./code/",f.split("/")[-1]))


def plot_eval_batch(results, dict_key="eval_during_train",plot_name="results",y_axis="F1 Score",plot_cumsum=False,make_individual_plots=False):
    def plot_seeds_mean(ax,x, accuracies, label,color):
        accuracies = np.array(accuracies)
        means = np.mean(accuracies, axis=0)
        stds = np.std(accuracies, axis=0)
        ax.plot(x, means, label=label, color=color)
        ax.fill_between(x,means-stds, means+stds,
                alpha=0.2, edgecolor=color, facecolor=color)
        

    att_evals_train = np.array([d[dict_key] for d in results["ATTENTION"]])[:,0]
    com_evals_train = np.array([d[dict_key] for d in results["COMODULATION"]])[:,0]

    finetune_x = np.arange(0, com_evals_train.shape[-1], 1)

    fig, ax = plt.subplots()

    plot_seeds_mean(ax,finetune_x, att_evals_train, "Attention", "blue")
    plot_seeds_mean(ax,finetune_x, com_evals_train, "Comodulation", "red")

    if dict_key == "eval_during_train":
        pre_evals_train = np.array([d[dict_key] for d in results["PRETRAIN"]])
        pre_evals_rep = np.ones((finetune_x.shape[0]))*pre_evals_train.mean()
        ax.plot(finetune_x, pre_evals_rep, label="Pretraining", color="black", linestyle="--")
        
    ax.set_xlabel("# of training batches")
    ax.set_ylabel(y_axis)
    ax.legend()    
    fig.savefig(f"{plot_name}_std.png")

    if plot_cumsum:
        fig, ax = plt.subplots()
        att_cumsum = np.cumsum(att_evals_train,axis=1)
        com_cumsum = np.cumsum(com_evals_train,axis=1)
        plot_seeds_mean(ax,finetune_x, att_cumsum, "Attention", "blue")
        plot_seeds_mean(ax,finetune_x, com_cumsum, "Comodulation", "red")

        ax.set_xlabel("# of training batches")
        ax.set_ylabel(y_axis)
        ax.legend()    
        fig.savefig(f"{plot_name}_cumsum.png")
    plt.close()
    if make_individual_plots:
        if not os.path.isdir(plot_name):
            os.mkdir(plot_name)
        for i in range(att_evals_train.shape[0]):
            fig, ax = plt.subplots()
            ax.plot(finetune_x, att_evals_train[i], label="Attention", color="blue")
            ax.plot(finetune_x, com_evals_train[i], label="Comodulation", color="red")
            ax.set_xlabel("# of training batches")
            ax.set_ylabel(y_axis)
            ax.legend()    
            fig.savefig(f"{plot_name}/seed_{i}.png")
            plt.close()
    
    
def get_dataloaders_celeba(task_groups, batch_size,image_size,add_augmentations=False,generator=None):
    train_dataset = CelebaGroupedDataset(data_dir="/celeba", split='train', image_size=image_size, task_groups=task_groups,add_augmentations=add_augmentations)
    val_dataset = CelebaGroupedDataset(data_dir="/celeba", split='val', image_size=image_size, task_groups=task_groups)
    test_dataset = CelebaGroupedDataset(data_dir="/celeba", split='test', image_size=image_size, task_groups=task_groups)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_CPUS, pin_memory=True,generator=generator,worker_init_fn=seed_worker)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_CPUS, pin_memory=True,generator=generator,worker_init_fn=seed_worker)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_CPUS, pin_memory=False)
    return train_loader, val_loader,test_loader


def evaluate(model, val_loader, criterion, num_tasks, num_classes, class_to_class_idx,train_config=None):
    model.eval()
    other_val_metrics = {}    
    metrics = MetricAggregator([F1(ignore_val=-1) for _ in range(num_tasks)])
    with torch.no_grad(): 
        for i,batch in enumerate(val_loader):                    
            # Get data
            data, targets = batch[0], batch[1]
            if len(batch) == 2:
                tasks = None
                data, targets = data.to(DEVICE), [elt.to(DEVICE) for elt in targets]
            elif len(batch) == 3:
                tasks = batch[2]
                data, targets,tasks = data.to(DEVICE),targets.to(DEVICE),tasks.to(DEVICE)
#            data, targets = data.to(DEVICE), [elt.to(DEVICE) for elt in targets]
            # Forward
            outputs,decoder_activity,_ = model(data,tasks=tasks)
            metrics.update(outputs, targets)

        return metrics, other_val_metrics
        
        
def train_test( train_config,modulation_params,network, controller, task_groups,is_attention=False,is_comodulation=False,seed=0):
    generator = torch.Generator()
    generator = generator.manual_seed(seed)
    set_seed(seed)

    num_tasks = len(task_groups)
    num_classes = sum([len(elt) for elt in task_groups])

    acc_dict = {"train": [], "val": [], "test": [], "eval_during_train":[],"num_informative_neurons":[]}
    
    is_pretrain = not (is_attention or is_comodulation )
    class_to_class_idx = []
    last_class = 0
    for task in range(num_tasks):
        class_to_class_idx.append( torch.arange(last_class, last_class+len(task_groups[task])).int() )
        last_class += len(task_groups[task])
    
    model = MultiTaskModulatedNet(network=network, controller=controller,device=DEVICE, is_attention=is_attention,is_comodulation=is_comodulation,**modulation_params)
    model.set_modulation_training(is_attention or is_comodulation )
    model.to(DEVICE)

    WHATS_GOING_ON = "ATTENTION" if is_attention else "COMODULATION" if is_comodulation else "PRETRAINING"
    
    logger.info("Training with %s"%WHATS_GOING_ON)
    
    optimizer = torch.optim.Adam(model.get_parameters_for_optimizer(), lr=train_config.lr)#,)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=20, gamma=0.5)

    train_loader, val_loader, test_loader= get_dataloaders_celeba(task_groups=task_groups,batch_size=train_config.batch_size,image_size=train_config.image_size,add_augmentations=(train_config.add_augmentations ),generator=generator)#}and not (is_attention or is_comodulation)))

    smaller_val_ds = torch.utils.data.Subset(val_loader.dataset, random.Random(0).sample( list(range(0, len(val_loader.dataset))), len(val_loader.dataset)//5) )
    smaller_val_loader = torch.utils.data.DataLoader(smaller_val_ds, batch_size=64, drop_last=False, shuffle=False, num_workers=NUM_CPUS, pin_memory=True)
    
    criterion = LossAggregator([MaskedBCEWithLogitsLoss(-1) for _ in range(num_tasks)]) 
    best_f1_score = 0
    best_model = None
    for epoch in range(1, train_config.max_epoch + 1):
    # Train loop
        if (is_attention or is_comodulation ) and train_config.eval_every_batch:
            acc_dict["eval_during_train"].append([])
            val_task_metrics, other_val_metrics = evaluate(model,val_loader=smaller_val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
            acc_dict["eval_during_train"][-1].append(val_task_metrics.get_mean())

        train_metrics = MetricAggregator([F1(ignore_val=-1) for _ in range(num_tasks)])
            
        for batch_num, batch in enumerate(train_loader):
            # Get data
            model.train()
            if (is_comodulation or is_attention) and not train_config.retrain_everything:
                model.controller.train()
                model.network.eval()
            
            data, targets = batch[0], batch[1]
            if len(batch) == 2:
                tasks = None
                data, targets = data.to(DEVICE), [elt.to(DEVICE) for elt in targets]
            
            if len(batch) == 3:
                tasks = batch[2]
                data, targets,tasks = data.to(DEVICE),targets.to(DEVICE),tasks.to(DEVICE)
            
            # Forward pass
            outputs,decoder_activity,controller_params = model(data,tasks=tasks)
            task_losses = criterion.compute(outputs,targets)#compute_losses(outputs,targets,criterion,mask_val=-1)

            loss = torch.sum(task_losses)
            bce_loss = loss.item()
            if train_config.do_l1_loss_on_controller:
                loss += train_config.l1_loss_weight*controller_params.abs().mean()
                
                            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_item = loss.item()
            del loss
            # Scoring
            train_metrics.update(outputs, targets)
            
            if train_config.eval_every_batch and (batch_num % train_config.eval_every_n_batch == 0):
                #val_losses, val_accs, val_precs, val_recs, val_fscores, other_val_metrics = evaluate(model,val_loader=smaller_val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
                val_task_metrics, other_val_metrics = evaluate(model,val_loader=smaller_val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
                acc_dict["eval_during_train"][-1].append(val_task_metrics.get_mean())
            
            if batch_num % train_config.log_freq == 0:
                logger.info('Epoch {}, iter {}/{}, Loss : {:.4f}, Class loss {:.4f}, fscores {:.4f}'.format(epoch, batch_num+1, len(train_loader), loss_item,float(bce_loss), train_metrics.get_mean()["f1"]))

            if (train_config.max_batch > 0) and (batch_num > train_config.max_batch):
                break
        acc_dict["train"].append(train_metrics.get_mean())    
        lr_scheduler.step()
        
        #############
        # Eval loop #
        #############
        if not train_config.eval_every_batch:
            val_task_metrics, other_val_metrics = evaluate(model,val_loader=smaller_val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
            acc_dict["eval_during_train"].append(val_task_metrics.get_mean())

        #val_losses, val_accs, val_precs, val_recs, val_fscores, _ = evaluate(model,val_loader=val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx)
        val_task_metrics, other_val_metrics = evaluate(model,val_loader=val_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
        # Out line
        eval_string ='{} EVAL EPOCH {}, '.format(WHATS_GOING_ON,epoch)
        eval_string += f"MEAN {val_task_metrics.get_string()}"
        logger.info(eval_string)

        if best_f1_score < val_task_metrics.get_mean()["f1"]:
            logger.info("New best model")
            best_f1_score = val_task_metrics.get_mean()["f1"]
            best_model = model.state_dict()
            acc_dict["best_val"] = val_task_metrics.get_mean()
        mean_metric = val_task_metrics.get_mean()

        acc_dict["val"].append(mean_metric)
    
    del val_loader
    model.load_state_dict(best_model)
    model.eval()
    
    test_task_metrics, other_val_metrics = evaluate(model,val_loader=test_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx,train_config=train_config)
    #test_losses, test_accs, test_precs, test_recs, test_fscores,_ = evaluate(model,val_loader=test_loader,criterion=criterion,num_tasks=num_tasks,num_classes=num_classes,class_to_class_idx=class_to_class_idx)
    test_string ='{} TEST, '.format(WHATS_GOING_ON)
    test_string += f"MEAN {test_task_metrics.get_string()}"
    logger.info(test_string)
    
    acc_dict["test"] = (test_task_metrics.get_mean())
        
    return model, acc_dict

@hydra.main(config_path="configs", config_name="train_mt_classification")
def main(config):
    
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    print(config)

    task_groups = TASK_GROUPS_CELEBA

    config.network.task_groups = task_groups
    if config.copy_files_to_save:
        copy_important_files(config.files_to_save)
    if config.make_embeddings_plot:
        os.mkdir("./plots/")
        

    if config.make_controller_embedding_plot:
        os.mkdir("./controller_embedding/")
        
    COMODULATION_RESULTS = []
    ATTENTION_RESULTS = []
    PRETRAIN_RESULTS = []
    
    for seed in  range(config.n_train_test):
        
        corruptions_results_seed = {}
        print("Train test from scratch {}".format(seed))
        set_seed(seed)
        network = hydra.utils.instantiate(config.network).to(DEVICE)
        pretrained_model, pretrain_acc_dict = train_test(config.pretrain_config,config.modulation_params, network, None, task_groups,is_attention=False,is_comodulation=False,seed=seed)
        PRETRAIN_RESULTS.append(pretrain_acc_dict)
        pretrained_network = pretrained_model.network.to("cpu")
        set_seed(seed)
        controller = TaskController(network.get_layer_to_modulate().conv_out_shape,
                                    task_groups=task_groups,
                                    device=DEVICE,
                                **config.controller).to(DEVICE)
        
        att_network = hydra.utils.instantiate(config.network)
        att_network.load_state_dict(pretrained_network.state_dict())
        att_network = att_network.to(DEVICE)
        att_model, att_acc_dict = train_test(config.attention_config,config.modulation_params, att_network, controller,task_groups,is_attention=True,is_comodulation=False,seed=seed)

        if config.make_controller_embedding_plot:
            plot_controller_embedding(att_model,plot_name=f"./controller_embedding/attention_{seed}.pdf")

        ATTENTION_RESULTS.append(att_acc_dict)
        del att_network
        
        set_seed(seed)
        controller = TaskController(network.get_layer_to_modulate().conv_out_shape,
                                    task_groups=task_groups,
                                    device=DEVICE,
                                **config.controller).to(DEVICE)
            
        com_network = hydra.utils.instantiate(config.network)
        com_network.load_state_dict(pretrained_network.state_dict())
        com_network = com_network.to(DEVICE)

        com_model, com_acc_dict = train_test(config.comodulation_config,config.modulation_params, com_network, controller,task_groups,is_attention=False,is_comodulation=True,seed=seed)
        COMODULATION_RESULTS.append(com_acc_dict)
        if config.make_controller_embedding_plot and config.modulation_params.comod_backprop:
            plot_controller_embedding(com_model,plot_name=f"./controller_embedding/comodulation_{seed}.pdf")

        del com_network
        

    logger.info("Metrics Scores FOR EVERY SEED:")
    PRETRAIN_TEST = [x["test"] for x in PRETRAIN_RESULTS]
    ATTENTION_TEST = [x["test"] for x in ATTENTION_RESULTS]
    COMODULATION_TEST = [x["test"] for x in COMODULATION_RESULTS]

    logger.info("TEST:")
    metrics = list(COMODULATION_TEST[0].keys())
    for met in metrics:
        for seed_resu in range(config.n_train_test):
            logger.info("SEED {}".format(seed_resu))
            logger.info("PRETRAIN  {}, ATTENTION {},  COMODULATION {}".format(PRETRAIN_RESULTS[seed_resu]["test"][met],ATTENTION_RESULTS[seed_resu]["test"][met],COMODULATION_RESULTS[seed_resu]["test"][met]))

        logger.info(f"{met} Scores MEAN:")
        logger.info("PRETRAIN  {}, ATTENTION {},  COMODULATION {} ".format(np.mean([seed_r[met] for seed_r in PRETRAIN_TEST]),np.mean([seed_r[met] for seed_r in ATTENTION_TEST]),np.mean([seed_r[met] for seed_r in COMODULATION_TEST])))

    logger.info("BEST VALIDATION:")
    PRETRAIN_BEST_VAL = [x["best_val"] for x in PRETRAIN_RESULTS]
    ATTENTION_BEST_VAL = [x["best_val"] for x in ATTENTION_RESULTS]
    COMODULATION_BEST_VAL = [x["best_val"] for x in COMODULATION_RESULTS]
    for met in metrics:
        print(f"{met} Scores:")
        for seed_resu in range(config.n_train_test):
            logger.info("SEED {}".format(seed_resu))
            logger.info("PRETRAIN  {}, ATTENTION {},  COMODULATION {}".format(PRETRAIN_RESULTS[seed_resu]["best_val"][met],ATTENTION_RESULTS[seed_resu]["best_val"][met],COMODULATION_RESULTS[seed_resu]["best_val"][met]))

        logger.info(f"{met} Scores MEAN:")
        logger.info("PRETRAIN  {}, ATTENTION {},  COMODULATION {} ".format(np.mean([seed_r[met] for seed_r in PRETRAIN_BEST_VAL]),np.mean([seed_r[met] for seed_r in ATTENTION_BEST_VAL]),np.mean([seed_r[met] for seed_r in COMODULATION_BEST_VAL])))

    save_dict = {"PRETRAIN":PRETRAIN_RESULTS,"ATTENTION":ATTENTION_RESULTS,"COMODULATION":COMODULATION_RESULTS}
    ## Save results in json dict:
    with open("results.json","w") as f:
        json.dump(save_dict,f,cls=TorchEncoder)


    if config.comodulation_config.eval_every_batch:
        plot_eval_batch(save_dict,dict_key="eval_during_train",plot_name="results",y_axis="F1 score",make_individual_plots=True)

if __name__ == '__main__':
    main()


