import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
import glob
import os
from PIL import Image
import math
from inspect import isfunction
from functools import partial
from einops import rearrange
from torch import nn, einsum

from h_vae_model_copy import ResAEN
from unet_model import CAUNET, CAUNET2, CAUNET3, Unet, NorUnet, SM3simple
from utils import *

from celeba_hq_mask_dataset import CelebAHQMaskDS
from h_vae_model import CelebAAttrNewBNAE
from sklearn.metrics import f1_score
    
def get_train_test_dataloader(batch_size, size):
    train_dataset = CelebAHQMaskDS(size=size, ds_type='train')
    val_dataset = CelebAHQMaskDS(size=size, ds_type='val')

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    return train_dataloader, val_dataloader

def dsm_loss(sm_model, q_z, sigma=1, sigma_index=None):
    noise = sigma * torch.normal(mean=0, std=1, size=q_z.shape, device=q_z.device)
    perturbed_q_z = q_z + noise
    if isinstance(sigma, int): 
        sm_out = sm_model(perturbed_q_z)
        loss = 0.5 * (sigma ** 2) * ((sm_out + (noise/sigma**2)) ** 2).sum(dim=-1)
    else:
        sm_out = sm_model(perturbed_q_z, sigma_index) / sigma
        loss = 0.5 * (sigma.squeeze() ** 2) * ((sm_out.view(q_z.shape[0],-1) + (noise/sigma**2).view(q_z.shape[0],-1)) ** 2).sum(dim=-1)
    return loss.mean()

new_id_to_attr = ['Bald',
        'Bangs',
        'Black_Hair',
        'Blond_Hair',
        'Brown_Hair',
        'Bushy_Eyebrows',
        'Eyeglasses',
        'Gray_Hair',
        'Heavy_Makeup',
        'Male',
        'Mouth_Slightly_Open',
        'Mustache',
        'Pale_Skin',
        'Receding_Hairline',
        'Smiling',
        'Straight_Hair',
        'Wavy_Hair',
        'Wearing_Hat',
]

# def dsm_loss(sm_model, q_z, sigma=1, sigma_index=None):
#     noise = sigma * torch.normal(mean=0, std=1, size=q_z.shape, device=q_z.device)
#     perturbed_q_z = q_z + noise
#     if isinstance(sigma, int): 
#         sm_out = sm_model(perturbed_q_z)
#         loss = 0.5 * (sigma ** 2) * ((sm_out + (noise/sigma**2)) ** 2).sum(dim=-1)
#     else:
#         sm_out = sm_model(perturbed_q_z, sigma)
#         loss = 0.5 * (sigma.squeeze() ** 2) * ((sm_out.view(q_z.shape[0],-1) + (noise/sigma**2).view(q_z.shape[0],-1)) ** 2).sum(dim=-1)
#     return loss.mean()

def train_model(train_loader, image_vae, mask_vae, attr_vae, attr_visible, n_mod, model, optimizer, size_z, device):
    losses = 0
    image_vae.eval()
    mask_vae.eval()
    attr_vae.eval()
    model.train()
    # sigmas = torch.tensor(np.linspace(10, 0.01, 500))
    # sigmas = torch.tensor(np.exp(np.linspace(np.log(6), np.log(0.01),300)))
    # sigmas = torch.tensor(np.exp(np.linspace(np.log(30), np.log(0.01),500)))
    # if size_z == 256:
    #     sigmas = torch.tensor(np.exp(np.linspace(np.log(20), np.log(0.01),54))) #calc_256
    # else:
    #     sigmas = torch.tensor(np.exp(np.linspace(np.log(20), np.log(0.01),112))) #calc
    sigmas = torch.tensor(np.linspace(5, 0.1, 500)).to(device)
    start_time = time.time()
    dim = int(np.sqrt(size_z))

    for batch_idx, (images, masks, target) in enumerate(train_loader):

        input = images.to(device)
        masks = masks.to(device)
        target = target.to(device)[:, attr_visible]

        with torch.no_grad():
            # Get z for each modality
            z_image = image_vae.encoder(input)
            z_mask = mask_vae.encoder(masks)
            z_attr = attr_vae.encoder(target.float())

        with torch.enable_grad():
            z = torch.cat([z_image.unsqueeze(1), z_mask.unsqueeze(1), z_attr.unsqueeze(1)], dim=1).view(-1,n_mod,dim,dim)
            rand_sigma_index = torch.randint(0, len(sigmas), (z.shape[0],)).to(device)
            cur_sigmas = sigmas[rand_sigma_index].view(z.shape[0],*([1]*len(z.shape[1:]))).float().to(device)

            loss = dsm_loss(model, z, cur_sigmas, rand_sigma_index)        
            losses += loss.item()        

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    end_time = time.time()
    losses /= len(train_loader)
    print("TRAINING TIME TAKEN: ", end_time - start_time, flush=True)
    print("Training loss: ", losses, flush=True)
    return losses

@torch.no_grad()
def evaluate(val_loader, image_vae, mask_vae, attr_vae, attr_visible, n_mod, model, epoch, unq_name, size_z, device):
    losses = 0
    image_vae.eval()
    mask_vae.eval()
    attr_vae.eval()
    model.eval()
    dim = int(np.sqrt(size_z))
    
    # if size_z == 256:
    #     sigmas = torch.tensor(np.exp(np.linspace(np.log(30), np.log(0.01),500))) #calc_256
    # else:
    #     sigmas = torch.tensor(np.exp(np.linspace(np.log(30), np.log(0.01),112))) #calc
    # sigmas = torch.tensor(np.exp(np.linspace(np.log(30), np.log(0.01),500)))
    # sigmas = torch.tensor(np.exp(np.linspace(np.log(20), np.log(0.01),500))) #reg001_1k
    # sigmas = torch.tensor(np.linspace(10, 0.01, 500))
    sigmas = torch.tensor(np.linspace(5, 0.1, 500)).to(device)
    start_time = time.time()

    for batch_idx, (images, masks, target) in enumerate(val_loader):

        input = images.to(device)
        masks = masks.to(device)
        target = target.to(device)[:, attr_visible]
        
        # Get z for each modality
        z_image = image_vae.encoder(input)
        z_mask = mask_vae.encoder(masks)
        z_attr = attr_vae.encoder(target.float())

        z = torch.cat([z_image.unsqueeze(1), z_mask.unsqueeze(1), z_attr.unsqueeze(1)], dim=1).view(-1,n_mod,dim,dim)

        rand_sigma_index = torch.randint(0, len(sigmas), (z.shape[0],)).to(device)
        cur_sigmas = sigmas[rand_sigma_index].view(z.shape[0],*([1]*len(z.shape[1:]))).float().to(device)

        loss = dsm_loss(model, z, cur_sigmas, rand_sigma_index)        
        losses += loss.item() 

    if (epoch > 30) and ((epoch+1) % 10 == 0):
        mods = '012' # 0 for image, 1 for mask, 2 for attr
        z = {}
        models = {'0': image_vae, '1': mask_vae, '2': attr_vae}
        samples = {'0': input[0].unsqueeze(0), '1': masks[0].unsqueeze(0), '2': target[0].unsqueeze(0).float()}
        outs = {}

        for mod in mods:
            if mod != '0':
                z[mod] = models[mod].encoder(samples[mod])
            else:
                z[mod] = torch.normal(mean=0, std=1, size=(1, size_z), device=device)
            outs[mod] = models[mod].decoder(z[mod])
        
        # er = 2e-4
        # er = 2e-4
        # if size_z == 256:
        #     er = 2.3e-6
        # else:
        #     er = 2e-6
        # n_comp = 5
        er = 0.01
        n_comp = 2
        c = 0.8

        for s_in, s in enumerate(sigmas):
            sigma_index = torch.tensor([s_in]*1).to(device)
            cur_sigmas = sigmas[sigma_index].float().to(device)
            alpha = er * (sigmas[s_in]**2)/(sigmas[-1]**2)

            # z_mask = z['1'] + s * torch.randn_like(z['1']).unsqueeze(1)
            # z_att = z['2'] + s * torch.randn_like(z['2']).unsqueeze(1)
            
            for i in range(n_comp):
                z_all = torch.cat([z['0'].unsqueeze(1), z['1'].unsqueeze(1), z['2'].unsqueeze(1)], dim=1).view(-1,3,dim,dim)
                sm_out = model(z_all, sigma_index) / cur_sigmas.view(z_all.shape[0],*([1]*len(z_all.shape[1:])))
                # sm_out = model(z_all, cur_sigmas.view(z_all.shape[0],*([1]*len(z_all.shape[1:]))))

                z['0'] = z['0'] + (alpha * sm_out[:,0].view(-1,size_z)) + c*(torch.sqrt(2*alpha) * torch.randn_like(z['0']))
                outs['0'] = models['0'].decoder(z['0'])

        sigmoid_outputs = torch.sigmoid(outs['2']).detach().cpu()
        predicted_att = np.round(sigmoid_outputs)

        tar_str, pred_str = 'T: ', 'P: '
        for ind,att in enumerate(target[0]):
            if int(att) == 1:
                tar_str += new_id_to_attr[ind] + ' '
        for ind,att in enumerate(predicted_att[0]):
            if int(att) == 1:
                pred_str += new_id_to_attr[ind] + ' '

        plt.figure()
        grid = torchvision.utils.make_grid(torch.cat([samples['0'], outs['0']],dim=0), nrow=10)
        plt.title('Samples')
        plt.axis("off")
        plt.imshow(grid.detach().permute(1,2,0).cpu().numpy())
        plt.savefig('./images/celeb_score/unet_img_' + str(epoch) + '_' +  unq_name + '.png')
        plt.figure()
        grid = torchvision.utils.make_grid(torch.cat([samples['1'], outs['1']],dim=0), nrow=10)
        plt.title('Samples')
        plt.axis("off")
        plt.imshow(grid.detach().permute(1,2,0).cpu().numpy())
        plt.savefig('./images/celeb_score/unet_mask_' + str(epoch) + '_' +  unq_name + '.png')
        plt.figure()
        plt.text(0.05,0.5,tar_str + '\n' + pred_str)
        plt.savefig('./images/celeb_score/unet_att_' + str(epoch) + '_' +  unq_name + '.png')   
        plt.close('all')        

    end_time = time.time()
    losses /= len(val_loader)
    print("VALIDATION TIME TAKEN: ", end_time - start_time, flush=True)
    print("Validation loss: ", losses, flush=True)
    return losses


def run(epochs, batch_size, lr, size_z1, size_z2, image_model_path, mask_model_path, attr_model_path, unq_name):
    res_size = 128
    print('vars: ', epochs, batch_size, lr, size_z1, size_z2, unq_name, flush=True)
    train_losses, val_losses = [], []
    attr_visible  = [4, 5, 8, 9, 11, 12, 15, 17, 18, 20, 21, 22, 26, 28, 31, 32, 33, 35]
    print(attr_visible, flush=True)

    path = {'model': './models/celeb_score/', 'plots': './plots/celeb_score/', 'images': './images/celeb_score/'}
    for p in path.values():
        if not os.path.exists(p):
            os.makedirs(p)

    cuda = torch.cuda.is_available()
    print("GPU Available: ", cuda, flush=True)
    device = torch.device("cuda:4")

    # Load mask model
    enc_channel_list = [(64,128,128,4), (128,256,256,4)]
    dec_channel_list = [(256,256,128,4), (128,128,64,4)]
    size_in = res_size
    img_ch = 1    
    mask_vae = ResAEN(enc_channel_list, dec_channel_list, size_in, size_z1, img_ch)
    mask_vae.load_state_dict(torch.load(mask_model_path)['model_state_dict'])
    mask_vae = mask_vae.to(device)

    # Load image model
    #sm
    enc_channel_list = [(64,128,128,2), (128,256,256,2), (256,512,512,2)]
    dec_channel_list = [(512,512,256,2), (256,256,128,2), (128,128,64,2)]
    # enc_channel_list = [(64,128,128,2), (128,256,256,2), (256,512,512,2), (512,1024,1024,2)]
    # dec_channel_list = [(1024,1024,512,2), (512,512,256,2), (256,256,128,2), (128,128,64,2)]
    size_in = res_size
    img_ch = 3    
    image_vae = ResAEN(enc_channel_list, dec_channel_list, size_in, size_z1, img_ch)
    image_vae.load_state_dict(torch.load(image_model_path)['model_state_dict'])
    image_vae = image_vae.to(device)

    # Load attr model
    attr_vae = CelebAAttrNewBNAE(size_z2)
    attr_vae.load_state_dict(torch.load(attr_model_path)['model_state_dict'])
    attr_vae = attr_vae.to(device)

    assert size_z1 == size_z2

    n_mod = 3
    dim = int(np.sqrt(size_z1))
    # score_model = Unet(dim=32, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True)
    score_model = Unet(dim=128, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True)
    # score_model = CAUNET(n_mod=n_mod, z_dim=size_z1, dim=32, dim2=32, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True, cross=True)
    # score_model = CAUNET2(n_mod=n_mod, z_dim=size_z1, dim=32, dim2=32, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True, n_block=3, cross=True)
    # score_model = CAUNET3(n_mod=n_mod, z_dim=size_z1, dim=dim, dim2=dim, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=True, cross=True)
    # score_model = NorUnet(n_mod=n_mod, z_dim=size_z1, dim=32, dim2=32, channels=n_mod, dim_mults=(1,2,2,2,2), with_time_emb=False)
    # score_model = NorUnet(n_mod=n_mod, z_dim=size_z1, dim=32, dim2=32, channels=n_mod, dim_mults=(1,2,2), with_time_emb=False) #3
    # score_model = SM3simple(n_mod, size_z1)
    optimizer = torch.optim.Adam(score_model.parameters(), lr=lr)
    score_model = score_model.to(device)
    
    train_dataloader, val_dataloader = get_train_test_dataloader(batch_size, res_size)

    for epoch in range(epochs):
        print("Epoch: "+str(epoch + 1), flush=True)

        training_loss = train_model(train_dataloader, image_vae, mask_vae, attr_vae, attr_visible, n_mod, score_model, optimizer, size_z1, device)
        validation_loss = evaluate(val_dataloader, image_vae, mask_vae, attr_vae, attr_visible, n_mod, score_model, epoch, unq_name, size_z1, device)
        print(' ', flush=True)

        train_losses.append(training_loss)
        val_losses.append(validation_loss)

        if epoch == 0:
            prev_loss = validation_loss
        if epoch > 0 and (validation_loss < prev_loss):
            torch.save({
            'epoch': epoch,
            'model_state_dict': score_model.state_dict(),
            # 'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': training_loss,
            'val_loss': validation_loss,
            'size_z': size_z1,
            }, path['model'] + "celeb_hq_unet_AE_" + str(size_z1) + str(unq_name) + str(len(attr_visible)) + unq_name)
            print('Model saved', flush=True)
            prev_loss = validation_loss

        if (epoch + 1) % 50 == 0:
            lr /= 5
            optimizer = torch.optim.Adam(score_model.parameters(), lr=lr)

    train_losses = np.array(train_losses)
    val_losses = np.array(val_losses)
    save_loss_plot_train_val(train_losses, val_losses, 'Loss', ['Train', 'Val'], path['plots'] + 'cel_upd_dsm_AE_' + '_' + unq_name)
     

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--size-z1', type=int, default=1024,
                        help='size of z1 [default: 1024]')
    parser.add_argument('--size-z2', type=int, default=1024,
                        help='size of z2 [default: 1024]')
    parser.add_argument('--batch-size', type=int, default=512,
                        help='batch size for training [default: 512]')
    parser.add_argument('--epochs', type=int, default=300,
                        help='number of epochs to train [default: 300]')
    parser.add_argument('--lr', type=float, default=0.0002,
                        help='learning rate [default: 0.0002]')
    parser.add_argument('--unq-name', type=str, default='unet_hq1k',
                        help='name to identify the model [default: "unet_hq1k"]')
    # parser.add_argument('--vae-type', type=str, default='conv_sig',
    #                     help='vae type [default: "conv_sig"]')
    # parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_ae__dsize_128_z_1024_1k_reg1N_sm_hq_',
    #                     help='image path for vae [default: "./models/celeba/celeb_hq_ae__dsize_128_z_1024_1k_reg1N_sm_hq_"]')
    # parser.add_argument('--mask-path', type=str, default='./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_1024_mask_hq_reg1_',
    #                     help='mask path for vae [default: "./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_1024_mask_hq_reg1_"]')
    # parser.add_argument('--attr-path', type=str, default='./models/celeba_attr/celeba_attr_bn_hq_AEreg__z_1024_1.0',
    #                     help='attr path for vae [default: "./models/celeba_attr/celeba_attr_bn_hq_AEreg__z_1024_1.0"]')

    # parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_ae__dsize_128_z_1024_1k_reg001NWK_sm_hq_',
    #                     help='image path for vae [default: "./models/celeba/celeb_hq_ae__dsize_128_z_1024_1k_reg001NWK_sm_hq_"]')
    # parser.add_argument('--mask-path', type=str, default='./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_1024_mask_hq_reg0011k_',
    #                     help='mask path for vae [default: "./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_1024_mask_hq_reg0011k_"]')
    # parser.add_argument('--attr-path', type=str, default='./models/celeba_attr/celeba_attr_bn_hq40_AEreg__z_1024_0.01',
    #                     help='attr path for vae [default: "./models/celeba_attr/celeba_attr_bn_hq40_AEreg__z_1024_0.01"]')

    # parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_ae__dsize_128_z_256_256_reg001NWK_sm_hq_',
    #                     help='image path for vae [default: "./models/celeba/celeb_hq_ae__dsize_128_z_256_256_reg001NWK_sm_hq_"]')
    # parser.add_argument('--mask-path', type=str, default='./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_256_mask_hq_reg001256_',
    #                     help='mask path for vae [default: "./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_256_mask_hq_reg001256_"]')
    # parser.add_argument('--attr-path', type=str, default='./models/celeba_attr/celeba_attr_bn_hq40_AEreg__z_256_0.01',
    #                     help='attr path for vae [default: "./models/celeba_attr/celeba_attr_bn_hq40_AEreg__z_256_0.01"]')

    # parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_ae__dsize_128_z_1024_1kae_sm_hq_',
    #                     help='vae model path [default: "./models/celeba/celeb_hq_ae__dsize_128_z_1024_1kae_sm_hq_"]')
    # parser.add_argument('--mask-path', type=str, default='./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_1024_mask_celebhq_ae_',
    #                     help='mask vae model path [default: "./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_1024_mask_celebhq_ae_"]')
    # parser.add_argument('--attr-path', type=str, default='/users/dwesego/multimodal/models/celeba_attr/celeba_attr_bn_hq_AE__z_1024',
    #                     help='vae model path [default: "/users/dwesego/multimodal/models/celeba_attr/celeba_attr_bn_hq_AE__z_1024"]')

    # 256
    # parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_ae__beta_0.0001_noisecons_0.001_z_256_256_re4ne3_sm_hq_',
    #                     help='image path for vae [default: "./models/celeba/celeb_hq_ae__beta_0.0001_noisecons_0.001_z_256_256_re4ne3_sm_hq_"]')
    # parser.add_argument('--mask-path', type=str, default='./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_256_mask_hq256_re5ne3_',
    #                     help='mask path for vae [default: "./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_256_mask_hq256_re5ne3_"]')
    # parser.add_argument('--attr-path', type=str, default='./models/celeba_attr/celeba_attr_bn_hq_AEreg__z_256_0.0001_att_ae_re4ne1',
    #                     help='attr path for vae [default: "./models/celeba_attr/celeba_attr_bn_hq_AEreg__z_256_0.0001_att_ae_re4ne1"]')

    # 1024
    parser.add_argument('--image-path', type=str, default='./models/celeba/celeb_hq_ae__beta_0.0001_noisecons_0.001_z_10241024_re4ne3_sm_hq_',
                        help='image path for vae [default: "./models/celeba/celeb_hq_ae__beta_0.0001_noisecons_0.001_z_10241024_re4ne3_sm_hq_"]')
    parser.add_argument('--mask-path', type=str, default='./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_1024mask_hq1024_re5ne3__',
                        help='mask path for vae [default: "./models/celeba_mask/celeb_hq_mask_AE__dsize_128_z_1024mask_hq1024_re5ne3__"]')
    parser.add_argument('--attr-path', type=str, default='./models/celeba_attr/celeba_attr_bn_hq_AEreg__z_1024_0.0001att_hq1024_re4ne1__',
                        help='attr path for vae [default: "./models/celeba_attr/celeba_attr_bn_hq_AEreg__z_1024_0.0001att_hq1024_re4ne1__"]')

    args = parser.parse_args()

    run(args.epochs, args.batch_size, args.lr, args.size_z1, args.size_z2, args.image_path, args.mask_path, args.attr_path, args.unq_name)


