import argparse, os, sys, glob, datetime, yaml
sys.path.insert(1, os.getcwd())

os.environ['OPENBLAS_NUM_THREADS'] = '1'

import torch
import time
import numpy as np
from tqdm import trange
import tqdm

from omegaconf import OmegaConf
from PIL import Image

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config

#import trainable discriminator codes here
from trainable_discriminator import create_model

############################################################################################
#dataset class goes here
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from datasets import load_dataset
import PIL.Image
import itertools

#additional frameworks for obtaining divergence metrics
import ot
import torch.nn as nn
kl_loss = nn.KLDivLoss(reduction="batchmean")
from sklearn.neighbors import NearestNeighbors

transform = transforms.Compose([transforms.ToTensor()])
scaler = lambda x: 2. * x - 1.

class Dataset(Dataset):

    def __init__(self, dataset, transform = None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = sample['image']
        if self.transform:
            image = self.transform(image)
        return image

############################################################################################

rescale = lambda x: (x + 1.) / 2.

def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

#this code has been taken from "diffusers".
def save_img_grid(image, path, gridh = 4, gridw = 4,img_resolution = 256, img_channels = 3):
    image = (image * 127.5 + 128).clip(0, 255).to(torch.uint8)
    image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
    image = image.reshape(gridh * img_resolution, gridw * img_resolution, img_channels)
    image = image.cpu().numpy()
    PIL.Image.fromarray(image, 'RGB').save(path)


def custom_to_np(x):
    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
    sample = x.detach().cpu()
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    return sample

def logs2pil(logs, keys=["sample"]):
    imgs = dict()
    for k in logs:
        try:
            if len(logs[k].shape) == 4:
                img = custom_to_pil(logs[k][0, ...])
            elif len(logs[k].shape) == 3:
                img = custom_to_pil(logs[k])
            else:
                print(f"Unknown format for key {k}. ")
                img = None
        except:
            img = None
        imgs[k] = img
    return imgs


@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
               verbose=True,
               make_prog_row=False):

    if not make_prog_row:
        return model.p_sample_loop(None, shape,
                                   return_intermediates=return_intermediates, verbose=verbose)
    else:
        return model.progressive_denoising(
            None, shape, verbose=True
        )

#for time being making only changes in ddim sampler
@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0, dg = 0, ds_iter = None, dataset = None, seed1 = 0, seed2 = 1, trainable = 0, discriminator = None, tss = 0, cut_off_value = 0, window_size = 0):

    ddim = DDIMSampler(model)
    bs = shape[0]
    shape = shape[1:]
    samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False, dg = dg, ds_iter = ds_iter, dataset = dataset, seed1 = seed1, seed2 = seed2, trainable = trainable, discriminator = discriminator, tss = tss, cut_off_value = cut_off_value, window_size = window_size)

    return samples, intermediates


@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0, dg = 0, ds_iter = None, dataset = None, seed1 = 0, seed2 = 1, trainable = 0, discriminator = None, imglogdir = None, tss = 0, cut_off_value = 0, window_size = 0):


    log = dict()

    shape = [batch_size,
             model.model.diffusion_model.in_channels,
             model.model.diffusion_model.image_size,
             model.model.diffusion_model.image_size]

    with model.ema_scope("Plotting"):
        t0 = time.time()
        if vanilla:
            sample, progrow = convsample(model, shape,
                                         make_prog_row=True)

        #for time being making only changes to the ddim sampler
        else:
            sample, intermediates = convsample_ddim(model,  steps=custom_steps, shape=shape,
                                                    eta=eta, dg = dg, ds_iter = ds_iter, dataset = dataset, seed1 = seed1, seed2 = seed2, trainable = trainable, discriminator = discriminator, tss = tss, cut_off_value = cut_off_value, window_size = window_size)

        t1 = time.time()
    
    #important
    #decoder. decodes back to image space. 
    x_sample = model.decode_first_stage(sample)

    '''
    if imglogdir:
        print("saving a grid of images")
        x_grid = x_sample[:16]
        save_img_grid(x_grid, imglogdir + "/grid_16.png")
    '''

    log["sample"] = x_sample
    log["time"] = t1 - t0
    log['throughput'] = sample.shape[0] / (t1 - t0)
    print(f'Throughput for this batch: {log["throughput"]}')

    return log


############################################################################
#KL-divergence-estimators

def verify_sample_shapes(s1, s2, k):

    # Expects [N, D]
    assert len(s1.shape) == len(s2.shape) == 2
    # Check dimensionality of sample is identical
    assert s1.shape[1] == s2.shape[1]

def skl_estimator(s1, s2, k=1):

    """KL-Divergence estimator using scikit-learn's NearestNeighbours
    s1: (N_1,D) Sample drawn from distribution P
    s2: (N_2,D) Sample drawn from distribution Q
    k: Number of neighbours considered (default 1)
    return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    d = float(s1.shape[1])
    D = np.log(m / (n - 1))

    s1_neighbourhood = NearestNeighbors(n_neighbors=k + 1).fit(s1)
    s2_neighbourhood = NearestNeighbors(n_neighbors=k).fit(s2)

    for p1 in s1:
        s1_distances, indices = s1_neighbourhood.kneighbors([p1], k + 1)
        s2_distances, indices = s2_neighbourhood.kneighbors([p1], k)
        rho = s1_distances[0][-1]
        nu = s2_distances[0][-1]
        D += (d / n) * np.log(nu / rho)
    return D

############################################################################

#infact its just a number (metric) and not a plot
def plot_distance(model, dl1, dl2, ds, batch_size = 100, num_samples = 1000):

    #scale the dataset to [-1,1]
    scaler = lambda x: 2. * x - 1.

    #use 1k! initially
    num_iter = num_samples // batch_size #usually 100. Hence this value is more or less 10

    #these are just numbers
    kl_d = [0,0] #kl divergence
    w_d =  [0,0] #wasserstein 1 distance

    encoded_dl1 = []
    encoded_dl2 = []
    encoded_real = []

    for i in tqdm.tqdm(range(num_iter)):

        samples1 = next(dl1).to(model.device)#encode baseline samples and vectorize them
        samples1 = scaler(samples1).to(model.device)#normalize between [-1,1]

        with torch.no_grad():
            samples1 = model.encode_first_stage(samples1) #encode them

        samples1 = samples1.reshape(1, 1228800)
        encoded_dl1.append(samples1)

        ######################################################################

        samples2 = next(dl2).to(model.device)#encode proposed samples and vectorize them
        samples2 = scaler(samples2).to(model.device)#normalize between [-1,1]

        with torch.no_grad():
            samples2 = model.encode_first_stage(samples2) #encode them

        samples2 = samples2.reshape(1, 1228800)
        encoded_dl2.append(samples2)

        ######################################################################

        samples3 = next(ds).to(model.device)#encode real samples and vectorize them
        samples3 = scaler(samples3).to(model.device)#normalize between [-1,1]

        with torch.no_grad():
            samples3 = model.encode_first_stage(samples3) #encode them

        samples3 = samples3.reshape(1, 1228800)
        encoded_real.append(samples3)

    numpy_dl1 = (np.concatenate([t.cpu().numpy() for t in encoded_dl1]).reshape(num_samples, 12288))
    numpy_dl2 = (np.concatenate([t.cpu().numpy() for t in encoded_dl2]).reshape(num_samples, 12288))
    numpy_real = (np.concatenate([t.cpu().numpy() for t in encoded_real]).reshape(num_samples, 12288))
 
    d1 = skl_estimator(numpy_dl1, numpy_real)
    d2 = skl_estimator(numpy_dl2, numpy_real)

    print("KL divergence estimate for baseline and real: ", d1)
    print("KL divergence estimate for baseline and real: ", d2)


def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None, dg = 0, dataset =None, seed1 = 0, seed2 = 1, trainable = 0, discriminator = None, tss = 0, cut_off_value = 0, window_size = 0, estimate_tc = 0, estimate_metrics = 0):

    if vanilla:
        print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
    else:
        print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')


    tstart = time.time()
    n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
    # path = logdir
    if model.cond_stage_model is None:
        all_images = []

        # load dataset here from diffusers
        if dataset == "celeba":
            dataset_load = "korexyz/celeba-hq-256x256"
        elif dataset == "ffhq":
            dataset_load = "merkol/ffhq-256"
        elif dataset == "lsun_church":
            dataset_load = "tglcourse/lsun_church_train"

        if(dataset == "lsun_church"):
            transform = transforms.Compose([

                transforms.CenterCrop((256,256)),
                transforms.ToTensor()
            ])
        else:
            transform = transforms.Compose([
                transforms.ToTensor()
            ])


        ds = load_dataset(dataset_load)

        transformed_ds = Dataset(ds['train'], transform = transform)
        dataloader = DataLoader(transformed_ds, batch_size=batch_size, shuffle=True, drop_last = True)

        #dataset iterable 
        data_iter = itertools.cycle(dataloader)

        #call the Tc estimation module here.
        if estimate_tc != 0:

            ddim = DDIMSampler(model)
            
            #while estimating Tc, ensure that batch size is 100. 
            dataloader = DataLoader(transformed_ds, batch_size=100, shuffle=False, drop_last = True)

            #dataset iterable 
            data_iter = itertools.cycle(dataloader)

            #this will save the plot. Tc can be estimated from the plot.
            ddim.obtain_cutoff_value(num = 100, ds = data_iter, batch_size = 100)
            
            #use this module for one-time only. Approximate Tc from the graph. 
            exit(0)

        if  opt.estimate_metrics != 0:

            #while estimating Tc, ensure that batch size is 100. 
            dataloader = DataLoader(transformed_ds, batch_size=100, shuffle=False, drop_last = True)

            #dataset iterable 
            data_iter = itertools.cycle(dataloader)

        
            # Define the custom Dataset class
            class NumpyDataset(Dataset):

                def __init__(self, numpy_data, transform=None):

                    """
                    Args:
                    numpy_data (numpy array): The NumPy data (shape: N x H x W x C).
                    transform (callable, optional): Optional transform to be applied on a sample.
                    """

                    self.data = numpy_data
                    self.transform = transform

                def __len__(self):
                    return len(self.data)

                def __getitem__(self, idx):
                    sample = self.data[idx]
            
                    if self.transform:
                        sample = self.transform(sample)
            
                    return sample


            #transform is to convert numpy array to torch tensor
            transform = transforms.Compose([
                    transforms.ToTensor()
                ])
            
            #TODO. Load as needed
            input1 = np.load("")#usually baseline path. #1000 samples for time-being
            input2 = np.load("")#proposed model path. #1000 samples for time-being

            tensor1 = NumpyDataset(input1, transform)
            tensor2 = NumpyDataset(input2, transform)

            dataloader1 = DataLoader(tensor1, batch_size = 100, shuffle = False, drop_last = True)
            dataloader2 = DataLoader(tensor2, batch_size = 100, shuffle = False, drop_last = True)

            dataloader1 = itertools.cycle(dataloader1)
            dataloader2 = itertools.cycle(dataloader2)

            plot_distance(model, dataloader1, dataloader2, data_iter, batch_size = 100, num_samples = 500)

            exit(0)


        print(f"Running unconditional sampling for {n_samples} samples")
        for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):

            #seeds to make results reproducible
            seed1 += 1
            seed2 += 1


            with torch.no_grad():

                logs = make_convolutional_sample(model, batch_size=batch_size,
                                             vanilla=vanilla, custom_steps=custom_steps,
                                             eta=eta, dg = dg, ds_iter = data_iter, dataset = dataset, seed1 = seed1, seed2 = seed2, trainable = trainable, discriminator = discriminator, imglogdir = logdir, tss = tss, cut_off_value = cut_off_value, window_size = window_size)

                torch.cuda.empty_cache()
                

            n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
            all_images.extend([custom_to_np(logs["sample"])])

            if n_saved >= n_samples:
                print(f'Finish after generating {n_saved} samples')
                break

        all_img = np.concatenate(all_images, axis=0)
        all_img = all_img[:n_samples]
        shape_str = "x".join([str(x) for x in all_img.shape])
        nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
        np.savez(nppath, all_img)

    else:
       raise NotImplementedError('Currently only sampling for unconditional models supported.')

    print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")


def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
    for k in logs:
        if k == key:
            batch = logs[key]
            if np_path is None:
                for x in batch:
                    img = custom_to_pil(x)
                    imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
                    img.save(imgpath)
                    n_saved += 1
            else:
                npbatch = custom_to_np(batch) 
                shape_str = "x".join([str(x) for x in npbatch.shape])  
                nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
                np.savez(nppath, npbatch)
                n_saved += npbatch.shape[0]
    return n_saved


def get_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-gpu",
        "--gpu",
        type = int,
        nargs = "?",
        help = "GPU index"
    )

    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        nargs="?",
        help="load from logdir or checkpoint in logdir",
    )
    parser.add_argument(
        "-n",
        "--n_samples",
        type=int,
        nargs="?",
        help="number of samples to draw",
        default=50000
    )
    parser.add_argument(
        "-e",
        "--eta",
        type=float,
        nargs="?",
        help="eta for ddim sampling (0.0 yields deterministic sampling)",
        default=1.0
    )
    parser.add_argument(
        "-v",
        "--vanilla_sample",
        default=False,
        action='store_true',
        help="vanilla sampling (default option is DDIM sampling)?",
    )
    parser.add_argument(
        "-l",
        "--logdir",
        type=str,
        nargs="?",
        help="extra logdir",
        default="none"
    )
    parser.add_argument(
        "-c",
        "--custom_steps",
        type=int,
        nargs="?",
        help="number of steps for ddim and fastdpm sampling",
        default=50
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        nargs="?",
        help="the bs and also the number of centres for the closed-form discriminator",
        default=10
    )

    #changes
    parser.add_argument(
        "-dg",
        "--dg_weight",
        type = float, 
        nargs="?",
        help = "the weight to discriminator guidance (0 indicates without any guidance)",
        default=0
    )
    parser.add_argument(
        "-d",
        "--ds",
        type = str,
        nargs = "?",
        help = "dataset name. (this is required for centres in the closed-form discriminator)",
        default = "none"
    )
    parser.add_argument(
        "-seed1", 
        "--seed1",
        type = int,
        nargs = "?",
        help = "seed to img",
        default = 0
    )
    #this defaults to seed2 only! [in closed-form discriminator]
    parser.add_argument(
        "-seed2", 
        "--seed2",
        type = int,
        nargs = "?",
        help = "seed to img_next",
        default = 1
    )
    parser.add_argument(
        "-trainable",
        "--trainable",
        type = int,
        nargs = "?",
        help = "1 for trainable and 0 for closed-form discriminator",
        default=0
    )

    parser.add_argument(
        "-time_shift_sampler",
        "--time_shift_sampler",
        type = int,
        nargs = "?",
        help = "1 for time shift sampler and 0 for vanilla sampler",
        default = 0
    )

    parser.add_argument(
        "-cutoff_time",
        "--cutoff_time",
        type = int,
        nargs = "?",
        help = "Value of Tc", 
        default = 0
    )

    parser.add_argument(
        "-estimate_tc",
        "--estimate_tc",
        type = int,
        nargs = "?",
        help = "Estimate the value of Tc", #if value is already known then make this 0
        default = 0 
    )

    parser.add_argument(
        "-estimate_metrics",
        "--estimate_metrics",
        type = int,
        nargs = "?",
        help = "Obtain plots of Wasserstein 1 and KL Divergence plots (maybe I will add more later)", #only for previously generated samples
        default = 0
    )

    parser.add_argument(
        "-window_size",
        "--window_size",
        type = int,
        nargs = "?",
        help = "Size of the window",
        default = 10
    )
    return parser
    

def load_model_from_config(config, sd, gpu):
    model = instantiate_from_config(config)
    model.load_state_dict(sd,strict=False)

    device = "cuda:{gpu}".format(gpu=gpu)
    model.cuda()
    model.eval()
    return model

def load_model(config, ckpt, gpu, eval_mode, gpu_device):
    if ckpt:
        print(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt, map_location="cpu")
        global_step = pl_sd["global_step"]

    else:
        pl_sd = {"state_dict": None}
        global_step = None

    model = load_model_from_config(config.model,
                                   pl_sd["state_dict"], gpu_device)

    return model, global_step



if __name__ == "__main__":
    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    sys.path.append(os.getcwd())
    command = " ".join(sys.argv)

    parser = get_parser()
    opt, unknown = parser.parse_known_args()
    ckpt = None

    if not os.path.exists(opt.resume):
        raise ValueError("Cannot find {}".format(opt.resume))
    if os.path.isfile(opt.resume):
        # paths = opt.resume.split("/")
        try:
            logdir = '/'.join(opt.resume.split('/')[:-1])
            # idx = len(paths)-paths[::-1].index("logs")+1
            print(f'Logdir is {logdir}')
        except ValueError:
            paths = opt.resume.split("/")
            idx = -2  # take a guess: path/to/logdir/checkpoints/model.ckpt
            logdir = "/".join(paths[:idx])
        ckpt = opt.resume
    else:
        assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
        logdir = opt.resume.rstrip("/")
        ckpt = os.path.join(logdir, "model.ckpt")

    base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
    opt.base = base_configs

    configs = [OmegaConf.load(cfg) for cfg in opt.base]
    cli = OmegaConf.from_dotlist(unknown)
    config = OmegaConf.merge(*configs, cli)

    gpu = True
    eval_mode = True

    if opt.logdir != "none":
        locallog = logdir.split(os.sep)[-1]
        if locallog == "": locallog = logdir.split(os.sep)[-2]
        print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
        logdir = os.path.join(opt.logdir, locallog)

    print("printing the config here: ")

    model, global_step = load_model(config, ckpt, gpu, eval_mode, opt.gpu)

    print(f"global step: {global_step}")
    print(75 * "=")
    print("logging to:")
    logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
    imglogdir = os.path.join(logdir, "img")
    numpylogdir = os.path.join(logdir, "numpy")

    #estimate_tc and estimate_metrics require previously generated samples and they themselves don't generate samples. Hence, when running these two ensure that the program terminates with this. 
    #estimate_metrics require sample directories. 
    if ((opt.estimate_tc == 0) and (opt.estimate_metrics== 0)):

        os.makedirs(imglogdir)
        os.makedirs(numpylogdir)
        print(logdir)
        print(75 * "=")

        # write config out
        sampling_file = os.path.join(logdir, "sampling_config.yaml")
        sampling_conf = vars(opt)

        with open(sampling_file, 'w') as f:
            yaml.dump(sampling_conf, f, default_flow_style=False)
        print(sampling_conf)


    # load trainable discriminator, if trainable = 1
    discriminator = None
    if(opt.trainable == 1):

        # load the checkpoint here. 
        discriminator = create_model.load_discriminator("/trainable_discriminator/checkpoints/discriminator_100.pt", "cuda:0", 0, eval=True)


    run(model, imglogdir, eta=opt.eta,
        vanilla=opt.vanilla_sample,  n_samples=opt.n_samples, custom_steps=opt.custom_steps,
        batch_size=opt.batch_size, nplog=numpylogdir, dg = opt.dg_weight, dataset = opt.ds, seed1 = opt.seed1, seed2 = opt.seed2, trainable = opt.trainable, discriminator = discriminator, tss = opt.time_shift_sampler, cut_off_value = opt.cutoff_time, window_size = opt.window_size, estimate_tc = opt.estimate_tc, estimate_metrics = opt.estimate_metrics)

    print("done.")
