import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import random, os
import argparse
import numpy as np
from collections import OrderedDict
from scipy.special import softmax
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy import spatial
from graphviz import Digraph

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from load_data import EMNIST
import layers as layers
from train_mtl import Trainer
from donn_model_mtl import DiffractiveClassifier_Raw


###########################################################

def main(args): 

    print("---------------")
    print(args)
    print("---------------")
    
    torch.backends.cudnn.benchmark = True

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    ### load data
    data_root = args.data_root

    transform = transforms.Compose([transforms.Resize((200,200),interpolation=2),transforms.ToTensor()])

    train_dataset_m = torchvision.datasets.MNIST(data_root+"data", train=True, transform=transform, download=True)
    val_dataset_m = torchvision.datasets.MNIST(data_root+"data", train=False, transform=transform, download=True)
    train_dataset_f = torchvision.datasets.FashionMNIST(data_root+"Fdata", train=True, transform=transform, download=True)
    val_dataset_f = torchvision.datasets.FashionMNIST(data_root+"Fdata", train=False, transform=transform, download=True)
    train_dataset_k = torchvision.datasets.KMNIST(data_root+"Kdata", train=True, transform=transform, download=True)
    val_dataset_k = torchvision.datasets.KMNIST(data_root+"Kdata", train=False, transform=transform, download=True)
    train_dataset_e = EMNIST(data_root+"Edata", train=True, divide=0, transform=transform, download=True)
    val_dataset_e = EMNIST(data_root+"Edata", train=False, divide=0, transform=transform, download=True)

    train_dataloader_m = DataLoader(dataset=train_dataset_m, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=False)
    val_dataloader_m = DataLoader(dataset=val_dataset_m, batch_size=args.batch_size, num_workers=4, shuffle=False, pin_memory=False)
    train_dataloader_f = DataLoader(dataset=train_dataset_f, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=False)
    val_dataloader_f = DataLoader(dataset=val_dataset_f, batch_size=args.batch_size, num_workers=4, shuffle=False, pin_memory=False)
    train_dataloader_k = DataLoader(dataset=train_dataset_k, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=False)
    val_dataloader_k = DataLoader(dataset=val_dataset_k, batch_size=args.batch_size, num_workers=4, shuffle=False, pin_memory=False)
    train_dataloader_e = DataLoader(dataset=train_dataset_e, batch_size=args.batch_size, num_workers=8, shuffle=True, pin_memory=False)
    val_dataloader_e = DataLoader(dataset=val_dataset_e, batch_size=args.batch_size, num_workers=4, shuffle=False, pin_memory=False)

    tasks = ['mnist', 'fmnist', 'kmnist', 'emnist']
    train_dataloader = {"mnist": train_dataloader_m, 
                        "fmnist": train_dataloader_f, 
                        "kmnist": train_dataloader_k, 
                        "emnist": train_dataloader_e}
    val_dataloader = {"mnist": val_dataloader_m, 
                      "fmnist": val_dataloader_f, 
                      "kmnist": val_dataloader_k, 
                      "emnist": val_dataloader_e}

    headsDict = nn.ModuleDict()
    trainDataloaderDict = {task: [] for task in tasks}
    valDataloaderDict = {task: [] for task in tasks}
    criterionDict = {}
    metricDict = {}

    criterion = torch.nn.MSELoss(reduction='sum').cuda()
    for task in tasks:
        headsDict[task] = layers.Detector(x_loc = [40, 40, 40, 90, 90, 90, 90, 140, 140, 140], 
                                          y_loc = [40, 90, 140, 30, 70, 110, 150, 40, 90, 140], 
                                          det_size = 20, size = args.sys_size)
        trainDataloaderDict[task] = train_dataloader[task]
        valDataloaderDict[task] = val_dataloader[task]
        criterionDict[task] = criterion
        metricDict[task] = []


    ### Define MTL model
    mtlmodel = DiffractiveClassifier_Raw(num_layers = args.depth, 
                                         wavelength = args.wavelength, 
                                         pixel_size = args.pixel_size, 
                                         sys_size=args.sys_size, 
                                         pad = args.pad,
                                         distance = args.distance,
                                         amp_factor=args.amp_factor, 
                                         approx=args.approx,
                                         heads_dict=headsDict)
    mtlmodel = mtlmodel.to(device)

    ### Define training framework
    trainer = Trainer(mtlmodel, 
                      trainDataloaderDict, valDataloaderDict, 
                      criterionDict, metricDict, 
                      print_iters=10, val_iters=100, 
                      save_iters=100, save_num=1, 
                      policy_update_iters=100)

    # ----------------
    ### validation
    if args.evaluate:
        print(">>>>>>>> Validation <<<<<<<<<<")
        ckpt = torch.load(args.evaluate)
        # print(ckpt["state_dict"].keys())
        mtlmodel.load_state_dict(ckpt["state_dict"])
        sd = mtlmodel.state_dict()
        # print(mtlmodel)
        trainer.validate('mtl', hard=True) 

        ## policy visualization
        if args.visualize:
            
            name = args.evaluate.split("/")[-1].split(".")[0]
            vis_savepath = f"result/{name}"
            if not os.path.exists(vis_savepath):
                os.makedirs(vis_savepath)
            print(f"All visualization save to {vis_savepath}")
            
            ## policy visualization        
            policy_list = {"mnist": [], "fmnist": [], "kmnist": [], "emnist": []}
            for name, param in mtlmodel.named_parameters():
                if 'policy' in name and not torch.eq(param, torch.tensor([0., 0., 0.]).cuda()).all():
                    policy = param.data.cpu().detach().numpy()
                    distribution = softmax(policy, axis=-1)
                    if '.mnist' in name:
                        policy_list['mnist'].append(distribution)
                    elif '.fmnist' in name:
                        policy_list['fmnist'].append(distribution)
                    elif '.kmnist' in name:
                        policy_list['kmnist'].append(distribution)
                    elif '.emnist' in name:
                        policy_list['emnist'].append(distribution)
            print(policy_list)

            spectrum_list = []
            ylabels = {'mnist': 'MNIST',
                        "fmnist": "FMNIST",
                        "kmnist": "KMNIST",
                        "emnist": "EMNIST"} 
            tickSize = 15
            labelSize = 16
            for task in tasks:
                policies = policy_list[task]    
                spectrum = np.stack([policy for policy in policies])
                spectrum = np.repeat(spectrum[np.newaxis,:,:],1,axis=0)
                spectrum_list.append(spectrum)
                
                plt.figure(figsize=(10,5))
                plt.xlabel('Layer No.', fontsize=labelSize)
                plt.xticks(fontsize=tickSize)
                plt.ylabel(ylabels[task], fontsize=labelSize)
                plt.yticks(fontsize=tickSize)
                ax = plt.subplot()
                im = ax.imshow(spectrum.T)
                ax.set_yticks(np.arange(3))
                ax.set_yticklabels(['shared', 'specific', 'skip'])
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", size="2%", pad=0.05)
                cb = plt.colorbar(im, cax=cax)
                cb.ax.tick_params(labelsize=tickSize)
                plt.savefig(f"{vis_savepath}/spect_{task}.png")
                plt.close()

            ### plot task correlation
            policy_list = {"mnist": [], "fmnist": [], "kmnist": [], "emnist": []}
            for name, param in mtlmodel.named_parameters():
                if 'policy' in name and not torch.eq(param, torch.tensor([0., 0., 0.]).cuda()).all():
                    policy = param.data.cpu().detach().numpy()
                    distribution = softmax(policy, axis=-1)
                    if '.mnist' in name:
                        policy_list['mnist'].append(distribution)
                    elif '.fmnist' in name:
                        policy_list['fmnist'].append(distribution)
                    elif '.kmnist' in name:
                        policy_list['kmnist'].append(distribution)
                    elif '.emnist' in name:
                        policy_list['emnist'].append(distribution)
            policy_array = np.array([np.array(policy_list['mnist']).ravel(), 
                                    np.array(policy_list['fmnist']).ravel(), 
                                    np.array(policy_list['kmnist']).ravel(),
                                    np.array(policy_list['emnist']).ravel()])
            sim = np.zeros((4,4))
            for i in range(len(tasks)):
                for j in range(len(tasks)):
                    sim[i,j] = 1 - spatial.distance.cosine(policy_array[i,:], policy_array[j,:])

            mpl.rc('image', cmap='Blues')
            tickSize = 15
            plt.figure(figsize=(10,10))
            plt.xticks(fontsize=tickSize, rotation='vertical')
            plt.yticks(fontsize=tickSize)
            ax = plt.subplot()
            im = ax.imshow(sim)
            ax.set_xticks(np.arange(4))
            ax.set_yticks(np.arange(4))
            ax.set_xticklabels(['mnist', 'fmnist', 'kmnist', 'emnist'])
            ax.set_yticklabels(['mnist', 'fmnist', 'kmnist', 'emnist'])
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="4%", pad=0.05)
            cb = plt.colorbar(im, cax=cax,ticks=[1,0.61])
            cb.ax.set_yticklabels(['high', 'low']) 
            cb.ax.tick_params(labelsize=tickSize)
            plt.savefig(f"{vis_savepath}/task_cor")
            plt.close()

            ### Show Policy (for test)
            dot = Digraph(comment='Policy')
            # make nodes
            layer_num = len(policy_list['mnist'])
            for i in range(layer_num):
                with dot.subgraph(name='cluster_L'+str(i),node_attr={'rank':'same'}) as c:
                    c.attr(rankdir='LR')
                    c.node('L'+str(i)+'B0', 'Shared')
                    c.node('L'+str(i)+'B1', 'Specific')
                    c.node('L'+str(i)+'B2', 'Skip')
            # make edges
            colors = {'mnist': 'blue', 'fmnist': 'yellow', 'emnist': 'green', 'kmnist': 'grey'}
            for task in tasks:
                for i in range(layer_num-1):
                    prev = np.argmax(policy_list[task][i])
                    nxt = np.argmax(policy_list[task][i+1])
                    dot.edge('L'+str(i)+'B'+str(prev), 'L'+str(i+1)+'B'+str(nxt), color=colors[task])
            # dot.render('Best.gv', view=True)
            dot.render(f'{vis_savepath}/Best', view=False)  

        return

    # ==================================
    ### Train

    checkpoint = 'checkpoint/'
    if not os.path.exists(checkpoint):
        os.makedirs(checkpoint)

    savepath = checkpoint+args.save_dir+"/"
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    print(f"All ckpts save to {savepath}")

    # ----------------
    ### Step 1: pre-train
    if args.pretrain:
        print(">>>>>>>> pre-train <<<<<<<<<<")
        trainer.pre_train(iters=args.pretrain_iters, lr=args.lr, 
                          savePath=savepath, writerPath=savepath)

    # ----------------
    ### Step 2: alter-train
    if args.alter_train:
        print(">>>>>>>> alter-train <<<<<<<<<<")
        loss_lambda = {'mnist': 1, 
                       'fmnist': 1, 
                       'kmnist': 1, 
                       'emnist': 1,
                       'policy': 0.0005}
        trainer.alter_train_with_reg(iters = args.alter_iters, policy_network_iters = (50,200), 
                                     # policy_lr = 0.01, network_lr = 0.0001,
                                     policy_lr=0.01, network_lr=0.001,
                                     loss_lambda = loss_lambda,
                                     savePath = savepath, writerPath=savepath,
                                     reload = args.pretrain_model,
                                     ext = args.ext)

    # ----------------
    if args.post_train:
        ### Step 3: sample policy from trained policy distribution and save
        print(">>>>>>>> Sample Policy <<<<<<<<<<")
        policy_list = {"mnist": [], "fmnist": [], "kmnist": [], "emnist": []}
        name_list = {"mnist": [], "fmnist": [], "kmnist": [], "emnist": []}

        if args.alter_model != None:
            state = torch.load(savepath + args.alter_model)
            mtlmodel.load_state_dict(state['state_dict'])

        for name, param in mtlmodel.named_parameters():
            if 'policy' in name and not torch.eq(param, torch.tensor([0., 0., 0., 0., 0.]).cuda()).all():
                # print(name)
                if '.mnist' in name:
                    policy_list['mnist'].append(param.data.cpu().detach().numpy())
                    name_list['mnist'].append(name)
                elif '.fmnist' in name:
                    policy_list['fmnist'].append(param.data.cpu().detach().numpy())
                    name_list['fmnist'].append(name)
                elif '.kmnist' in name:
                    policy_list['kmnist'].append(param.data.cpu().detach().numpy())
                    name_list['kmnist'].append(name)
                elif '.emnist' in name:
                    policy_list['emnist'].append(param.data.cpu().detach().numpy())
                    name_list['emnist'].append(name)

        shared = args.shared
        sample_policy_dict = OrderedDict()
        for task in tasks:
            count = 0
            for name, policy in zip(name_list[task], policy_list[task]):
                if count < shared:
                    sample_policy_dict[name] = torch.tensor([1.0, 0.0, 0.0]).cuda()
                else:
                    distribution = softmax(policy, axis=-1)
                    distribution /= sum(distribution)
                    choice = np.random.choice((0, 1, 2, 3, 4), p = distribution)
                    if choice == 0:
                        sample_policy_dict[name] = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0]).cuda()
                    elif choice == 1:
                        sample_policy_dict[name] = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0]).cuda()
                    elif choice == 2:
                        sample_policy_dict[name] = torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0]).cuda()
                    elif choice == 3:
                        sample_policy_dict[name] = torch.tensor([0.0, 0.0, 0.0, 1.0, 0.0]).cuda()
                    elif choice == 4:
                        sample_policy_dict[name] = torch.tensor([0.0, 0.0, 0.0, 0.0, 1.0]).cuda()
                count += 1

        sample_path = savepath
        sample_state = {'state_dict': sample_policy_dict}
        torch.save(sample_state, sample_path + f'sample_policy{args.ext}.model')

        # ----------------
        ### Step 4: post train from scratch
        print(">>>>>>>> Post-train <<<<<<<<<<")
        loss_lambda = {'mnist': 1, 'fmnist': 1, 'kmnist': 1, 'emnist': 1}
        print("Loss lambda: ", loss_lambda)
        trainer.post_train(iters=args.post_iters, lr=args.post_lr,
                            decay_lr_freq=args.decay_lr_freq, decay_lr_rate=0.5,
                            loss_lambda=loss_lambda,
                            savePath=savepath, writerPath=savepath,
                            reload=f'sample_policy{args.ext}.model',
                            ext = args.ext)
    
    # ----------------
    if args.post_train_rotate:
        ### Step 3: sample policy from trained policy distribution and save
        if not args.reload_policy:
            print(">>>>>>>> Sample Policy <<<<<<<<<<")
            policy_list = {"mnist": [], "fmnist": [], "kmnist": [], "emnist": []}
            name_list = {"mnist": [], "fmnist": [], "kmnist": [], "emnist": []}

            if args.alter_model != None:
                state = torch.load(savepath + args.alter_model)
                mtlmodel.load_state_dict(state['state_dict'])

            for name, param in mtlmodel.named_parameters():
                if 'policy' in name and not torch.eq(param, torch.tensor([0., 0., 0., 0., 0.]).cuda()).all():
                    # print(name)
                    if '.mnist' in name:
                        policy_list['mnist'].append(param.data.cpu().detach().numpy())
                        name_list['mnist'].append(name)
                    elif '.fmnist' in name:
                        policy_list['fmnist'].append(param.data.cpu().detach().numpy())
                        name_list['fmnist'].append(name)
                    elif '.kmnist' in name:
                        policy_list['kmnist'].append(param.data.cpu().detach().numpy())
                        name_list['kmnist'].append(name)
                    elif '.emnist' in name:
                        policy_list['emnist'].append(param.data.cpu().detach().numpy())
                        name_list['emnist'].append(name)

            shared = args.shared
            sample_policy_dict = OrderedDict()
            for task in tasks:
                count = 0
                for name, policy in zip(name_list[task], policy_list[task]):
                    if count < shared:
                        ### force first layer be shared
                        sample_policy_dict[name] = torch.tensor([1.0, 0.0, 0.0]).cuda()
                        # ----------------
                        # ### force first layer be dynamic
                        # sample_policy_dict[name] = torch.tensor([0.0, 1.0, 0.0]).cuda()
                        # ----------------
                    else:
                        distribution = softmax(policy, axis=-1)
                        distribution /= sum(distribution)
                        choice = np.random.choice((0, 1, 2, 3, 4), p = distribution)
                        if choice == 0:
                            sample_policy_dict[name] = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0]).cuda()
                        elif choice == 1:
                            sample_policy_dict[name] = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0]).cuda()
                        elif choice == 2:
                            sample_policy_dict[name] = torch.tensor([0.0, 0.0, 1.0, 0.0, 0.0]).cuda()
                        elif choice == 3:
                            sample_policy_dict[name] = torch.tensor([0.0, 0.0, 0.0, 1.0, 0.0]).cuda()
                        elif choice == 4:
                            sample_policy_dict[name] = torch.tensor([0.0, 0.0, 0.0, 0.0, 1.0]).cuda()
                    count += 1

            sample_path = savepath
            sample_state = {'state_dict': sample_policy_dict}
            reload_policy = f'sample_policy{args.ext}.model'
            torch.save(sample_state, sample_path + reload_policy)
        else:
            reload_policy = args.reload_policy

        print("Policy file:", reload_policy)

        # # ----------------
        ### Step 4: post train from scratch
        print(">>>>>>>> Post-train <<<<<<<<<<")
        loss_lambda = {'mnist': 1, 'fmnist': 1, 'kmnist': 1, 'emnist': 1}
        print("Loss Lambda:", loss_lambda)
        mtlmodel_shrd = DiffractiveClassifier_Raw(num_layers = args.depth, 
                                                    wavelength = args.wavelength, 
                                                    pixel_size = args.pixel_size, 
                                                    sys_size=args.sys_size, 
                                                    pad = args.pad,
                                                    distance = args.distance,
                                                    amp_factor=args.amp_factor, 
                                                    approx=args.approx,
                                                    heads_dict=headsDict)
        mtlmodel_shrd = mtlmodel_shrd.to(device)
        
        if args.auto:
            trainer.post_train_rotate_auto(iters=args.post_iters, model_copy=mtlmodel_shrd, depth=args.depth,  
                                            lr=args.post_lr, decay_lr_freq=args.decay_lr_freq, 
                                            decay_lr_rate=0.5, loss_lambda=loss_lambda,
                                            savePath=savepath, writerPath=savepath,
                                            reload=reload_policy,
                                            ext=args.ext)
        elif args.alt:
            trainer.post_train_rotate_alt(iters=args.post_iters, model_copy=mtlmodel_shrd, depth=args.depth,  
                                            lr=args.post_lr, decay_lr_freq=args.decay_lr_freq, 
                                            decay_lr_rate=0.5, loss_lambda=loss_lambda,
                                            savePath=savepath, writerPath=savepath,
                                            reload=reload_policy,
                                            ext=args.ext)
        else:
            print("!!! NOT ROTATE!")
            
# # --------------------------------------------------------------

if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    ### DONN parameters
    parser.add_argument('--data-root', type=str, default='../../../', help="data path")
    parser.add_argument('--batch-size', type=int, default=200)
    parser.add_argument('--save-dir', type=str, default='avg_L5', help="save the model")
    parser.add_argument('--evaluate', type=str, help="Model path for evaulation")
    parser.add_argument('--visualize', action='store_true', default=False, help="visualize model during validation")

    parser.add_argument('--pretrain', action='store_true', default=False, help='whether to run pre-train part')
    parser.add_argument('--pretrain-iters', type=int, default=3000, help='#iterations for pre-training, default: 10000')
    parser.add_argument('--lr', type=float, default=0.01, help='pre-train learning rate')

    parser.add_argument('--alter-train', action='store_true', default=False, help='whether to run alter-trian part')
    parser.add_argument('--pretrain-model', type=str, default=None, help="pretrain model in alter-train")
    parser.add_argument('--alter-iters', type=int, default=6000, help='#iterations for alter-train, default: 20000')

    parser.add_argument('--post-train', action='store_true', default=False, help='whether to run post-train part')
    parser.add_argument('--post-train-rotate', action='store_true', default=False, help='whether to run post-train part')
    parser.add_argument('--auto', action='store_true', default=False, help='whether to run post-train part')
    parser.add_argument('--alt', action='store_true', default=False, help='whether to run post-train part')
    parser.add_argument('--reload-policy', type=str, default=None, help='load existing policy')
    parser.add_argument('--alter-model', type=str, default=None, help="alter-train model in post-train")
    parser.add_argument('--post-iters', type=int, default=30000, help='#iterations for post-train, default: 30000')
    parser.add_argument('--post-lr', type=float, default=0.1, help='post-train learning rate')
    parser.add_argument('--decay-lr-freq', type=float, default=3000, help='post-train learning rate decay frequency')
    parser.add_argument('--shared', type=float, default=0, help='number of layers force to share during sample policy')

    parser.add_argument('--ext', type=str, default='', help="extension for saved the model")

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--depth', type=int, default=5, help='number of fourier optic transformations/num of layers')
    parser.add_argument('--sys-size', type=int, default=200, help='system size (dim of each diffractive layer)')
    parser.add_argument('--distance', type=float, default=0.3, help='layer distance (default=0.1 meter)')
    parser.add_argument('--amp-factor', type=float, default=4, help='regularization factors to balance phase-amplitude where they share same downstream graidents')
    parser.add_argument('--pixel-size', type=float, default=0.000036, help='the size of pixel in diffractive layers')
    parser.add_argument('--pad', type=int, default=50, help='the padding size ')
    parser.add_argument('--approx', type=str, default='Fresnel3', help="Use which Approximation, Sommerfeld, fresnel or fraunhofer.")
    parser.add_argument('--wavelength', type=float, default=5.32e-7, help='wavelength')

    args_ = parser.parse_args()

    main(args_)
