import argparse
import shutil
from random import *
import torch
from torchvision.utils import save_image
import torch.autograd
from tqdm import tqdm
import numpy as np
import os
import sys

from Datasets.dataset_MNIST_SVHN import setup_MNIST_SVHN_loaders, get_10_mnist_svhn_samples, get_some_mnist_svhn_samples
from Models.mvae_MNIST_SVHN import MVAE_MNIST_SVHN
from utils.classifier import Clf_MNIST_SVHN
from utils.test_fuctions_MNIST_SVHN import cross_coherence, train_clf_lr, linear_latent_classification, calculate_fid_routine
from utils.test_fuctions_utils import Logger, clustering, save_model_light


def set_seed(seed):
    import random
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)     # python random generator
    np.random.seed(seed)  # numpy random generator
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main(args):
    data_loaders = setup_MNIST_SVHN_loaders(args.batch_size,
                                            sup_frac=args.sup_frac,
                                            missing=args.missing,
                                            root='data/datasets/MNIST_SVHN')

    if args.sup_frac != 1.0:
        pseudo_samples_a, pseudo_samples_b, _ = next(iter(data_loaders['unsup']))
    else:
        pseudo_samples_a, pseudo_samples_b, _ = next(iter(data_loaders['sup']))

    vae = MVAE_MNIST_SVHN(args, pseudo_samples_a=pseudo_samples_a.to(args.device),
                                pseudo_samples_b=pseudo_samples_b.to(args.device))

    optim = torch.optim.Adam(params=vae.parameters(), lr=args.learning_rate)

    for epoch in range(0, args.num_epochs+1):
        # # # compute number of batches for an epoch
        if args.sup_frac == 1.0: # fully supervised
            batches_per_epoch = len(data_loaders["sup"])
            unsup_batches = 0
        elif args.sup_frac > 0.0: # semi-supervised
            sup_batches = len(data_loaders["sup"])
            unsup_batches = len(data_loaders["unsup"])
            batches_per_epoch = unsup_batches + sup_batches
        else:
            assert False, "Data frac not correct"

        epoch_losses_sup = 0.0
        epoch_losses_unsup = 0.0

        # setup the iterators for training data loaders
        if args.sup_frac != 0.0:
            sup_iter = iter(data_loaders["sup"])
        if args.sup_frac != 1.0:
            unsup_iter = iter(data_loaders["unsup"])

        # count the number of supervised batches seen in this epoch
        num_sups = 0
        num_unsups = 0

        vae.train()

        for i in tqdm(range(batches_per_epoch)):
            # whether this batch is supervised or not
            if i < unsup_batches:
                is_supervised = False
            else:
                is_supervised = True

            # extract the corresponding batch
            if is_supervised:
                data = next(sup_iter)
            else:
                data = next(unsup_iter)

            svhn_batch = data[0].to(args.device)
            mnist_batch = data[1].to(args.device)

            if is_supervised: 
                num_sups += 1
                loss = vae.match(batch_a=svhn_batch, batch_b=mnist_batch)
                loss.backward()
                epoch_losses_sup += loss.detach().item()
            else:
                num_unsups += 1
                if args.missing == 'mnist':
                    loss = vae.unsup(batch_a=svhn_batch, batch_b=None, direction='s2m')
                elif args.missing == 'svhn':
                    loss = vae.unsup(batch_a=None, batch_b=mnist_batch, direction='m2s')
                else:
                    raise Exception("Modality %s not recognised" % args.missing)

                epoch_losses_unsup += loss.detach().item()
                loss.backward()
 
            optim.step()
            optim.zero_grad()   

        vae.num_steps += 1         
                      
        if epoch % 10 == 0:
            vae.eval()
            with torch.no_grad():
                test_selected_samples = get_10_mnist_svhn_samples(data_loaders['test'].dataset,
                                                                  num_testing_images=data_loaders['test'].dataset.__len__(), device=args.device)
                svhn, mnist = test_selected_samples[0], test_selected_samples[1]

                cg_imgs = vae.self_and_cross_modal_generation(svhn, mnist, 1, 10, -1)

                save_image(cg_imgs[0][0], os.path.join(args.data_dir, 'img/recon_svhn.png'))
                save_image(cg_imgs[0][1], os.path.join(args.data_dir, 'img/svhn_mnist.png'))
                save_image(cg_imgs[1][0], os.path.join(args.data_dir, 'img/mnist_svhn.png'))
                save_image(cg_imgs[1][1], os.path.join(args.data_dir, 'img/recon_mnist.png'))

                cg_imgs = vae.self_and_cross_modal_generation_spec(svhn, mnist, 5, 10, -1)

                save_image(cg_imgs[0][0], os.path.join(args.data_dir, 'img/gen_recon_svhn.png'))
                save_image(cg_imgs[0][1], os.path.join(args.data_dir, 'img/gen_svhn_mnist.png'))
                save_image(cg_imgs[1][0], os.path.join(args.data_dir, 'img/gen_mnist_svhn.png'))
                save_image(cg_imgs[1][1], os.path.join(args.data_dir, 'img/gen_recon_mnist.png'))

                test_selected_samples = get_some_mnist_svhn_samples(data_loaders['test'].dataset,
                                                                    data_loaders['test'].dataset.__len__(), 128, device=args.device)
                svhn, mnist = test_selected_samples[0], test_selected_samples[1]

                cg_imgs = vae.self_and_cross_modal_generation(svhn, mnist, 1, 8, -1)

                save_image(cg_imgs[0][0], os.path.join(args.data_dir, 'img/Appendix_recon_svhn.png'))
                save_image(cg_imgs[0][1], os.path.join(args.data_dir, 'img/Appendix_svhn_mnist.png'))
                save_image(cg_imgs[1][0], os.path.join(args.data_dir, 'img/Appendix_mnist_svhn.png'))
                save_image(cg_imgs[1][1], os.path.join(args.data_dir, 'img/Appendix_recon_mnist.png'))

                cg_imgs = vae.self_and_cross_modal_generation_spec(svhn, mnist, 5, 8, -1)

                save_image(cg_imgs[0][0], os.path.join(args.data_dir, 'img/Appendix_gen_recon_svhn.png'))
                save_image(cg_imgs[0][1], os.path.join(args.data_dir, 'img/Appendix_gen_svhn_mnist.png'))
                save_image(cg_imgs[1][0], os.path.join(args.data_dir, 'img/Appendix_gen_mnist_svhn.png'))
                save_image(cg_imgs[1][1], os.path.join(args.data_dir, 'img/Appendix_gen_recon_mnist.png'))

                vae.tsne_plot(data_loaders['test'], args.data_dir, n=10)

            # Compute cross-coherence
            clfs = [Clf_MNIST_SVHN() for _ in range(2)]
            clfs[0].load_state_dict(torch.load("pretrained/img_to_digit_clf_svhn.pth"), strict=False)
            clfs[1].load_state_dict(torch.load("pretrained/img_to_digit_clf_mnist.pth"), strict=False)
            cors = cross_coherence(clfs, data_loaders['test'], vae, args.device)
            for i in range(2):
                for j in range(2):
                    print("Conditional_coherence_m%dxm%d: %.3f" % (i, j, cors[i][j]))

            # Train latent classfier
            clf_lr = train_clf_lr(data_loaders, vae, args.missing, args.device)
            # Calculate unconditional coherence and linear latent classification accuracies
            accuracies_lc = linear_latent_classification(data_loaders['test'], vae, clf_lr, args.device)
            for key in accuracies_lc:
                print("Latent classify acc " + str(key) + ": " + str(accuracies_lc[key]))

            accuracies = clustering(data_loaders['test'], vae, n=10)
            for key in accuracies:
                print(str(key) + ' latent clustering ACC = {:6f} NMI = {:6f} ARI = {:6f} Purity = {:6f}'.format(accuracies[key][0], accuracies[key][1], accuracies[key][2], accuracies[key][3]))

            calculate_fid_routine(data_loaders['test'], vae, 'data', os.path.join(args.data_dir, 'fid'), 'data/pt_inception-2015-12-05-6726825d.pth', 10000, args.device)        
        
        print("[Epoch %03d] Sup Loss %.3f, Unsup Loss %.3f" % (epoch, epoch_losses_sup, epoch_losses_unsup))
        
    save_model_light(vae, args.data_dir + '/model.rar')



def parser_args(parser):
    parser.add_argument('--cuda', type=bool, default=True,  help='use cuda')
    parser.add_argument('-n', '--num-epochs', type=int, default=100, help="number of epochs")
    parser.add_argument('-sup', '--sup-frac', type=float, default=1,
                        help="supervised fractional amount of the data i.e. "
                                         "how many of the images have supervised labels."
                                         "Should be a multiple of train_size / batch_size")
    parser.add_argument('--missing', type=str, default='mnist', help='svhn|mnist')
    parser.add_argument('-zd', '--z_dim', type=int, default=32,
                        help="latent size")
    parser.add_argument('-wd', '--w_dim', type=int, default=32,
                        help="latent size")
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('-lr', '--learning_rate', type=float, default=5e-4)
    parser.add_argument('-bs', '--batch_size', type=int, default=64)
    parser.add_argument('--data_dir', type=str, default='data',
                        help='Data path')
    parser.add_argument('--seed', type=int, default=0)
    return parser


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser = parser_args(parser)
    args = parser.parse_args()

    args.device = torch.device("cuda:0" if args.cuda else "cpu")
    args.a_name = 'SVHN'
    args.b_name = 'MNIST'
    
    set_seed(args.seed)

    if args.sup_frac < 1.0:
        assert args.missing is not None, "Set missing modality for semi-sup"

    run_name = 'MNIST_SVHN'

    args.data_dir = os.path.join(args.data_dir, 'runs', str(args.sup_frac), str(args.missing), run_name, str(args.seed))

    if os.path.isdir(os.path.join(args.data_dir, "img")):
        shutil.rmtree(os.path.join(args.data_dir, "img"))

    os.makedirs(os.path.join(args.data_dir, "img"))
    sys.stdout = Logger('{}/run.log'.format(args.data_dir))
    main(args)
    