import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch.nn.functional as F
import numpy as np
import os

from polymnist_dataset import get_train_test_dataset_upd10_32x32
from h_vae_model_copy import ResAE
from unet_model import Unet

from utils import *

def get_train_test_dataloader_upd10_32x32(batch_size):
    paired_train_dataset, paired_val_dataset, paired_test_dataset = get_train_test_dataset_upd10_32x32()
    train_dataloader = DataLoader(paired_train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_dataloader = DataLoader(paired_val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(paired_test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    return train_dataloader, val_dataloader, test_dataloader

def dsm_loss(sm_model, q_z, sigma=1, sigma_index=None):
    noise = sigma * torch.normal(mean=0, std=1, size=q_z.shape, device=q_z.device)
    perturbed_q_z = q_z + noise
    if isinstance(sigma, int): 
        sm_out = sm_model(perturbed_q_z)
        loss = 0.5 * (sigma ** 2) * ((sm_out + (noise/sigma**2)) ** 2).sum(dim=-1)
    else:
        sm_out = sm_model(perturbed_q_z, sigma_index) / sigma
        loss = 0.5 * (sigma.squeeze() ** 2) * ((sm_out.view(q_z.shape[0],-1) + (noise/sigma**2).view(q_z.shape[0],-1)) ** 2).sum(dim=-1)
    return loss.mean()

def train_model(train_loader, pvae_dict, model, optimizer, sm_type, device):
    losses = 0
    for mod in sorted(pvae_dict.keys()):
        pvae_dict[mod].eval()
    model.train()
    # sigmas = torch.tensor(np.exp(np.linspace(np.log(5), np.log(0.05),300))).to(device)
    sigmas = torch.tensor(np.linspace(5, 0.1, 200)).to(device)
    start_time = time.time()

    for batch_idx, (images, target) in enumerate(train_loader):
        with torch.no_grad():
            p = {}
            z = {}
            for mod in pvae_dict.keys():
                p[mod] = images['m'+mod].to(device)
                z[mod] = pvae_dict[mod].encoder(p[mod])

            # stack zs
            z = torch.cat([z[mod].unsqueeze(1) for mod in sorted(pvae_dict.keys())], dim=1).view(-1,len(pvae_dict.keys()),8,8).detach()

        with torch.enable_grad():
            rand_sigma_index = torch.randint(0, len(sigmas), (z.shape[0],)).to(device)
            cur_sigmas = sigmas[rand_sigma_index].view(z.shape[0],*([1]*len(z.shape[1:]))).float().to(device)

            loss = dsm_loss(model, z, cur_sigmas, rand_sigma_index)
            losses += loss.item()        

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    end_time = time.time()
    losses /= len(train_loader)
    print("TRAINING TIME TAKEN: ", end_time - start_time, flush=True)
    print("Training loss: ", losses, flush=True)
    return losses


def evaluate(test_loader, pvae_dict, sm_model, iter_num, amount, size_z, sm_type, epoch, device, unq_name, save_paths, all_mods):
    with torch.no_grad():
        start_time = time.time()
        losses = 0
        sm_model.eval()
        z = {}
        img_outs = {}
        c = 0.5
        # sigmas = torch.tensor(np.exp(np.linspace(np.log(5), np.log(0.05),300))).to(device)
        sigmas = torch.tensor(np.linspace(5, 0.1, 200)).to(device)
        er = 2e-3

        for mod in sorted(pvae_dict.keys()):
            pvae_dict[mod].eval()
            z[mod] = torch.normal(mean=0, std=1, size=(amount,size_z), requires_grad=True, device=device)
            img_outs[mod] = [] 

        if (epoch > 30) and ((epoch + 1) % 10 == 0):
            for s_in, s in enumerate(sigmas):
                sigma_index = torch.tensor([s_in]*amount).to(device)
                cur_sigmas = sigmas[sigma_index].float().to(device) 
                alpha = er * (sigmas[s_in]**2)/(sigmas[-1]**2)
                outs = {}
                
                for i in range(iter_num):
                    if s_in == len(sigmas)-1:                    
                        for ind,mod in enumerate(sorted(pvae_dict.keys())):
                            outs[mod] = pvae_dict[mod].decoder(z[mod])
                            img_outs[mod].append(torch.clip(outs[mod],0.,1.).squeeze().permute(1,2,0))

                    z_in = torch.cat([z[mod].unsqueeze(1) for mod in sorted(pvae_dict.keys())], dim=1).view(-1,len(pvae_dict.keys()),8,8).detach()
                    sm_out = sm_model(z_in, sigma_index) / cur_sigmas.view(z_in.shape[0],*([1]*len(z_in.shape[1:])))

                    for ind,mod in enumerate(sorted(pvae_dict.keys())):
                        z[mod] = z[mod] + (alpha * sm_out[:,ind].view(amount,size_z)) + c*torch.sqrt(2*alpha)*torch.normal(mean=0, std=1, size=z[mod].shape, device=device)

            for mod in sorted(pvae_dict.keys()):
                img_outs[mod] = torch.cat(img_outs[mod], dim=0)
            img_outs = torch.cat([img_outs[mod] for mod in sorted(pvae_dict.keys())], dim=1)
            plt.figure(figsize=(10,10))
            plt.imshow(img_outs.cpu().numpy())
            plt.axis('off')
            plt.savefig(save_paths['image'] +  all_mods + '_' + sm_type + unq_name + '__' + str(epoch) + '.png')

        for batch_idx, (images, target) in enumerate(test_loader):
            p = {}
            z = {}
            for mod in sorted(pvae_dict.keys()):
                p[mod] = images['m'+mod].to(device)
                z[mod] = pvae_dict[mod].encoder(p[mod])
                # z[mod] = pvae_dict[mod].pm_encoder(p[mod])[0]

            # stack zs
            z_all = torch.cat([z[mod].unsqueeze(1) for mod in sorted(pvae_dict.keys())], dim=1).view(-1,len(pvae_dict.keys()),8,8).detach()
            rand_sigma_index = torch.randint(0, len(sigmas), (z_all.shape[0],)).to(device)
            cur_sigmas = sigmas[rand_sigma_index].view(z_all.shape[0],*([1]*len(z_all.shape[1:]))).float().to(device)

            loss = dsm_loss(sm_model, z_all, cur_sigmas, rand_sigma_index)
            losses += loss.item()

        img_outs = {}
        given = '0'
        noised = {}
        for mod in sorted(pvae_dict.keys()):
            if mod not in given:
                z[mod] = torch.normal(mean=0, std=1, size=(amount,size_z), requires_grad=True, device=device)
            else:
                z[mod] = z[mod][0].view(1,-1)
            img_outs[mod] = [] 

        if (epoch > 30) and ((epoch + 1) % 10 == 0):
            for s_in, s in enumerate(sigmas):
                sigma_index = torch.tensor([s_in]*amount).to(device)
                cur_sigmas = sigmas[sigma_index].float().to(device) 
                alpha = er * (sigmas[s_in]**2)/(sigmas[-1]**2)
                outs = {}

                # if mod in given:
                #     noised[mod] = s * torch.randn_like(z[mod])
                #     z[mod] = z[mod] + noised[mod]
                
                for i in range(iter_num):
                    if s_in == len(sigmas)-1:                    
                        for ind,mod in enumerate(sorted(pvae_dict.keys())):
                            outs[mod] = pvae_dict[mod].decoder(z[mod])
                            img_outs[mod].append(torch.clip(outs[mod],0.,1.).squeeze().permute(1,2,0))

                    z_in = torch.cat([z[mod].unsqueeze(1) for mod in sorted(pvae_dict.keys())], dim=1).view(-1,len(pvae_dict.keys()),8,8).detach()
                    # z_in = torch.cat([z[mod].unsqueeze(1) for mod in sorted(pvae_dict.keys())], dim=1).view(-1,len(pvae_dict.keys()),8,8).detach()
                    sm_out = sm_model(z_in, sigma_index) / cur_sigmas.view(z_in.shape[0],*([1]*len(z_in.shape[1:])))

                    for ind,mod in enumerate(sorted(pvae_dict.keys())):
                        if mod not in given:
                            z[mod] = z[mod] + (alpha * sm_out[:,ind].view(amount,size_z)) + c*torch.sqrt(2*alpha)*torch.normal(mean=0, std=1, size=z[mod].shape, device=device)
                
                # if mod in given:
                #     z[mod] = z[mod] - noised[mod]
                
            for mod in sorted(pvae_dict.keys()):
                img_outs[mod] = torch.cat(img_outs[mod], dim=0)
            img_outs = torch.cat([img_outs[mod] for mod in sorted(pvae_dict.keys())], dim=1)
            plt.figure(figsize=(10,10))
            plt.imshow(img_outs.cpu().numpy())
            plt.axis('off')
            plt.savefig(save_paths['image'] + '_given0_' +all_mods + '_' + sm_type + unq_name + '__' + str(epoch) + '.png')
            plt.close('all') 

        losses /= len(test_loader)
        print("VALIDATION TIME TAKEN: ", time.time() - start_time, flush=True)
        print("Validation loss: ", losses, flush=True)
        return losses   


def run(epochs, batch_size, lr, size_z, all_mods, savefolder, model_paths, sm_type, unq_name):
    print('vars: ', epochs, batch_size, lr, size_z, sm_type, unq_name, all_mods, savefolder, flush=True)
    train_losses, val_losses = [], []

    savefolder += '/'
    save_paths = {'model': './models/' + savefolder, 'plot': './plots/' + savefolder, 'image': './images/' + savefolder}
    for p in save_paths.values():
        if not os.path.exists(p):
            os.makedirs(p)

    # cuda = torch.cuda.is_available()
    # print("GPU Available: ", cuda, flush=True)
    device = torch.device("cuda:2")
    print("device: ", torch.cuda.get_device_properties(device), flush=True)
    
    enc_channel_list = [(64,64,64,2), (64,128,128,2), (128,256,256,2)]
    dec_channel_list = [(256,128,128,2), (128,128,64,2), (64,64,64,2)]
    size_in = 32
    img_ch = 3
    pvae_dict = {}
    n_mod = len(all_mods)
    
    for ind, model_path in enumerate(model_paths):
        if str(ind) in all_mods:
            pmvae = ResAE(enc_channel_list, dec_channel_list, size_in, size_z, img_ch)
            pmvae.load_state_dict(torch.load(model_path)['model_state_dict'])
            pmvae = pmvae.to(device)
            pvae_dict[str(ind)] = pmvae

    if n_mod > 5:
        dim = 64
    else:
        dim = 32
    score_model = Unet(dim=dim, channels=n_mod, dim_mults=(1,2,2,2), with_time_emb=True)
    sm_optimizer = torch.optim.Adam(score_model.parameters(), lr=lr)
    score_model = score_model.to(device)
    
    train_dataloader, val_dataloader, _ = get_train_test_dataloader_upd10_32x32(batch_size)

    for epoch in range(epochs):
        print("Epoch: "+str(epoch + 1), flush=True)
        iter_num = 30

        training_loss = train_model(train_dataloader, pvae_dict, score_model, sm_optimizer, sm_type, device)
        validation_loss = evaluate(val_dataloader, pvae_dict, score_model, iter_num, 1, size_z, sm_type, epoch, device, unq_name, save_paths, all_mods)
        print(' ', flush=True)

        train_losses.append(training_loss)
        val_losses.append(validation_loss)

        if epoch == 0:
            prev_loss = validation_loss
        if epoch > 0 and (validation_loss < prev_loss):
            torch.save({
            'epoch': epoch,
            'model_state_dict': score_model.state_dict(),
            # 'optimizer_state_dict': sm_optimizer.state_dict(),
            'train_loss': training_loss,
            'val_loss': validation_loss,
            'size_z': size_z,
            }, save_paths['model'] + all_mods + '_' + str(size_z) + '_AE_' + str(unq_name))
            print('Model saved', flush=True)
            prev_loss = validation_loss

        if (epoch + 1) % 70 == 0:
            lr /= 5
            optimizer = torch.optim.Adam(score_model.parameters(), lr=lr)

  
    train_losses = np.array(train_losses)
    val_losses = np.array(val_losses)
    save_loss_plot_train_val(train_losses, val_losses, 'Loss', ['Train', 'Val'], save_paths['plot'] + all_mods + '_' + str(size_z) + sm_type + '_' + unq_name)
    

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--size-z', type=int, default=64,
                        help='size of z [default: 64]')
    parser.add_argument('--batch-size', type=int, default=512,
                        help='batch size for training [default: 512]')
    parser.add_argument('--epochs', type=int, default=300,
                        help='number of epochs to train [default: 300]')
    parser.add_argument('--upd', type=str, default='',
                        help='updated polymnist dataset [default: ]')
    parser.add_argument('--lr', type=float, default=0.0002,
                        help='learning rate [default: 0.0002]')
    parser.add_argument('--sm-type', type=str, default='dsm',
                        help='loss type [default: "dsm"]')
    parser.add_argument('--allmods', type=str, default='01',
                        help='Mods to train on [default: "01"]')
    parser.add_argument('--savefolder', type=str, default='psm_upd',
                        help='folder name to save output [default: "psm_upd"]')
    parser.add_argument('--unq-name', type=str, default='psm_aeNEW_unet_',
                        help='identifier name for saving [default: "psm_aeNEW_unet_"]')
    # parser.add_argument('--p0-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m0/polyAE_m0_64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m0/polyAE_m0_64"]')
    # parser.add_argument('--p1-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m1/polyAE_m1_64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m1/polyAE_m1_64"]')
    # parser.add_argument('--p2-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m2/polyAE_m2_64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m2/polyAE_m2_64"]')
    # parser.add_argument('--p3-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m3/polyAE_m3_64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m3/polyAE_m3_64"]')
    # parser.add_argument('--p4-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m4/polyAE_m4_64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m4/polyAE_m4_64"]')
    # parser.add_argument('--p5-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m5/polyupd10_res_beta_0.5__64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m5/polyupd10_res_beta_0.5__64"]')
    # parser.add_argument('--p6-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m6/polyupd10_res_beta_0.5__64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m6/polyupd10_res_beta_0.5__64"]')
    # parser.add_argument('--p7-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m7/polyupd10_res_beta_0.5__64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m7/polyupd10_res_beta_0.5__64"]')
    # parser.add_argument('--p8-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m8/polyupd10_res_beta_0.5__64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m8/polyupd10_res_beta_0.5__64"]')
    # parser.add_argument('--p9-path', type=str, default='/users/dwesego/multimodal/models/polyupd10_m9/polyupd10_res_beta_0.5__64',
    #                     help='multimodal model path [default: "/users/dwesego/multimodal/models/polyupd10_m9/polyupd10_res_beta_0.5__64"]')
    # parser.add_argument('--p0-path', type=str, default='./models/polyupd10_m0/polyNEWAE_m0_64_0.01_1e-07',
    #                     help='multimodal model path [default: "./models/polyupd10_m0/polyNEWAE_m0_64_0.01_1e-07"]')
    # parser.add_argument('--p1-path', type=str, default='./models/polyupd10_m1/polyNEWAE_m1_64_0.01_1e-07',
    #                     help='multimodal model path [default: "./models/polyupd10_m1/polyNEWAE_m1_64_0.01_1e-07"]')
    # parser.add_argument('--p0-path', type=str, default='./models/polyupd10_m0/polyNEWAE_m0_64_0.01_0.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m0/polyNEWAE_m0_64_0.01_0.0"]')
    # parser.add_argument('--p1-path', type=str, default='./models/polyupd10_m1/polyNEWAE_m1_64_0.01_0.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m1/polyNEWAE_m1_64_0.01_0.0"]')
    # parser.add_argument('--p2-path', type=str, default='./models/polyupd10_m2/polyNEWAE_m2_64_0.01_0.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m2/polyNEWAE_m2_64_0.01_0.0"]')
    # parser.add_argument('--p3-path', type=str, default='./models/polyupd10_m3/polyNEWAE_m3_64_0.01_0.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m3/polyNEWAE_m3_64_0.01_0.0"]')
    # parser.add_argument('--p4-path', type=str, default='./models/polyupd10_m4/polyNEWAE_m4_64_0.01_0.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m4/polyNEWAE_m4_64_0.01_0.0"]')
    # parser.add_argument('--p0-path', type=str, default='./models/polyupd10_m0/polyNEWAE_m0_64_0.01_0.1',
    #                     help='multimodal model path [default: "./models/polyupd10_m0/polyNEWAE_m0_64_0.01_0.1"]')
    # parser.add_argument('--p1-path', type=str, default='./models/polyupd10_m1/polyNEWAE_m1_64_0.01_0.1',
    #                     help='multimodal model path [default: "./models/polyupd10_m1/polyNEWAE_m1_64_0.01_0.1"]')
    # parser.add_argument('--p2-path', type=str, default='./models/polyupd10_m2/polyNEWAE_m2_64_0.01_0.1',
    #                     help='multimodal model path [default: "./models/polyupd10_m2/polyNEWAE_m2_64_0.01_0.1"]')
    # parser.add_argument('--p3-path', type=str, default='./models/polyupd10_m3/polyNEWAE_m3_64_0.01_0.1',
    #                     help='multimodal model path [default: "./models/polyupd10_m3/polyNEWAE_m3_64_0.01_0.1"]')
    # parser.add_argument('--p4-path', type=str, default='./models/polyupd10_m4/polyNEWAE_m4_64_0.01_0.1',
    #                     help='multimodal model path [default: "./models/polyupd10_m4/polyNEWAE_m4_64_0.01_0.1"]')
    
    # parser.add_argument('--p0-path', type=str, default='./models/polyupd10_m0/polyNEWAE_m0_64_0.01_1.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m0/polyNEWAE_m0_64_0.01_1.0"]')
    # parser.add_argument('--p1-path', type=str, default='./models/polyupd10_m1/polyNEWAE_m1_64_0.01_1.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m1/polyNEWAE_m1_64_0.01_1.0"]')
    # parser.add_argument('--p2-path', type=str, default='./models/polyupd10_m2/polyNEWAE_m2_64_0.01_1.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m2/polyNEWAE_m2_64_0.01_1.0"]')
    # parser.add_argument('--p3-path', type=str, default='./models/polyupd10_m3/polyNEWAE_m3_64_0.01_1.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m3/polyNEWAE_m3_64_0.01_1.0"]')
    # parser.add_argument('--p4-path', type=str, default='./models/polyupd10_m4/polyNEWAE_m4_64_0.01_1.0',
    #                     help='multimodal model path [default: "./models/polyupd10_m4/polyNEWAE_m4_64_0.01_1.0"]')

    parser.add_argument('--p0-path', type=str, default='./models/polyupd10_m0/polyNEWAE_m0_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m0/polyNEWAE_m0_64_0.01_1e-05"]')
    parser.add_argument('--p1-path', type=str, default='./models/polyupd10_m1/polyNEWAE_m1_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m1/polyNEWAE_m1_64_0.01_1e-05"]')
    parser.add_argument('--p2-path', type=str, default='./models/polyupd10_m2/polyNEWAE_m2_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m2/polyNEWAE_m2_64_0.01_1e-05"]')
    parser.add_argument('--p3-path', type=str, default='./models/polyupd10_m3/polyNEWAE_m3_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m3/polyNEWAE_m3_64_0.01_1e-05"]')
    parser.add_argument('--p4-path', type=str, default='./models/polyupd10_m4/polyNEWAE_m4_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m4/polyNEWAE_m4_64_0.01_1e-05"]')
    parser.add_argument('--p5-path', type=str, default='./models/polyupd10_m5/polyNEWAE_m5_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m5/polyNEWAE_m5_64_0.01_1e-05"]')
    parser.add_argument('--p6-path', type=str, default='./models/polyupd10_m6/polyNEWAE_m6_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m6/polyNEWAE_m6_64_0.01_1e-05"]')
    parser.add_argument('--p7-path', type=str, default='./models/polyupd10_m7/polyNEWAE_m7_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m7/polyNEWAE_m7_64_0.01_1e-05"]')
    parser.add_argument('--p8-path', type=str, default='./models/polyupd10_m8/polyNEWAE_m8_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m8/polyNEWAE_m8_64_0.01_1e-05"]')
    parser.add_argument('--p9-path', type=str, default='./models/polyupd10_m9/polyNEWAE_m9_64_0.01_1e-05',
                        help='multimodal model path [default: "./models/polyupd10_m9/polyNEWAE_m9_64_0.01_1e-05"]')
    
    args = parser.parse_args()

    run(args.epochs, args.batch_size, args.lr, args.size_z, args.allmods, args.savefolder, [args.p0_path, args.p1_path, args.p2_path, args.p3_path, args.p4_path, args.p5_path, args.p6_path, args.p7_path, args.p8_path, args.p9_path], args.sm_type, args.unq_name)


