import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch.nn.functional as F
import numpy as np
import os
from torchvision.utils import save_image

from polymnist_dataset import test_dataset_upd10_32x32
from h_vae_model_copy import ResAE
from unet_model import Unet
from polymnist_model import PMCLF
# from lat_sm2_model import LSMPoly64_sm
# from mopoe_model import MOPOEPolyRes, MMVAEPolyRes, MVPolyRes

from pytorch_fid.fid_score import calculate_fid_given_paths
from utils import *

def get_train_test_dataloader_upd10_32x32(batch_size):
    paired_test_dataset = test_dataset_upd10_32x32()
    # train_dataloader = DataLoader(paired_train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    # val_dataloader = DataLoader(paired_val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(paired_test_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    return test_dataloader


def calc_poly_cond(test_loader, sample_path, vae_dict, sm_model, predicted_mods, all_mods, p_clf, c, er, n_comp, size_z, device, write_input):
    with torch.no_grad():
        sigmas = torch.tensor(np.linspace(5, 0.1, 200)).to(device)
        for vae in vae_dict.values():
            vae.eval()
        cond_accuracies = {}
        for pred in predicted_mods:
            cond_accuracies[pred] = 0
        p = {}
        p_out = {}

        for batch_idx, (images, target) in enumerate(test_loader):
            for key in sorted(vae_dict.keys()):
                p[key] = images['m'+key].to(device)
            target = target.to(device)

            z = {}
            b_size = p[all_mods[0]].shape[0]

            for key in sorted(vae_dict.keys()):
                if key in predicted_mods:
                    z[key] = torch.normal(mean=0, std=1, size=(p[key].shape[0],size_z), device=device)
                if key not in predicted_mods:
                    z[key] = vae_dict[key].encoder(p[key])

            for s_in, s in enumerate(sigmas):
                sigma_index = torch.tensor([s_in]*b_size).to(device)
                cur_sigmas = sigmas[sigma_index].float().to(device) 
                alpha = er * (sigmas[s_in]**2)/(sigmas[-1]**2)

                # if mod in given:
                #     noised[mod] = s * torch.randn_like(z[mod])
                #     z[mod] = z[mod] + noised[mod]
                
                for i in range(n_comp):
                    z_in = torch.cat([z[mod].unsqueeze(1) for mod in sorted(vae_dict.keys())], dim=1).view(-1,len(vae_dict.keys()),8,8).detach()
                    sm_out = sm_model(z_in, sigma_index) / cur_sigmas.view(z_in.shape[0],*([1]*len(z_in.shape[1:])))

                    for ind,mod in enumerate(sorted(vae_dict.keys())):
                        if mod in predicted_mods:
                            z[mod] = z[mod] + (alpha * sm_out[:,ind].view(b_size,size_z)) + c*torch.sqrt(2*alpha)*torch.normal(mean=0, std=1, size=z[mod].shape, device=device)

                 # if mod in given:
                    #  z[mod] = z[mod] - noised[mod]

            for mod in predicted_mods:
                p_out[mod] = vae_dict[mod].decoder(z[mod])
                predicted_out = p_clf(p_out[mod].view(-1,3,32,32)[:,:,2:30,2:30])
                predicted_out = torch.argmax(predicted_out, 1)
                
                cond_acc = torch.sum(predicted_out == target).item()
                cond_acc = cond_acc / p[mod].shape[0]
                cond_accuracies[mod] += cond_acc
                
                if write_input:
                    save_batch_image(p[mod], sample_path['p' + mod] + 'p' + mod + str(batch_idx) + '_')
                save_batch_image(p_out[mod], sample_path['cond_pAE' + str(len(all_mods)) + '_' + mod + '_' + ''.join([i for i in all_mods if i not in predicted_mods])] + str(batch_idx) + '_')

        for mod in cond_accuracies.keys():
            cond_accuracies[mod] /= len(test_loader)
        print("Cond Coherence AE: ", cond_accuracies, flush=True)
        return cond_accuracies

def calc_poly_uncond(test_loader, sample_path, vae_dict, sm_model, p_clf, c, er, n_comp, size_z, device, write_input):
    with torch.no_grad():
        sigmas = torch.tensor(np.linspace(5, 0.1, 200)).to(device)
        for vae in vae_dict.values():
            vae.eval()
        unc_accuracies = [0]*6
        p = {}
        p_out = {}

        for batch_idx, (images, target) in enumerate(test_loader):
            for key in sorted(vae_dict.keys()):
                p[key] = images['m'+key].to(device)
            target = target.to(device)
            b_size = p[key].shape[0]

            z = {}
            predicted_out = {}
            for pred in sorted(vae_dict.keys()):
                z[pred] = torch.normal(mean=0, std=1, size=(p[pred].shape[0],size_z), device=device)

            for s_in, s in enumerate(sigmas):
                sigma_index = torch.tensor([s_in]*b_size).to(device)
                cur_sigmas = sigmas[sigma_index].float().to(device) 
                alpha = er * (sigmas[s_in]**2)/(sigmas[-1]**2)
                
                for i in range(n_comp):
                    z_in = torch.cat([z[mod].unsqueeze(1) for mod in sorted(vae_dict.keys())], dim=1).view(-1,len(vae_dict.keys()),8,8).detach()
                    sm_out = sm_model(z_in, sigma_index) / cur_sigmas.view(z_in.shape[0],*([1]*len(z_in.shape[1:])))

                    for ind,mod in enumerate(sorted(vae_dict.keys())):
                        z[mod] = z[mod] + (alpha * sm_out[:,ind].view(b_size,size_z)) + c*torch.sqrt(2*alpha)*torch.normal(mean=0, std=1, size=z[mod].shape, device=device)


            for mod in sorted(vae_dict.keys()):
                p_out[mod] = vae_dict[mod].decoder(z[mod])
                predicted_out[mod] = p_clf(p_out[mod].view(-1,3,32,32)[:,:,2:30,2:30])
                predicted_out[mod] = torch.argmax(predicted_out[mod], 1)
                
                if write_input:
                    save_batch_image(p[mod], sample_path['p' + mod] + 'p' + mod + str(batch_idx) + '_')
                save_batch_image(p_out[mod], sample_path['unc_pAE' + str(len(sorted(vae_dict.keys()))) + '_' + mod] + mod + '_out_' + str(batch_idx) + '_')

                
            for ind, num_eq_check in enumerate(range(5,len(list(sorted(vae_dict.keys())))+1)):
                equality_mask = (torch.stack([predicted_out[out] for out in sorted(vae_dict.keys())], dim=0) == predicted_out[list(sorted(vae_dict.keys()))[0]]).sum(dim=0)
                equality_mask = equality_mask >= num_eq_check
                unc_acc = torch.sum(equality_mask).item()
                unc_acc = unc_acc / p[list(sorted(vae_dict.keys()))[0]].shape[0]
                unc_accuracies[ind] += unc_acc
                
        for i in range(len(unc_accuracies)):
            unc_accuracies[i] =  unc_accuracies[i] / len(test_loader)
        print("UNC acc AE: " , unc_accuracies, flush=True)
        return unc_accuracies[-1]

def check_file_len(path, amount):
    for dir in path:
        initial_count = 0
        for path in os.listdir(dir):
            if os.path.isfile(os.path.join(dir, path)):
                initial_count += 1
        if (initial_count != amount):
            print('file len error: ', dir, flush=True)
            return False
    return True


def run(batch_size, size_z, all_mod, predicted_mod, model_paths, sm_path, pclf_path, c, unc_c, er, unc_er, n_comp, unc_n_comp, fid_n_times, unq_name, inc_fid):
    print('vars: ', all_mod, predicted_mod, batch_size, size_z, c, unc_c, er, unc_er, fid_n_times, sm_path, unq_name, flush=True)
    print("All mods: ", all_mod, flush=True)
    print("Predicted mod: ", predicted_mod, flush=True)
    print('inc fid: ', 'true' if inc_fid else 'false')
    fid_scores_cond = []
    fid_scores_unc = []
    unc_accs = []
    cond_accs = []

    cuda = torch.cuda.is_available()
    print("GPU Available: ", cuda, flush=True)
    device = torch.device("cuda:3")

    enc_channel_list = [(64,64,64,2), (64,128,128,2), (128,256,256,2)]
    dec_channel_list = [(256,128,128,2), (128,128,64,2), (64,64,64,2)]
    size_in = 32
    img_ch = 3
    pvae_dict = {}
    n_mod = len(all_mod)
    
    for ind, model_path in enumerate(model_paths):
        if str(ind) in all_mod:
            pmvae = ResAE(enc_channel_list, dec_channel_list, size_in, size_z, img_ch)
            pmvae.load_state_dict(torch.load(model_path)['model_state_dict'])
            pmvae = pmvae.to(device)
            pvae_dict[str(ind)] = pmvae
    
    
    if n_mod > 5:
        dim = 64
    else:
        dim = 32
    score_ae = Unet(dim=dim, channels=n_mod, dim_mults=(1,2,2,2), with_time_emb=True)
    score_ae.load_state_dict(torch.load(sm_path)['model_state_dict'])
    score_ae = score_ae.to(device)
    
    test_dataloader = get_train_test_dataloader_upd10_32x32(batch_size)
    print('data loaded', flush=True)

    poly_clf = PMCLF()
    poly_clf.load_state_dict(torch.load(pclf_path))
    poly_clf = poly_clf.to(device)
    poly_clf.eval()

    sample_input_path = []
    sample_output_path = []

    sample_path = {}
    for mod in all_mod:
        sample_path['p' + mod] = './samples/p' + mod + '/'
        sample_input_path.append('./samples/p' + mod + '/')
        if inc_fid == 0:
            if len(predicted_mod) == 0:
                sample_path['unc_pAE' + str(len(all_mod)) + '_' + mod] = './samples/unc_pAE' + str(len(all_mod)) + '_' + mod + '/'
                sample_output_path.append('./samples/unc_pAE' + str(len(all_mod)) + '_' + mod + '/')
            else:
                for pred in predicted_mod:
                    sample_path['cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod])] = './samples/cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod]) + '/'
                    if './samples/cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod]) + '/' not in sample_output_path:
                        sample_output_path.append('./samples/cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod]) + '/')
    
    for p in sample_input_path + sample_output_path:
        if not os.path.exists(p):
            os.makedirs(p)
    
    print("Input path: ", sample_input_path, flush=True)
    print("Output path: ", sample_output_path, flush=True)

    write_input = False
    # num_avail_cpus = len(os.sched_getaffinity(0))
    num_workers = 2

    if inc_fid:

        calculated_mod = '9'
        all_fid_score, all_cond_acc = [], []

        for g in range(9):
            predicted_mod = all_mod[g+1:len(all_mod)]
            print('my pred: ', predicted_mod, flush=True)
            
            for pred in predicted_mod:
                sample_path['cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod])] = './samples/cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod]) + '/'
                if './samples/cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod]) + '/' not in sample_output_path:
                    sample_output_path.append('./samples/cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod]) + '/')
            
            for p in sample_input_path + sample_output_path:
                if not os.path.exists(p):
                    os.makedirs(p)

            if len(predicted_mod) > 0:
                fid_scores_cond, cond_accs = [], []
                for i in range(fid_n_times):
                    # fid_scores_cond.append([])
                    if not check_file_len(sample_input_path, 10000):
                        print('write input true')
                        write_input = True
                    # write_input = True

                    cond_coherence = calc_poly_cond(test_dataloader, sample_path, pvae_dict, score_ae, predicted_mod, all_mod, poly_clf, c, er, n_comp, size_z, device, write_input)
                    
                    if not check_file_len(sample_input_path + sample_output_path, 10000):
                        raise Exception('file len check not correct!')

                    for pred in calculated_mod:
                        cond_p = calculate_fid_given_paths([sample_path['p'+pred], sample_path['cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod])]], batch_size, device, 2048, num_workers)
                        fid_scores_cond.append(cond_p)
                    cond_accs.append(cond_coherence[calculated_mod])
            
            # fid_scores_cond = np.array(fid_scores_cond)
            # cond_accs = np.array(cond_accs)
            
            all_fid_score.append(fid_scores_cond)
            all_cond_acc.append(cond_accs)
            
            print("Conditional coherence: ", np.array(cond_accs).mean(), flush=True)
            print("Mean Fid Scores conditional: ", np.array(fid_scores_cond).mean(), flush=True)

        np.save('./ar/dsmAE_increm_fid_' + calculated_mod, np.array(all_fid_score))
        np.save('./ar/dsmAE_increm_acc_' + calculated_mod, np.array(all_cond_acc))
        return
    
    else:

        if len(predicted_mod) > 0:
            for i in range(fid_n_times):
                fid_scores_cond.append([])
                if not check_file_len(sample_input_path, 10000):
                    write_input = True
                # write_input = True

                cond_coherence = calc_poly_cond(test_dataloader, sample_path, pvae_dict, score_ae, predicted_mod, all_mod, poly_clf, c, er, n_comp, size_z, device, write_input)
                
                if not check_file_len(sample_input_path + sample_output_path, 10000):
                    raise Exception('file len check not correct!')

                for pred in predicted_mod:
                    cond_p = calculate_fid_given_paths([sample_path['p'+pred], sample_path['cond_pAE' + str(len(all_mod)) + '_' + pred + '_' + ''.join([i for i in all_mod if i not in predicted_mod])]], batch_size, device, 2048, num_workers)
                    fid_scores_cond[i].append(cond_p)
                print('fids: ', fid_scores_cond[i], flush=True)
                cond_accs.append(list(cond_coherence.values()))
            
            fid_scores_cond = np.array(fid_scores_cond)
            cond_accs = np.array(cond_accs)
            
            print("Conditional coherence: ", np.mean(cond_accs, axis=0), flush=True)
            print("Mean Fid Scores conditional: ", np.mean(fid_scores_cond, axis=0), flush=True)

            np.save('./ar/cond_fid_dsmAE_' + unq_name + predicted_mod, np.array(fid_scores_cond))
            np.save('./ar/cond_acc_dsmAE_' + unq_name + predicted_mod, np.array(cond_accs))
            return

        else:
            if len(predicted_mod) == 0:
                for i in range(fid_n_times):
                    fid_scores_unc.append([])
                    if not check_file_len(sample_input_path, 10000):
                        write_input = True
                    # write_input = False

                    unc_coherence = calc_poly_uncond(test_dataloader, sample_path, pvae_dict, score_ae, poly_clf, unc_c, unc_er, unc_n_comp, size_z, device, write_input)
                    
                    if not check_file_len(sample_input_path + sample_output_path, 10000):
                        raise Exception('file len check not correct!')
                    
                    for mod in all_mod:
                        unc_p = calculate_fid_given_paths([sample_path['p'+mod], sample_path['unc_pAE' + str(len(all_mod)) + '_' + mod]], batch_size, device, 2048, num_workers)
                        fid_scores_unc[i].append(unc_p)
                    unc_accs.append(unc_coherence)
            
                fid_scores_unc = np.array(fid_scores_unc)
                unc_accs = np.array(unc_accs)

                print("Mean Fid Scores unconditional: ", np.mean(fid_scores_unc, axis=0), flush=True)
                print("Unc coherence: ", np.mean(unc_accs), flush=True)

                np.save('./ar/dsmAE_unc_fid_' + all_mod, np.array(fid_scores_unc))
                np.save('./ar/dsmAE_unc_acc_' + all_mod, np.array(unc_accs))


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--all_mod', type=str, default='0123456789',
                        help='all modalities [default: "0123456789"]')
    parser.add_argument('--predicted_mod', type=str, default='',
                        help='predicted modalities [default: ""]')
    parser.add_argument('--size-z', type=int, default=64,
                        help='size of z [default: 64]')
    parser.add_argument('--fid-n-times', type=int, default=5,
                        help='number of times to repeat fid calc [default: 5]')
    parser.add_argument('--batch-size', type=int, default=256,
                        help='batch size for training [default: 256]')
    parser.add_argument('--c', type=float, default=0.5,
                        help='c noise constant [default: 0.5]')
    parser.add_argument('--er', type=float, default=2e-3,
                        help='er ld constant [default: 2e-3]')
    parser.add_argument('--unc-c', type=float, default=0.7,
                        help='unc-c [default: 0.7]')
    parser.add_argument('--unc-er', type=float, default=0.01,
                        help='unc-er [default: 0.01]')
    parser.add_argument('--ncomp', type=int, default=20,
                        help='size of LD iterations [default: 20]')
    parser.add_argument('--unc-ncomp', type=int, default=2,
                        help='size of LD iterations [default: 2]')
    parser.add_argument('--inc-fid', type=int, default=1,
                        help='calculate fid incrementally [default: 1]')
    parser.add_argument('--unq-name', type=str, default='',
                        help='unique name for experiment [default: ""]')

    parser.add_argument('--p0-path-ae', type=str, default='./models/polyupd10_m0/polyNEWAE_m0_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m0/polyNEWAE_m0_64_0.01_1e-05"]')
    parser.add_argument('--p1-path-ae', type=str, default='./models/polyupd10_m1/polyNEWAE_m1_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m1/polyNEWAE_m1_64_0.01_1e-05"]')
    parser.add_argument('--p2-path-ae', type=str, default='./models/polyupd10_m2/polyNEWAE_m2_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m2/polyNEWAE_m2_64_0.01_1e-05"]')
    parser.add_argument('--p3-path-ae', type=str, default='./models/polyupd10_m3/polyNEWAE_m3_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m3/polyNEWAE_m3_64_0.01_1e-05"]')
    parser.add_argument('--p4-path-ae', type=str, default='./models/polyupd10_m4/polyNEWAE_m4_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m4/polyNEWAE_m4_64_0.01_1e-05"]')
    parser.add_argument('--p5-path-ae', type=str, default='./models/polyupd10_m5/polyNEWAE_m5_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m5/polyNEWAE_m5_64_0.01_1e-05"]')
    parser.add_argument('--p6-path-ae', type=str, default='./models/polyupd10_m6/polyNEWAE_m6_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m6/polyNEWAE_m6_64_0.01_1e-05"]')
    parser.add_argument('--p7-path-ae', type=str, default='./models/polyupd10_m7/polyNEWAE_m7_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m7/polyNEWAE_m7_64_0.01_1e-05"]')
    parser.add_argument('--p8-path-ae', type=str, default='./models/polyupd10_m8/polyNEWAE_m8_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m8/polyNEWAE_m8_64_0.01_1e-05"]')
    parser.add_argument('--p9-path-ae', type=str, default='./models/polyupd10_m9/polyNEWAE_m9_64_0.01_1e-05',
                        help='multimodal model path-ae [default: "./models/polyupd10_m9/polyNEWAE_m9_64_0.01_1e-05"]')
    parser.add_argument('--score-ae', type=str, default='./models/psm_upd/0123456789_64_AE_psm_aeNEWre5_dim32_s5_01_200_dim64_unet_',
                        help='score ae model path [default: "./models/psm_upd/0123456789_64_AE_psm_aeNEWre5_dim32_s5_01_200_dim64_unet_"]')

    parser.add_argument('--pclf-path', type=str, default='./models/pm_clf/pm_clf_best',
                        help='poly classifier path [default: "./models/pm_clf/pm_clf_best"]')

    # parser.add_argument('--mopoe-path', type=str, default='./models/mopoe_pupd/mopoe_pupd_perm_vae_res_beta_0.5__640.001',
    #                     help='mopoe model path [default: "./models/mopoe_pupd/mopoe_pupd_perm_vae_res_beta_0.5__640.001"]')
    # parser.add_argument('--mmvae-path', type=str, default='./models/mopoe_pupd/mmvae_vae_res_beta_0.5__640.001',
    #                     help='mmvae model path [default: "./models/mopoe_pupd/mmvae_vae_res_beta_0.5__640.001"]')
    # parser.add_argument('--mvae-path', type=str, default='./models/mopoe_pupd/mvae_vae_res_beta_0.5__640.001',
    #                     help='mvae model path [default: "./models/mopoe_pupd/mvae_vae_res_beta_0.5__640.001"]')

    args = parser.parse_args()
    model_paths_ae = [args.p0_path_ae, args.p1_path_ae, args.p2_path_ae, args.p3_path_ae, args.p4_path_ae, args.p5_path_ae, args.p6_path_ae, args.p7_path_ae, args.p8_path_ae, args.p9_path_ae]

    run(args.batch_size, args.size_z, args.all_mod, args.predicted_mod, model_paths_ae, args.score_ae, args.pclf_path, args.c, args.unc_c, args.er, args.unc_er, args.ncomp, args.unc_ncomp, args.fid_n_times, args.unq_name, args.inc_fid)


