import os, sys, time, warnings
import argparse
from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision
import torch.utils.data as data_utils

from mnist_loader import MnistRotated
from autoencoders.ae_model import AE
from utils.metrics import part_wd, evaluate_fid_score

sys.path.append('/iaf')
from iaf_torch import add_one_layer, IterAlignFlow
from utils.inb_utils import prepare_data, inb_translate




class wrap_enc(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        self.ae_list = []
        for dd in args.list_train_domains:
            ae = AE(args)
            ae_path = args.ae_dir + args.ae_model + '-' + str(dd) + '.pt'
            ae.load_state_dict(torch.load(ae_path))
            ae = ae.to(device)
            self.ae_list.append(ae.encoder)
            print(f'Finish loading encoder from {ae_path}')

    def forward(self, X, y):
        X = X.view(-1, 1, 28, 28)
        return self.ae_list[y](X).view(X.shape[0], -1)

class wrap_dec(nn.Module):
    def __init__(self, args,device):
        super().__init__()
        self.ae_list = []
        for dd in args.list_train_domains:
            ae = AE(args)
            ae_path = args.ae_dir + args.ae_model + '-' + dd + '.pt'
            ae.load_state_dict(torch.load(ae_path))
            ae = ae.to(device)
            self.ae_list.append(ae.decoder)
            print(f'Finish loading decoder from {ae_path}')

    def forward(self, X, y):
        X = X.view(-1, 8, 7, 7)
        return self.ae_list[y](X).view(X.shape[0], -1)

def prepare_data_domains(imgs, labels, domains, label, domain_list, train=True):
    xlist = []
    dlist = []
    for d in domain_list:
        xd, _, dd = prepare_data(imgs, labels, domains, label, d)
        xlist.append(xd)
        dlist.append(dd)
    x = torch.cat(xlist)
    x = x.view(x.shape[0], -1)
    d = torch.cat(dlist)

    if train:
        # make the number of samples to be even
        idx = int(x.shape[0] / 2) * 2
        x = x[:idx]
        d = d[:idx]
        #d = d.numpy()

    return x, d





def eval_fid_wd_init(x, d, cd_enc, cd_dec, domain_list, mat=False, fid=True, wd=True):
    wd_mat = torch.zeros(len(domain_list),len(domain_list))
    fid_mat = torch.zeros(len(domain_list), len(domain_list))
    for idx in domain_list:
        xt = x[d==idx]
        dc = d[d==idx]
        for jdx in domain_list:
            xr = x[d == jdx]
            assert torch.max(xr) <= 1 and torch.max(xt) <= 1 and torch.min(xr) >= 0\
                   and torch.min(xt) >= 0, 'Check range of output'
            if wd:
                wd_mat[idx,jdx] = part_wd(xr.cpu(),xt.cpu())
            if fid:
                fid_mat[idx, jdx] = evaluate_fid_score(
                    xr.view(-1, 1, 28, 28).cpu().detach().numpy().reshape(xr.shape[0], 28, 28, 1),
                    xt.view(-1, 1, 28, 28).cpu().detach().numpy().reshape(xt.shape[0], 28, 28, 1))


    avg_wd = torch.mean(wd_mat).item()
    avg_fid = torch.mean(fid_mat).item()
    if mat:
        return avg_wd, wd_mat, avg_fid, fid_mat
    else:
        return avg_wd, avg_fid

def eval_fid_wd_init_enc(x, d, cd_enc, cd_dec, domain_list, mat=False, fid=True, wd=True):
    wd_mat = torch.zeros(len(domain_list),len(domain_list))
    fid_mat = torch.zeros(len(domain_list), len(domain_list))
    for idx in domain_list:
        xc = x[d==idx]
        dc = d[d==idx]
        x_enc = cd_enc(xc,idx)
        xt = cd_dec(x_enc, idx)
        for jdx in domain_list:
            xr = x[d == jdx]
            assert torch.max(xr) <= 1 and torch.max(xt) <= 1 and torch.min(xr) >= 0\
                   and torch.min(xt) >= 0, 'Check range of output'
            if wd:
                wd_mat[idx,jdx] = part_wd(xr.cpu(),xt.cpu())
            if fid:
                fid_mat[idx, jdx] = evaluate_fid_score(
                    xr.view(-1, 1, 28, 28).cpu().detach().numpy().reshape(xr.shape[0], 28, 28, 1),
                    xt.view(-1, 1, 28, 28).cpu().detach().numpy().reshape(xt.shape[0], 28, 28, 1))


    avg_wd = torch.mean(wd_mat).item()
    avg_fid = torch.mean(fid_mat).item()
    if mat:
        return avg_wd, wd_mat, avg_fid, fid_mat
    else:
        return avg_wd, avg_fid


def eval_fid_wd(x, d, cd, cd_enc, cd_dec, domain_list, mat=True, fid=True,wd =True):
    wd_mat = torch.zeros(len(domain_list),len(domain_list))
    fid_mat = torch.zeros(len(domain_list), len(domain_list))
    for idx in domain_list:
        xc = x[d==idx]
        dc = d[d==idx]
        x_enc = cd_enc(xc,idx)
        for jdx in domain_list:
            xr = x[d == jdx]
            xt = inb_translate(cd,x_enc,dc,jdx)
            xt = cd_dec(xt,jdx)
            assert torch.max(xr) <= 1 and torch.max(xt) <= 1 and torch.min(xr) >= 0\
                   and torch.min(xt) >= 0, 'Check range of output'
            if wd:
                wd_mat[idx,jdx] = part_wd(xr.cpu(),xt.cpu())
            if fid:
                fid_mat[idx, jdx] = evaluate_fid_score(
                    xr.view(-1, 1, 28, 28).cpu().detach().numpy().reshape(xr.shape[0], 28, 28, 1),
                    xt.view(-1, 1, 28, 28).cpu().detach().numpy().reshape(xt.shape[0], 28, 28, 1))


    avg_wd = torch.mean(wd_mat).item()
    avg_fid = torch.mean(fid_mat).item()
    if mat:
        return avg_wd, wd_mat, avg_fid, fid_mat
    else:
        return avg_wd, avg_fid

def train_inb(train_imgs,
              train_labels,
              train_domains,
              test_imgs,
              test_labels,
              test_domains,
              label,
              inb_dict,
              nlayers,
              K,
              cd_enc,
              cd_dec,
              tracker_dict,
              domain_list=[0, 1, 2, 3, 4],
              vis=False,
              fid=True,
              wd=True,
              log_interval=1,
              bary_type='nb',
              verbose=True,
              max_swd_iters=200,
              hist_bins = 2000,
              quantile=False,
              device = torch.device('cuda')):
    # =================================================================================== #
    #                                      Prepare data                                   #
    # =================================================================================== #

    x_train, d_train = prepare_data_domains(train_imgs, train_labels, train_domains, label, domain_list)
    x_test, d_test = prepare_data_domains(test_imgs, test_labels, test_domains, label, domain_list, train=False)

    x_train = x_train.to(device)
    d_train = d_train.to(device)
    x_test = x_test.to(device)
    d_test = d_test.to(device)

    # =================================================================================== #
    #                                         Set up                                      #
    # =================================================================================== #

    start = time.time()
    cd = IterAlignFlow()
    tracker = dict()
    tracker['fid'] = list()
    tracker['wd'] = list()
    tracker['nparams'] = list()

    z_train = torch.zeros(x_train.shape[0], 8*7*7).to(device)
    for d in domain_list:
        ddx = d_train == d
        z_train[ddx] = cd_enc(x_train[ddx],d)
    n_params = 0


    # keep track of metrics for AE
    if fid or wd:
        avg_wd, wd_mat, avg_fid, fid_mat = eval_fid_wd_init(x_test,d_test,cd_enc,cd_dec,domain_list, mat=True)
        print(f'Initially, the FID for digit {label} is {avg_fid}')
        print(f'Initially, the WD for digit {label} is {avg_wd}')
        tracker['wd'].append(wd_mat)
        tracker['fid'].append(fid_mat)
        tracker['nparams'].append(n_params)

    if fid or wd:
        avg_wd, wd_mat, avg_fid, fid_mat = eval_fid_wd_init_enc(x_test, d_test, cd_enc, cd_dec, domain_list, mat=True)
        print(f'Initially after ae, the FID for digit {label} is {avg_fid}')
        print(f'Initially after ae, the WD for digit {label} is {avg_wd}')
        tracker['wd'].append(wd_mat)
        tracker['fid'].append(fid_mat)
        tracker['nparams'].append(n_params)

    # =================================================================================== #
    #                                         Training                                    #
    # =================================================================================== #
    # add INB layers
    for i in range(nlayers):
        cd, z_train = add_one_layer(cd, z_train, d_train, K, bary_type=bary_type, max_swd_iters=max_swd_iters,
                                    swd_bins = hist_bins, trans_hist=True, quantile=quantile)

        # =================================================================================== #
        #                                        Evaluation                                   #
        # =================================================================================== #
        n_params += len(domain_list)*(
                cd.layer[-1].swd_iters * 2 *
                (cd.layer[-1].wT.shape[0] * cd.layer[-1].wT.shape[1]
                 + hist_bins * cd.layer[-1].wT.shape[1])
                + cd.layer[-1].swd_extra_iters * 2 * hist_bins * cd.layer[-1].wT.shape[1]) \
                   + cd.layer[-1].nb_params
        if (i + 1) % log_interval == 0:
            print(f'iter {i + 1}')

            if vis:
                x_vis = x_test[d_test==0]
                d_vis = d_test[d_test==0]
                x_vis = x_vis[:10]
                d_vis = d_vis[:10]
                x_vis_list = []
                x_vis_list.append(x_vis)
                x_vis_enc = cd_enc(x_vis,0)
                for vdx in domain_list:
                    x_vis_trans = inb_translate(cd,x_vis_enc,d_vis,vdx)
                    x_vis_trans = cd_dec(x_vis_trans,vdx)
                    x_vis_list.append(x_vis_trans)
                x_vis = torch.cat(x_vis_list)
                grid_img = torchvision.utils.make_grid(x_vis.view(-1, 1, 28, 28), nrow=10, normalize=True)


            if fid or wd:
                avg_wd, wd_mat, avg_fid, fid_mat = eval_fid_wd(x_test, d_test, cd, cd_enc, cd_dec, domain_list, fid=fid, wd=wd, mat=True)
                tracker['wd'].append(wd_mat)
                tracker['fid'].append(fid_mat)
                tracker['nparams'].append(n_params)

            if verbose:
                print(f'at iter{i + 1}, for digit {label}, {n_params} parameters has been transmitted')
                if fid or wd:
                    print(f'at iter{i + 1}, the FID for digit {label} is {avg_fid}')
                    print(f'at iter{i + 1}, the WD for digit {label} is {avg_wd}')

    del(z_train)
    print(f'fitting time: {time.time() - start} s')
    inb_dict[label] = cd

    tracker_dict[label] = tracker
    return inb_dict, tracker_dict


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Train AEINB')

    # training
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed (default: 0)')
    parser.add_argument('--mnist_subset', type=str, default='full')
    parser.add_argument('--label_list',type=list,default=list(range(10)))
    parser.add_argument('--log-interval', type=int, default=1)

    parser.add_argument('--K',type=int, default=10)
    parser.add_argument('--nlayers', type=int, default=10)
    parser.add_argument('--max_swd_iters', type=int, default=100)
    parser.add_argument('--hist_bins',type=int,default=2000)
    parser.add_argument('--quantile',action='store_true', default=False)


    # data
    parser.add_argument('--all-data', action='store_true', default=False,
                        help='whether to use all MNIST in the training')
    parser.add_argument('--data-dir',default='/',type=str)
    parser.add_argument('--num-supervised', default=60000, type=int,
                        help="number of supervised examples, /10 = samples per class")
    parser.add_argument('--list_train_domains', type=list,
                        default=['0','15','30','45','60','75'],
                        help='domains used during training')
    parser.add_argument('--target_domain', type=str, default='75',
                        help='domain used during testing')

    # ae model
    parser.add_argument('--activation', default='sigmoid')
    parser.add_argument('--ae_model',default='ae')
    parser.add_argument('--ae_dir',default='./autoencoders/saved/indae/')


    # log
    parser.add_argument('--run_name', default='histindae')
    parser.add_argument('--save_dir', default='/')
    parser.add_argument('--note', type=str, default='')
    parser.add_argument('--fid', action='store_false', default=True)
    parser.add_argument('--wd', action='store_false', default=True)




    args = parser.parse_args()

    args.save_dir = args.save_dir + f'{args.nlayers}_{args.K}_{args.max_swd_iters}_{args.hist_bins}'
    if args.quantile:
        args.save_dir += '_q'
    if args.mnist_subset == '0':
        args.save_dir += '0'
    args.save_dir = args.save_dir + f'/{args.target_domain}/'

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    seed = args.seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    # Torch RNG
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Python RNG
    np.random.seed(seed)
    #random.seed(seed)


    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': False} if args.cuda else {}

    all_training_domains = ['0', '15', '30', '45', '60', '75']
    all_training_domains.remove(args.target_domain)
    args.list_train_domains = all_training_domains
    print(args.list_train_domains)
    args.run_name = args.run_name + '-' + args.target_domain + f'-{args.nlayers}-{args.K}-swd-{args.max_swd_iters}-' + str(args.hist_bins)
    if args.quantile:
        args.run_name += '-quant'


    train_set = MnistRotated(args.list_train_domains, [args.target_domain], args.data_dir,
                             train=True, mnist_subset=args.mnist_subset, all_data=args.all_data)
    test_set = MnistRotated(args.list_train_domains, [args.target_domain], args.data_dir,
                            train=True, not_eval=False)

    train_loader = data_utils.DataLoader(train_set,
                                         batch_size=train_set.train_data.shape[0],
                                         shuffle=True)
    test_loader = data_utils.DataLoader(test_set,
                                        batch_size=test_set.train_data.shape[0],
                                        shuffle=True)

    for i, (x, y, d) in enumerate(train_loader):
        train_imgs = x
        train_labels = y
        train_domains = d

    for i, (x, y, d) in enumerate(test_loader):
        test_imgs = x
        test_labels = y
        test_domains = d
    # mnist_imgs = ((ori_mnist_imgs*255) + torch.rand_like(ori_mnist_imgs))/256.0
    print('Finish preparing data!!!')
    print('train imgs', train_imgs.shape)
    print('test imgs', test_imgs.shape)

    activations = {'tanh':nn.Tanh(),
                   'sigmoid':nn.Sigmoid()}
    args.activation = activations[args.activation]


    cd_enc = wrap_enc(args,device)
    cd_dec = wrap_dec(args,device)

    inb_dict = dict()
    tracker_dict = dict()
    #args.label_list = [0]
    for label in args.label_list:
        inb_dict,tracker_dict = train_inb(train_imgs, train_labels, train_domains,
                             test_imgs, test_labels, test_domains,
                             label, inb_dict, args.nlayers, args.K,
                             cd_enc, cd_dec, tracker_dict,
                             vis=True,
                             fid=args.fid,
                             wd=args.wd,
                             verbose=True,
                             max_swd_iters=args.max_swd_iters,
                                          hist_bins=args.hist_bins,
                                          quantile=args.quantile,
                                          device= device)

    torch.save(inb_dict,args.save_dir + 'inb.pt')
    if args.fid and args.wd:
        torch.save(tracker_dict,args.save_dir + 'stats.pt')