import click
import os
import torch
import numpy as np
import glob
import torchvision.transforms as transforms
import torch.utils.data as data
from tqdm import tqdm

import create_model
import forward_process

from typing import Any, List, Tuple, Union, Optional

class EasyDict(dict):
    
    """Convenience class that behaves like a dict but allows access with the attribute syntax."""

    def __getattr__(self, name: str) -> Any:
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name: str, value: Any) -> None:
        self[name] = value

    def __delattr__(self, name: str) -> None:
        del self[name]

def npz_concat(filenames):
    for file in filenames:
        samples = np.load(file)['samples']
        try:
            data = np.concatenate((data, samples))
        except:
            data = samples
    return data

def npz_concat_cond(filenames):
    for file in filenames:
        samples = np.load(file)['samples']
        label = np.load(file)['label']
        try:
            data = np.concatenate((data, samples))
            data_label = np.concatenate((data_label, label))
        except:
            data = samples
            data_label = label
    return data, data_label

class BasicDataset(data.Dataset):
  def __init__(self, x_np, y_np, transform=None, is_transform = False):
    super(BasicDataset, self).__init__()

    self.x = x_np
    self.y = y_np
    self.transform = transform
    self.is_transform = is_transform

  def __getitem__(self, index):

    if self.is_transform:
        return self.transform(self.x[index]), self.y[index]
    else:
        return self.x[index], self.y[index]

  def __len__(self):
    return len(self.x)

class BasicDatasetCond(data.Dataset):
  def __init__(self, x_np, y_np, cond_np, transform=transforms.ToTensor()):
    super(BasicDatasetCond, self).__init__()

    self.x = x_np
    self.y = y_np
    self.cond = cond_np
    self.transform = transform

  def __getitem__(self, index):
    return self.transform(self.x[index]), self.y[index], self.cond[index]

  def __len__(self):
    return len(self.x)

@click.command()
@click.option('--savedir',                     help='Save directory',          metavar='PATH',    type=str, required=True,     default="/checkpoints_25k")
@click.option('--continue_training',           help='Continue training',       metavar='INT',     type=str, required=True,     default=0)
@click.option('--gendir',                      help='Fake sample directory',   metavar='PATH',    type=str, required=True,     default="/latent_datasets/celebA256_25K/fake/VQ-VAE-DDIM_steps_500_batch_size_100_seed_1_eta_0")
@click.option('--datadir',                     help='Real sample directory',   metavar='PATH',    type=str, required=True,     default="/latent_datasets/celebA256_25K/real/real_latents.npz")
@click.option('--img_resolution',              help='Image resolution',        metavar='INT',     type=click.IntRange(min=1),  default=64)
@click.option('--cond',                        help='Is it conditional?',      metavar='INT',     type=click.IntRange(min=0),  default=0)
@click.option('--timesteps',                   help='Timesteps - T',            metavar='INT',     type=click.IntRange(min=0),  default=1000)
@click.option('--num_data',                    help='Num samples',             metavar='INT',     type=click.IntRange(min=1),  default=25000)
@click.option('--batch_size',                  help='Num samples',             metavar='INT',     type=click.IntRange(min=1),  default=128)
@click.option('--epoch',                       help='Num epochs',             metavar='INT',     type=click.IntRange(min=1),  default=100)
@click.option('--lr',                          help='Learning rate',           metavar='FLOAT',   type=click.FloatRange(min=0),default=3e-4)
@click.option('--device',                      help='Device',                  metavar='STR',     type=str,                    default='cuda:0')

def main(**kwargs):

    opts = EasyDict(kwargs)
    gendir = os.getcwd() + opts.gendir
    savedir = os.getcwd() + opts.savedir
    datadir = os.getcwd() + opts.datadir
    os.makedirs(savedir, exist_ok=True)

    ## Prepare real data
    if not opts.cond:

        real_data = np.load(datadir)['arr_0']
        print("shape of real data is: ",real_data.shape)

    else:
        real_data  = np.load(datadir)['samples']
        real_label = np.load(datadir)['label']
        real_label = np.eye(10)[real_label]

    ## Prepare fake data
    if not opts.cond:
        if not os.path.exists(os.path.join(gendir, 'gen_data_for_discriminator_training.npz')):
            filenames = np.sort(glob.glob(os.path.join(gendir, 'sample*.npz')))
            gen_data = npz_concat(filenames)
            np.savez_compressed(os.path.join(gendir, 'gen_data_for_discriminator_training.npz'), samples=gen_data)
        else:
            gen_data = np.load(os.path.join(gendir, 'gen_data_for_discriminator_training.npz'))['samples']
            print("shape of generated data is: ",gen_data.shape)

    else:
        if not os.path.exists(os.path.join(gendir, 'gen_data_for_discriminator_training.npz')):
            filenames = np.sort(glob.glob(os.path.join(gendir, 'sample*.npz')))
            gen_data, gen_label = npz_concat_cond(filenames)
            np.savez_compressed(os.path.join(gendir, 'gen_data_for_discriminator_training.npz'), samples=gen_data, label=gen_label)
        else:
            gen_data = np.load(os.path.join(gendir, 'gen_data_for_discriminator_training.npz'))['samples']
            gen_label = np.load(os.path.join(gendir, 'gen_data_for_discriminator_training.npz'))['label']
            gen_label = gen_label[:opts.num_data]

    ## Combine the fake / real
    real_data = real_data[:opts.num_data]
    gen_data = gen_data[:opts.num_data]
    train_data = np.concatenate((real_data, gen_data))

    print("shape of train_data is: ", train_data.shape)
    
    train_label = torch.zeros(train_data.shape[0])
    train_label[:real_data.shape[0]] = 1.

    ## not using this transform
    transform = transforms.Compose([transforms.ToTensor()])

    if not opts.cond:
        train_dataset = BasicDataset(train_data, train_label, transform, is_transform = False)

    else:
        condition_label = np.concatenate((real_label, gen_label))
        train_dataset = BasicDatasetCond(train_data, train_label, condition_label, transform)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opts.batch_size, num_workers=0, shuffle=True, drop_last=True)

    ## Extractor & Disciminator
    discriminator = create_model.load_discriminator(None, opts.device, opts.cond, eval=False)

    '''
    #if continuing training
    ckpt_path = "/checkpoints/discriminator_150.pt"
    discriminator = create_model.load_discriminator(ckpt_path, opts.device, opts.cond, eval=False)
    '''

    ## Prepare training
    optimizer = torch.optim.Adam(discriminator.parameters(), lr=opts.lr, weight_decay=1e-7)
    loss = torch.nn.BCELoss()

    disc_loss = []

    '''
    #if continuing training
    disc_loss = np.load("").tolist() #please specify as required
    '''

    ## Training
    for i in tqdm(range(opts.epoch)):

        epoch_disc_loss = []

        #statistics
        outs = []
        cors = []

        num_data = 0

        for data in train_loader:

            optimizer.zero_grad()

            if not opts.cond:
                inputs, labels = data
            else:
                inputs, labels, cond = data
                cond = cond.to(opts.device)

            inputs = inputs.to(opts.device)
            labels = labels.to(opts.device)

            """
            #checking if on decoding samples from this dataset, we get proper reconstruction!
            from diffusers import VQModel
            import PIL.Image

            #load all models
            vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
            vqvae.to("cuda:0")

            def save_img_grid(image, path, gridh = 4, gridw = 4,img_resolution = 256, img_channels = 3, name = None):

                image = (image * 127.5 + 128).clip(0, 255).to(torch.uint8)
                image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
                image = image.reshape(gridh * img_resolution, gridw * img_resolution, img_channels)
                image = image.cpu().numpy()
                PIL.Image.fromarray(image, 'RGB').save("checkpoints_25k/check.png")

            inputs = inputs[:16]

            # decode image with vae
            with torch.no_grad():
                img_save = vqvae.decode(inputs)

            save_img_grid(img_save.sample, "check_25k.png")
            exit(0)
            """

            #sample uniformly from T timesteps
            t = torch.randint(0, opts.timesteps, (opts.batch_size,), device=opts.device).long()
            noise = torch.randn_like(inputs)

            ## Data perturbation
            perturbed_inputs = forward_process.q_sample(x_start = inputs, t = t, noise = noise)

            ## Forward
            if not opts.cond:
                label_prediction = discriminator(perturbed_inputs, t, sigmoid=True).view(-1)
            else:
                label_prediction = discriminator(perturbed_inputs, t, sigmoid=True, condition=cond).view(-1)

            ## Backward
            out = loss(label_prediction, labels)
            out.backward()
            optimizer.step()

            epoch_disc_loss += [float(out)]

            ## Report
            cor = ((label_prediction > 0.5).float() == labels).float().mean()
            outs.append(out.item())
            cors.append(cor.item())
            num_data += inputs.shape[0]
            print(f"{i}-th epoch BCE loss: {np.mean(outs)}, correction rate: {np.mean(cors)}")

        print(f"end of {i}-th epoch discriminator loss is: {np.mean(epoch_disc_loss)}")
        disc_loss.append(np.mean(epoch_disc_loss))

        ## Saving every 50 epochs
        if((i+1)%50 == 0):
            torch.save(discriminator.state_dict(), savedir + f"/discriminator_{i+1}.pt")

    np.save("checkpoints_25k/discriminator_loss_{epochs}_epochs".format(epochs = opts.epoch), disc_loss)

#----------------------------------------------------------------------------
if __name__ == "__main__":
    main()
#----------------------------------------------------------------------------