import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch.nn.functional as F
import numpy as np
import os

from polymnist_dataset import get_train_test_dataset_upd10_32x32
from h_vae_model_copy import ResVAE
from lat_sm2_model import LSMPoly64_sm, Poly_sm, Poly_sm2

from utils import *

def get_train_test_dataloader_upd10_32x32(batch_size):
    paired_train_dataset, paired_val_dataset, paired_test_dataset = get_train_test_dataset_upd10_32x32()
    train_dataloader = DataLoader(paired_train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_dataloader = DataLoader(paired_val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(paired_test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    return train_dataloader, val_dataloader, test_dataloader

def sm_loss(score_net, input):
    grad_energy = score_net(input)
    loss1 = torch.square(torch.norm(grad_energy, dim=-1))

    # loss2 to calculate the trace of the gradients with respect to each input
    loss2 = torch.zeros(input.shape[0], device=input.device)
    for i in range(input.shape[1]):
        i_grad = torch.autograd.grad(grad_energy[:,i].sum(), input, create_graph=True, retain_graph=True)[0].view(input.shape[0],-1)[:,i]
        loss2 += i_grad
    
    return (0.5 * (loss1 + loss2)).mean()

def dsm_loss(sm_model, q_z, sigma=1):
    noise = sigma * torch.normal(mean=0, std=1, size=q_z.shape, device=q_z.device)
    perturbed_q_z = q_z + noise
    sm_out = sm_model(perturbed_q_z)
    if isinstance(sigma, int): 
        loss = 0.5 * (sigma ** 2) * ((sm_out + (noise/sigma**2)) ** 2).sum(dim=-1)
    else:
        loss = 0.5 * (sigma.squeeze() ** 2) * ((sm_out + (noise/sigma**2)) ** 2).sum(dim=-1)
    return loss.mean()

def ssm_loss(score_net, input):
    grad_energy = score_net(input)
    proj_vec = torch.randn_like(input)
    loss1 = 0.5 * torch.square(torch.sum(grad_energy * proj_vec, dim=-1))

    grad_score_proj = torch.autograd.grad(torch.sum(grad_energy * proj_vec), input, create_graph=True)[0]
    loss2 = torch.sum(grad_score_proj * proj_vec, dim=-1)
    
    return (loss1 + loss2).mean()

def train_model(train_loader, pvae_dict, model, optimizer, sm_type, device):
    losses = 0
    for mod in sorted(pvae_dict.keys()):
        pvae_dict[mod].eval()
    model.train()
    start_time = time.time()

    for batch_idx, (images, target) in enumerate(train_loader):
        with torch.no_grad():
            p = {}
            z = {}
            for mod in pvae_dict.keys():
                p[mod] = images['m'+mod].to(device)
                z[mod] = pvae_dict[mod].reparametrize(*pvae_dict[mod].encoder(p[mod]))

            # stack zs
            z = torch.cat([z[mod] for mod in sorted(pvae_dict.keys())], dim=1).detach()

        with torch.enable_grad():
            if sm_type == 'sm':
                z.requires_grad = True
                loss = sm_loss(model, z)
            elif sm_type == 'dsm':
                loss = dsm_loss(model, z)
            
            losses += loss.item()        

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    end_time = time.time()
    losses /= len(train_loader)
    print("TRAINING TIME TAKEN: ", end_time - start_time, flush=True)
    print("Training loss: ", losses, flush=True)
    return losses


def evaluate(test_loader, pvae_dict, sm_model, iter_num, lr1, lr2, amount, size_z, sm_type, epoch, device, unq_name, save_paths, all_mods):
    with torch.no_grad():
        start_time = time.time()
        losses = 0
        sm_model.eval()
        z = {}
        img_outs = {}

        for mod in sorted(pvae_dict.keys()):
            pvae_dict[mod].eval()
            z[mod] = torch.normal(mean=0, std=1, size=(amount,size_z), requires_grad=True, device=device)
            img_outs[mod] = [] 

        if (epoch + 1) % 10 == 0:    
            for i in range(iter_num):
                outs = {}
                
                for ind,mod in enumerate(sorted(pvae_dict.keys())):
                    outs[mod] = pvae_dict[mod].decoder(z[mod])
                    img_outs[mod].append(torch.clip(outs[mod],0.,1.).squeeze().permute(1,2,0))

                z_in = torch.cat([z[mod] for mod in sorted(pvae_dict.keys())], dim=1)
                sm_out = sm_model(z_in)

                for ind,mod in enumerate(sorted(pvae_dict.keys())):
                    z[mod] = z[mod] + (lr1 * sm_out[:,ind*size_z:ind*size_z+size_z]) + lr2*torch.normal(mean=0, std=1, size=z[mod].shape, device=device)

            for mod in sorted(pvae_dict.keys()):
                img_outs[mod] = torch.cat(img_outs[mod], dim=0)
            img_outs = torch.cat([img_outs[mod] for mod in sorted(pvae_dict.keys())], dim=1)
            plt.figure(figsize=(10,10))
            plt.imshow(img_outs.cpu().numpy())
            plt.axis('off')
            plt.savefig(save_paths['image'] +  all_mods + '_' + str(sm_model.size_z) + sm_type + unq_name + '__' + str(epoch) + '.png')

        for batch_idx, (images, target) in enumerate(test_loader):
            p = {}
            z = {}
            for mod in sorted(pvae_dict.keys()):
                p[mod] = images['m'+mod].to(device)
                z[mod] = pvae_dict[mod].reparametrize(*pvae_dict[mod].encoder(p[mod]))
                # z[mod] = pvae_dict[mod].pm_encoder(p[mod])[0]

            # stack zs
            z = torch.cat([z[mod] for mod in sorted(pvae_dict.keys())], dim=1).detach()
            with torch.enable_grad():
                z.requires_grad = True

                if sm_type == 'sm':
                    loss = sm_loss(sm_model, z)
                elif sm_type == 'dsm':
                    loss = dsm_loss(sm_model, z)
                
                losses += loss.item()
            
        losses /= len(test_loader)
        print("VALIDATION TIME TAKEN: ", time.time() - start_time, flush=True)
        print("Validation loss: ", losses, flush=True)
        return losses   


def run(epochs, batch_size, lr, size_z, all_mods, savefolder, model_paths, sm_type, unq_name):
    print('vars: ', epochs, batch_size, lr, size_z, sm_type, unq_name, all_mods, savefolder, flush=True)
    train_losses, val_losses = [], []

    savefolder += '/'
    save_paths = {'model': './models/' + savefolder, 'plot': './plots/' + savefolder, 'image': './images/' + savefolder}
    for p in save_paths.values():
        if not os.path.exists(p):
            os.makedirs(p)

    cuda = torch.cuda.is_available()
    print("GPU Available: ", cuda, flush=True)
    device = torch.device("cuda:4")
    print("device: ", torch.cuda.get_device_properties(device), flush=True)
    
    enc_channel_list = [(64,64,64,2), (64,128,128,2), (128,256,256,2)]
    dec_channel_list = [(256,128,128,2), (128,128,64,2), (64,64,64,2)]
    size_in = 32
    img_ch = 3
    pvae_dict = {}
    n_mod = len(all_mods)
    
    for ind, model_path in enumerate(model_paths):
        if str(ind) in all_mods:
            pmvae = ResVAE(enc_channel_list, dec_channel_list, size_in, size_z, img_ch)
            pmvae.load_state_dict(torch.load(model_path)['model_state_dict'])
            pmvae = pmvae.to(device)
            pvae_dict[str(ind)] = pmvae

    if sm_type == 'sm':
        score_model = Poly_sm(n_mod, size_z)
    elif sm_type == 'dsm':
        score_model = LSMPoly64_sm(n_mod, size_z)
    sm_optimizer = torch.optim.Adam(score_model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(sm_optimizer, 'min')
    score_model = score_model.to(device)
    
    train_dataloader, val_dataloader, _ = get_train_test_dataloader_upd10_32x32(batch_size)

    for epoch in range(epochs):
        print("Epoch: "+str(epoch + 1), flush=True)

        training_loss = train_model(train_dataloader, pvae_dict, score_model, sm_optimizer, sm_type, device)
        validation_loss = evaluate(val_dataloader, pvae_dict, score_model, 30, 0.2, 0.1, 1, size_z, sm_type, epoch, device, unq_name, save_paths, all_mods)
        print(' ', flush=True)

        train_losses.append(training_loss)
        val_losses.append(validation_loss)
        scheduler.step(validation_loss)

        if epoch == 0:
            prev_loss = validation_loss
        if epoch > 0 and (validation_loss < prev_loss):
            torch.save({
            'epoch': epoch,
            'model_state_dict': score_model.state_dict(),
            # 'optimizer_state_dict': sm_optimizer.state_dict(),
            'train_loss': training_loss,
            'val_loss': validation_loss,
            'size_z': size_z,
            }, save_paths['model'] + all_mods + '_' + str(size_z) + sm_type + str(unq_name))
            print('Model saved', flush=True)
            prev_loss = validation_loss

  
    train_losses = np.array(train_losses)
    val_losses = np.array(val_losses)
    save_loss_plot_train_val(train_losses, val_losses, 'Loss', ['Train', 'Val'], save_paths['plot'] + all_mods + '_' + str(size_z) + sm_type + '_' + unq_name)
    

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--size-z', type=int, default=64,
                        help='size of z [default: 64]')
    parser.add_argument('--batch-size', type=int, default=512,
                        help='batch size for training [default: 512]')
    parser.add_argument('--epochs', type=int, default=400,
                        help='number of epochs to train [default: 400]')
    parser.add_argument('--upd', type=str, default='',
                        help='updated polymnist dataset [default: ]')
    parser.add_argument('--lr', type=float, default=0.0002,
                        help='learning rate [default: 0.0002]')
    parser.add_argument('--sm-type', type=str, default='dsm',
                        help='loss type [default: "dsm"]')
    parser.add_argument('--allmods', type=str, default='0123456789',
                        help='Mods to train on [default: "0123456789"]')
    parser.add_argument('--savefolder', type=str, default='psm_upd',
                        help='folder name to save output [default: "psm_upd"]')
    parser.add_argument('--unq-name', type=str, default='psm_upd_res_',
                        help='identifier name for saving [default: "psm_upd_res_"]')
    parser.add_argument('--p0-path', type=str, default='./models/polyupd10_m0/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m0/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p1-path', type=str, default='./models/polyupd10_m1/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m1/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p2-path', type=str, default='./models/polyupd10_m2/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m2/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p3-path', type=str, default='./models/polyupd10_m3/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m3/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p4-path', type=str, default='./models/polyupd10_m4/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m4/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p5-path', type=str, default='./models/polyupd10_m5/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m5/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p6-path', type=str, default='./models/polyupd10_m6/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m6/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p7-path', type=str, default='./models/polyupd10_m7/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m7/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p8-path', type=str, default='./models/polyupd10_m8/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m8/polyupd10_res_beta_0.5__64"]')
    parser.add_argument('--p9-path', type=str, default='./models/polyupd10_m9/polyupd10_res_beta_0.5__64',
                        help='multimodal model path [default: "./models/polyupd10_m9/polyupd10_res_beta_0.5__64"]')
    
    args = parser.parse_args()

    run(args.epochs, args.batch_size, args.lr, args.size_z, args.allmods, args.savefolder, [args.p0_path, args.p1_path, args.p2_path, args.p3_path, args.p4_path, args.p5_path, args.p6_path, args.p7_path, args.p8_path, args.p9_path], args.sm_type, args.unq_name)