import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
import glob
import os
from PIL import Image
import math
from inspect import isfunction
from functools import partial
from einops import rearrange
from torch import nn, einsum

from h_vae_model_copy import ResVAEN
from unet_model import CAUNET, Unet, UnetNodown
from utils import *

from celeba_hq_mask_dataset import CelebAHQMaskDS
from h_vae_model import CelebAAttrNewBN
from sklearn.metrics import f1_score

new_id_to_attr = ['Bald',
        'Bangs',
        'Black_Hair',
        'Blond_Hair',
        'Brown_Hair',
        'Bushy_Eyebrows',
        'Eyeglasses',
        'Gray_Hair',
        'Heavy_Makeup',
        'Male',
        'Mouth_Slightly_Open',
        'Mustache',
        'Pale_Skin',
        'Receding_Hairline',
        'Smiling',
        'Straight_Hair',
        'Wavy_Hair',
        'Wearing_Hat',
]
    
def get_train_test_dataloader(batch_size, size):
    train_dataset = CelebAHQMaskDS(size=size, ds_type='train')
    val_dataset = CelebAHQMaskDS(size=size, ds_type='val')

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    return train_dataloader, val_dataloader

def dsm_loss(sm_model, q_z, sigma=1, sigma_index=None):
    noise = sigma * torch.normal(mean=0, std=1, size=q_z.shape, device=q_z.device)
    perturbed_q_z = q_z + noise
    if isinstance(sigma, int): 
        sm_out = sm_model(perturbed_q_z)
        loss = 0.5 * (sigma ** 2) * ((sm_out + (noise/sigma**2)) ** 2).sum(dim=-1)
    else:
        sm_out = sm_model(perturbed_q_z, sigma_index) / sigma
        loss = 0.5 * (sigma.squeeze() ** 2) * ((sm_out.view(q_z.shape[0],-1) + (noise/sigma**2).view(q_z.shape[0],-1)) ** 2).sum(dim=-1)
    return loss.mean()

def train_model(train_loader, image_vae, mask_vae, attr_vae, attr_visible, n_mod, model, size_z, optimizer, device):
    losses = 0
    image_vae.eval()
    mask_vae.eval()
    attr_vae.eval()
    model.train()
    sigmas = torch.tensor(np.linspace(5, 0.1, 500)).to(device)
    # sigmas = torch.tensor(np.linspace(3, 0.01, 200))
    # sigmas = torch.tensor(np.linspace(3, 0.1, 200)).to(device) #s3e-2
    start_time = time.time()
    dim = int(np.sqrt(size_z))

    for batch_idx, (images, masks, target) in enumerate(train_loader):

        input = images.to(device)
        masks = masks.to(device)
        target = target.to(device)[:, attr_visible]

        with torch.no_grad():
            # Get z for each modality
            image_mu, image_logvar = image_vae.encoder(input)
            # z_image = image_vae.reparametrize(image_mu, image_logvar)
            z_image = image_mu

            mask_mu, mask_logvar = mask_vae.encoder(masks)
            # z_mask = mask_vae.reparametrize(mask_mu, mask_logvar)
            z_mask =  mask_mu
            
            attr_mu, attr_logvar = attr_vae.encoder(target.float())
            # z_attr = attr_vae.reparametrize(attr_mu, attr_logvar)
            z_attr = attr_mu

        with torch.enable_grad():
            z = torch.cat([z_image.unsqueeze(1), z_mask.unsqueeze(1), z_attr.unsqueeze(1)], dim=1).view(-1,n_mod,dim,dim)
            rand_sigma_index = torch.randint(0, len(sigmas), (z.shape[0],)).to(device)
            cur_sigmas = sigmas[rand_sigma_index].view(z.shape[0],*([1]*len(z.shape[1:]))).float().to(device)

            loss = dsm_loss(model, z, cur_sigmas, rand_sigma_index)        
            losses += loss.item()        

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    end_time = time.time()
    losses /= len(train_loader)
    print("TRAINING TIME TAKEN: ", end_time - start_time, flush=True)
    print("Training loss: ", losses, flush=True)
    return losses

@torch.no_grad()
def evaluate(val_loader, image_vae, mask_vae, attr_vae, attr_visible, n_mod, model, size_z, device, epoch, unq_name):
    losses = 0
    image_vae.eval()
    mask_vae.eval()
    attr_vae.eval()
    model.eval()
    sigmas = torch.tensor(np.linspace(5, 0.1, 500)).to(device)
    # sigmas = torch.tensor(np.linspace(3, 0.01, 200)) #s3e-4
    # sigmas = torch.tensor(np.linspace(3, 0.1, 200)) #s3e-2
    start_time = time.time()
    dim = int(np.sqrt(size_z))

    for batch_idx, (images, masks, target) in enumerate(val_loader):

        input = images.to(device)
        masks = masks.to(device)
        target = target.to(device)[:, attr_visible]
        

        # Get z for each modality
        image_mu, image_logvar = image_vae.encoder(input)
        # z_image = image_vae.reparametrize(image_mu, image_logvar)
        z_image = image_mu

        mask_mu, mask_logvar = mask_vae.encoder(masks)
        # z_mask = mask_vae.reparametrize(mask_mu, mask_logvar)
        z_mask =  mask_mu
        
        attr_mu, attr_logvar = attr_vae.encoder(target.float())
        # z_attr = attr_vae.reparametrize(attr_mu, attr_logvar)
        z_attr = attr_mu

        z = torch.cat([z_image.unsqueeze(1), z_mask.unsqueeze(1), z_attr.unsqueeze(1)], dim=1).view(-1,n_mod,dim,dim)
        rand_sigma_index = torch.randint(0, len(sigmas), (z.shape[0],)).to(device)
        cur_sigmas = sigmas[rand_sigma_index].view(z.shape[0],*([1]*len(z.shape[1:]))).float().to(device)

        loss = dsm_loss(model, z, cur_sigmas, rand_sigma_index)        
        losses += loss.item()  

    if (epoch > 20) and ((epoch+1) % 10 == 0):
        mods = '012' # 0 for image, 1 for mask, 2 for attr
        z = {}
        models = {'0': image_vae, '1': mask_vae, '2': attr_vae}
        samples = {'0': input[0].unsqueeze(0), '1': masks[0].unsqueeze(0), '2': target[0].unsqueeze(0).float()}
        outs = {}

        for mod in mods:
            if mod != '0':
                z[mod] = models[mod].encoder(samples[mod])[0]
            else:
                z[mod] = torch.normal(mean=0, std=1, size=(1, size_z), device=device)
            outs[mod] = models[mod].decoder(z[mod])
        
        # er = 2e-3
        er = 0.01
        n_comp = 2
        c = 0.8

        for s_in, s in enumerate(sigmas):
            sigma_index = torch.tensor([s_in]*1).to(device)
            cur_sigmas = sigmas[sigma_index].float().to(device)
            alpha = er * (sigmas[s_in]**2)/(sigmas[-1]**2)

            # z_mask = z['1'] + s * torch.randn_like(z['1']).unsqueeze(1)
            # z_att = z['2'] + s * torch.randn_like(z['2']).unsqueeze(1)
            
            for i in range(n_comp):
                z_all = torch.cat([z['0'].unsqueeze(1), z['1'].unsqueeze(1), z['2'].unsqueeze(1)], dim=1).view(-1,n_mod,dim,dim)
                sm_out = model(z_all, sigma_index) / cur_sigmas.view(z_all.shape[0],*([1]*len(z_all.shape[1:])))

                z['0'] = z['0'] + (alpha * sm_out[:,0].view(-1,size_z)) + c*(torch.sqrt(2*alpha) * torch.randn_like(z['0']))
                outs['0'] = models['0'].decoder(z['0'])

        sigmoid_outputs = torch.sigmoid(outs['2']).detach().cpu()
        predicted_att = np.round(sigmoid_outputs)

        tar_str, pred_str = 'T: ', 'P: '
        for ind,att in enumerate(target[0]):
            if int(att) == 1:
                tar_str += new_id_to_attr[ind] + ' '
        for ind,att in enumerate(predicted_att[0]):
            if int(att) == 1:
                pred_str += new_id_to_attr[ind] + ' '

        plt.figure()
        grid = torchvision.utils.make_grid(torch.cat([samples['0'], outs['0']],dim=0), nrow=10)
        plt.title('Samples')
        plt.axis("off")
        plt.imshow(grid.detach().permute(1,2,0).cpu().numpy())
        plt.savefig('./images/celeb_score/unetVAE_img_' + str(epoch) + '_' +  unq_name + '.png')
        plt.figure()
        grid = torchvision.utils.make_grid(torch.cat([samples['1'], outs['1']],dim=0), nrow=10)
        plt.title('Samples')
        plt.axis("off")
        plt.imshow(grid.detach().permute(1,2,0).cpu().numpy())
        plt.savefig('./images/celeb_score/unetVAE_mask_' + str(epoch) + '_' +  unq_name + '.png')
        plt.figure()
        plt.text(0.05,0.5,tar_str + '\n' + pred_str)
        plt.savefig('./images/celeb_score/unetVAE_att_' + str(epoch) + '_' +  unq_name + '.png')   
        plt.close('all') 

    end_time = time.time()
    losses /= len(val_loader)
    print("VALIDATION TIME TAKEN: ", end_time - start_time, flush=True)
    print("Validation loss: ", losses, flush=True)
    return losses


def run(epochs, batch_size, lr, size_z1, size_z2, image_model_path, mask_model_path, attr_model_path, unq_name):
    res_size = 128
    print('vars: ', epochs, batch_size, lr, size_z1, size_z2, unq_name, flush=True)
    train_losses, val_losses = [], []
    attr_visible  = [4, 5, 8, 9, 11, 12, 15, 17, 18, 20, 21, 22, 26, 28, 31, 32, 33, 35]
    print(attr_visible, flush=True)

    path = {'model': './models/celeb_score/', 'plots': './plots/celeb_score/', 'images': './images/celeb_score/'}
    for p in path.values():
        if not os.path.exists(p):
            os.makedirs(p)

    cuda = torch.cuda.is_available()
    print("GPU Available: ", cuda, flush=True)
    device = torch.device("cuda:2" if cuda else "cpu")

    # Load mask model
    enc_channel_list = [(64,128,128,4), (128,256,256,4)]
    dec_channel_list = [(256,256,128,4), (128,128,64,4)]
    size_in = res_size
    mask_img_ch = 1    
    mask_vae = ResVAEN(enc_channel_list, dec_channel_list, size_in, size_z1, mask_img_ch)
    mask_vae.load_state_dict(torch.load(mask_model_path)['model_state_dict'])
    mask_vae = mask_vae.to(device)

    # Load image model
    #sm
    enc_channel_list = [(64,128,128,2), (128,256,256,2), (256,512,512,2)]
    dec_channel_list = [(512,512,256,2), (256,256,128,2), (128,128,64,2)]
    # enc_channel_list = [(64,128,128,2), (128,256,256,2), (256,512,512,2), (512,1024,1024,2)]
    # dec_channel_list = [(1024,1024,512,2), (512,512,256,2), (256,256,128,2), (128,128,64,2)]
    size_in = res_size
    img_ch = 3    
    image_vae = ResVAEN(enc_channel_list, dec_channel_list, size_in, size_z1, img_ch)
    image_vae.load_state_dict(torch.load(image_model_path)['model_state_dict'])
    image_vae = image_vae.to(device)

    # Load attr model
    attr_vae = CelebAAttrNewBN(size_z2)
    attr_vae.load_state_dict(torch.load(attr_model_path)['model_state_dict'])
    attr_vae = attr_vae.to(device)

    assert size_z1 == size_z2

    n_mod = 3
    score_model = Unet(dim=128, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True)
    # score_model = UnetNodown(dim=128, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True)
    # score_model = Unet(dim=16, channels=n_mod, dim_mults=(1,2,4), with_time_emb=True)
    # score_model = Unet(dim=16, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True)
    # score_model = CAUNET(n_mod=n_mod, z_dim=size_z1, dim=16, dim2=16, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True, cross=False)
    # score_model = CAUNET3(n_mod=n_mod, z_dim=size_z1, dim=16, dim2=16, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True, cross=False)

    optimizer = torch.optim.Adam(score_model.parameters(), lr=lr)
    score_model = score_model.to(device)
    
    train_dataloader, val_dataloader = get_train_test_dataloader(batch_size, res_size)

    for epoch in range(epochs):
        print("Epoch: "+str(epoch + 1), flush=True)

        training_loss = train_model(train_dataloader, image_vae, mask_vae, attr_vae, attr_visible, n_mod, score_model, size_z1, optimizer, device)
        validation_loss = evaluate(val_dataloader, image_vae, mask_vae, attr_vae, attr_visible, n_mod, score_model, size_z1, device, epoch, unq_name)
        print(' ', flush=True)

        train_losses.append(training_loss)
        val_losses.append(validation_loss)

        if epoch == 0:
            prev_loss = validation_loss
        if epoch > 0 and (validation_loss < prev_loss):
            torch.save({
            'epoch': epoch,
            'model_state_dict': score_model.state_dict(),
            # 'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': training_loss,
            'val_loss': validation_loss,
            'size_z': size_z1,
            }, path['model'] + "celeb_hq_unetVAE_" + str(size_z1) + str(unq_name) + str(len(attr_visible)) + unq_name)
            print('Model saved', flush=True)
            prev_loss = validation_loss

        if (epoch + 1) % 50 == 0:
            lr /= 5
            optimizer = torch.optim.Adam(score_model.parameters(), lr=lr)

    train_losses = np.array(train_losses)
    val_losses = np.array(val_losses)
    save_loss_plot_train_val(train_losses, val_losses, 'Loss', ['Train', 'Val'], path['plots'] + 'cel_upd_dsm_' + '_' + unq_name)
     

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--size-z1', type=int, default=256,
                        help='size of z1 [default: 256]')
    parser.add_argument('--size-z2', type=int, default=256,
                        help='size of z2 [default: 256]')
    parser.add_argument('--batch-size', type=int, default=512,
                        help='batch size for training [default: 512]')
    parser.add_argument('--epochs', type=int, default=300,
                        help='number of epochs to train [default: 300]')
    parser.add_argument('--lr', type=float, default=0.0002,
                        help='learning rate [default: 0.0002]')
    parser.add_argument('--unq-name', type=str, default='unet_hqdsm',
                        help='name to identify the model [default: "unet_hqdsm"]')
    # parser.add_argument('--vae-type', type=str, default='conv_sig',
    #                     help='vae type [default: "conv_sig"]')
    # parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_res_dsize_128_z_256_beta_1.0_smN_256__',
    #                     help='vae model path [default: "./models/celeba/celeb_hq_res_dsize_128_z_256_beta_1.0_smN_256__"]')
    # parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_res_dsize_128_z_256_beta_0.5_smN_256__',
    #                     help='vae model path [default: "./models/celeba/celeb_hq_res_dsize_128_z_256_beta_0.5_smN_256__"]')
    
    #256
    parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_res_dsize_128_z_256_beta_0.1_smN_256__',
                        help='vae model path [default: "./models/celeba/celeb_hq_res_dsize_128_z_256_beta_0.1_smN_256__"]')
    parser.add_argument('--mask-path', type=str, default='./models/celeba_mask/celeb_hq_mask_dsize_128_z_256_beta_1_mask_celebhq',
                        help='mask vae model path [default: "./models/celeba_mask/celeb_hq_mask_dsize_128_z_256_beta_1_mask_celebhq"]')
    parser.add_argument('--attr-path', type=str, default='./models/celeba_attr/celeba_attr_bn_hq__z_256_beta_0.1',
                        help='vae model path [default: "./models/celeba_attr/celeba_attr_bn_hq__z_256_beta_0.1"]')
    
    # #1024
    # parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_res_dsize_128_z_1024_beta_0.1_smN_',
    #                     help='vae model path [default: "./models/celeba/celeb_hq_res_dsize_128_z_1024_beta_0.1_smN_"]')
    # parser.add_argument('--mask-path', type=str, default='./models/celeba_mask/celeb_hq_mask_dsize_128_z_1024_beta_1_mask_celebhq',
    #                     help='mask vae model path [default: "./models/celeba_mask/celeb_hq_mask_dsize_128_z_1024_beta_1_mask_celebhq"]')
    # parser.add_argument('--attr-path', type=str, default='./models/celeba_attr/celeba_attr_bn_hq__z_1024_beta_0.1',
    #                     help='vae model path [default: "./models/celeba_attr/celeba_attr_bn_hq__z_1024_beta_0.1"]')

    args = parser.parse_args()

    run(args.epochs, args.batch_size, args.lr, args.size_z1, args.size_z2, args.image_path, args.mask_path, args.attr_path, args.unq_name)


