import argparse
import shutil
from random import *
import torch
from torchvision.utils import save_image
import torch.autograd
from tqdm import tqdm
import numpy as np
import os
import sys

from Datasets.dataset_CUBICC import setup_CUBICC_loaders, get_8_CUBICC_samples, get_some_CUBICC_samples
from Models.mvae_CUBICC import MVAE_CUBICC
from utils.classifier import Clf_CUBICC
from utils.test_fuctions_CUBICC import cross_coherence, train_clf_lr, linear_latent_classification, calculate_fid_routine
from utils.test_fuctions_utils import Logger, clustering, save_model_light


def set_seed(seed):
    import random
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main(args):
    data_loaders = setup_CUBICC_loaders(args.batch_size,
                                      sup_frac=args.sup_frac,
                                      missing=args.missing,
                                      root='data/datasets/CUBICC')

    if args.sup_frac != 1.0:
        pseudo_samples_a, pseudo_samples_b, _ = next(iter(data_loaders['unsup']))
    else:
        pseudo_samples_a, pseudo_samples_b, _ = next(iter(data_loaders['sup']))

    vae = MVAE_CUBICC(args, pseudo_samples_a=pseudo_samples_a.to(args.device),
                            pseudo_samples_b=pseudo_samples_b.to(args.device))

    optim = torch.optim.Adam(params=vae.parameters(), lr=args.learning_rate)

    for epoch in range(0, args.num_epochs+1):
        # # # compute number of batches for an epoch
        if args.sup_frac == 1.0: # fully supervised
            batches_per_epoch = len(data_loaders["sup"])
            unsup_batches = 0
        elif args.sup_frac > 0.0: # semi-supervised
            sup_batches = len(data_loaders["sup"])
            unsup_batches = len(data_loaders["unsup"])
            batches_per_epoch = unsup_batches + sup_batches
        else:
            assert False, "Data frac not correct"

        epoch_losses_sup = 0.0
        epoch_losses_unsup = 0.0

        # setup the iterators for training data loaders
        if args.sup_frac != 0.0:
            sup_iter = iter(data_loaders["sup"])
        if args.sup_frac != 1.0:
            unsup_iter = iter(data_loaders["unsup"])

        # count the number of supervised batches seen in this epoch
        num_sups = 0
        num_unsups = 0

        vae.train()

        for i in tqdm(range(batches_per_epoch)):
            # whether this batch is supervised or not
            if i < unsup_batches:
                is_supervised = False
            else:
                is_supervised = True

            # extract the corresponding batch
            if is_supervised:
                data = next(sup_iter)
            else:
                data = next(unsup_iter)

            image_batch = data[0].to(args.device)
            sentence_batch = data[1].to(args.device)

            if is_supervised: 
                num_sups += 1
                loss = vae.match(batch_a=image_batch, batch_b=sentence_batch)
                loss.backward()
                epoch_losses_sup += loss.detach().item()
            else:
                num_unsups += 1
                if args.missing == 'sentence':
                    loss = vae.unsup(batch_a=image_batch, batch_b=None, direction='i2s')
                elif args.missing == 'image':
                    loss = vae.unsup(batch_a=None, batch_b=sentence_batch, direction='s2i')
                else:
                    raise Exception("Modality %s not recognised" % args.missing)

                epoch_losses_unsup += loss.detach().item()
                loss.backward()
 
            optim.step()
            optim.zero_grad()   

        vae.num_steps += 1         
                      
        if epoch % 10 == 0:
            vae.eval()
            with torch.no_grad():
                test_selected_samples = get_8_CUBICC_samples(data_loaders['test'].dataset,
                                                             num_testing_images=data_loaders['test'].dataset.__len__(), device=args.device)
                image, sentence = test_selected_samples[0], test_selected_samples[1]

                cg_imgs = vae.self_and_cross_modal_generation(image, sentence, 1, 8, 2)

                save_image(cg_imgs[0][0], os.path.join(args.data_dir, 'img/recon_image.png'))
                save_image(cg_imgs[0][1], os.path.join(args.data_dir, 'img/image_sentence.png'))
                save_image(cg_imgs[1][0], os.path.join(args.data_dir, 'img/sentence_image.png'))
                save_image(cg_imgs[1][1], os.path.join(args.data_dir, 'img/recon_sentence.png'))

                cg_imgs = vae.self_and_cross_modal_generation_spec(image, sentence, 5, 8, 2)

                save_image(cg_imgs[0][0], os.path.join(args.data_dir, 'img/gen_recon_image.png'))
                save_image(cg_imgs[0][1], os.path.join(args.data_dir, 'img/gen_image_sentence.png'))
                save_image(cg_imgs[1][0], os.path.join(args.data_dir, 'img/gen_sentence_image.png'))
                save_image(cg_imgs[1][1], os.path.join(args.data_dir, 'img/gen_recon_sentence.png'))

                test_selected_samples = get_some_CUBICC_samples(data_loaders['test'].dataset,
                                                                data_loaders['test'].dataset.__len__(), 16, device=args.device)
                image, sentence = test_selected_samples[0], test_selected_samples[1]

                cg_imgs = vae.self_and_cross_modal_generation(image, sentence, 1, 16, 2)

                save_image(cg_imgs[0][0], os.path.join(args.data_dir, 'img/Appendix_recon_image.png'))
                save_image(cg_imgs[0][1], os.path.join(args.data_dir, 'img/Appendix_image_sentence.png'))
                save_image(cg_imgs[1][0], os.path.join(args.data_dir, 'img/Appendix_sentence_image.png'))
                save_image(cg_imgs[1][1], os.path.join(args.data_dir, 'img/Appendix_recon_sentence.png'))

                cg_imgs = vae.self_and_cross_modal_generation_spec(image, sentence, 5, 16, 2)

                save_image(cg_imgs[0][0], os.path.join(args.data_dir, 'img/Appendix_gen_recon_image.png'))
                save_image(cg_imgs[0][1], os.path.join(args.data_dir, 'img/Appendix_gen_image_sentence.png'))
                save_image(cg_imgs[1][0], os.path.join(args.data_dir, 'img/Appendix_gen_sentence_image.png'))
                save_image(cg_imgs[1][1], os.path.join(args.data_dir, 'img/Appendix_gen_recon_sentence.png'))

                vae.tsne_plot(data_loaders['test'], args.data_dir, n=8)

            # Compute cross-coherence
            clf = Clf_CUBICC()
            clf.load_state_dict(torch.load("pretrained/img_to_digit_clf_image.pth"), strict=False)
            cors = cross_coherence(clf, data_loaders['test'], vae, args.device)
            for i in range(2):
                print("Conditional_coherence_m%dxm0: %.3f" % (i, cors[i]))

            # Train latent classfier
            clf_lr = train_clf_lr(data_loaders, vae, args.missing, args.device, n=8)
            # Calculate unconditional coherence and linear latent classification accuracies
            accuracies_lc = linear_latent_classification(data_loaders['test'], vae, clf_lr, args.device, n=8)
            for key in accuracies_lc:
                print("Latent classify acc " + str(key) + ": " + str(accuracies_lc[key]))

            accuracies = clustering(data_loaders['test'], vae, n=8)
            for key in accuracies:
                print(str(key) + ' latent clustering ACC = {:6f} NMI = {:6f} ARI = {:6f} Purity = {:6f}'.format(accuracies[key][0], accuracies[key][1], accuracies[key][2], accuracies[key][3]))
                
            calculate_fid_routine(data_loaders['test'], vae, 'data', os.path.join(args.data_dir, 'fid'), 'data/pt_inception-2015-12-05-6726825d.pth', 10000, args.device)
            
        print("[Epoch %03d] Sup Loss %.3f, Unsup Loss %.3f" % (epoch, epoch_losses_sup, epoch_losses_unsup))
        
    save_model_light(vae, args.data_dir + '/model.rar')


def parser_args(parser):
    parser.add_argument('--cuda', type=bool, default=True,  help='use cuda')
    parser.add_argument('-n', '--num-epochs', type=int, default=200, help="number of epochs")
    parser.add_argument('-sup', '--sup-frac', type=float, default=1,
                        help="supervised fractional amount of the data i.e. "
                                         "how many of the images have supervised labels."
                                         "Should be a multiple of train_size / batch_size")
    parser.add_argument('--missing', type=str, default='image', help='image|sentence')
    parser.add_argument('-zd', '--z_dim', type=int, default=48,
                        help="latent size")
    parser.add_argument('-wd', '--w_dim', type=int, default=16,
                        help="latent size")
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('-lr', '--learning_rate', type=float, default=1e-4)
    parser.add_argument('-bs', '--batch_size', type=int, default=16)
    parser.add_argument('--data_dir', type=str, default='data',
                        help='Data path')
    parser.add_argument('--seed', type=int, default=0)
    return parser


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser = parser_args(parser)
    args = parser.parse_args()

    args.device = torch.device("cuda:0" if args.cuda else "cpu")
    args.a_name = 'Image'
    args.b_name = 'Caption'
    
    set_seed(args.seed)

    if args.sup_frac < 1.0:
        assert args.missing is not None, "Set missing modality for semi-sup"

    run_name = 'CUBICC'
    
    args.data_dir = os.path.join(args.data_dir, 'runs', str(args.sup_frac), str(args.missing), run_name, str(args.seed))

    if os.path.isdir(os.path.join(args.data_dir, "img")):
        shutil.rmtree(os.path.join(args.data_dir, "img"))

    os.makedirs(os.path.join(args.data_dir, "img"))
    sys.stdout = Logger('{}/run.log'.format(args.data_dir))
    main(args)
    