import time, logging
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
from model import ConditionalModel, Discriminator, Navigator
from ema import EMA
from loss import ct_loss, ct_withd_loss
from data_loader import *
import torch.optim as optim
import matplotlib.pyplot as plt
plt.style.use('ggplot')


def get_optimizer(config, parameters):
    if config.optim.optimizer == 'Adam':
        return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
                          betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad,
                          eps=config.optim.eps)
    elif config.optim.optimizer == 'RMSProp':
        return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'SGD':
        return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
    else:
        raise NotImplementedError(
            'Optimizer {} not understood.'.format(config.optim.optimizer))

def get_dataset(args, config):
    if config.data.dataset == 'swiss_roll':
        data = SwissRoll().sample(config.data.dataset_size)
    elif config.data.dataset == 'moons':
        data = Moons().sample(config.data.dataset_size)
    elif config.data.dataset == '8gaussians':
        data = Gaussians().sample(config.data.dataset_size - config.data.dataset_size%8, mode=8)
    elif config.data.dataset == '25gaussians':
        data = Gaussians().sample(config.data.dataset_size - config.data.dataset_size%25, mode=25)
    else:
        raise NotImplementedError("Toy dataset options: swiss_roll, moons, 8gaussians and 25gaussians.")
    return data

# Forward function
def make_beta_schedule(schedule='linear', num_timesteps=1000, start=1e-5, end=1e-2):
    if schedule == 'linear':
        betas = torch.linspace(start, end, num_timesteps)
    elif schedule == 'const':
        betas = end * torch.ones(num_timesteps)
    elif schedule == "quad":
        betas = torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2
    elif schedule == "jsd":
        betas = 1.0 / torch.linspace(num_timesteps, 1, num_timesteps)
    elif schedule == "sigmoid":
        betas = torch.linspace(-6, 6, num_timesteps)
        betas = torch.sigmoid(betas) * (end - start) + start
    return betas


def extract(input, t, x):
    shape = x.shape
    out = torch.gather(input, 0, t.to(input.device))
    reshape = [t.shape[0]] + [1] * (len(shape) - 1)
    return out.reshape(*reshape)

def q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_0).to(x_0.device)
    alphas_t = extract(alphas_bar_sqrt, t, x_0)
    alphas_1_m_t = extract(one_minus_alphas_bar_sqrt, t, x_0)
    return (alphas_t * x_0 + alphas_1_m_t * noise)

def q_sample_next(x_t_minus_1, betas_sqrt, one_minus_betas_sqrt, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_t_minus_1).to(x_t_minus_1.device)
    betas_t = extract(betas_sqrt, t, x_t_minus_1)
    betas_1_m_t = extract(one_minus_betas_sqrt, t, x_t_minus_1)
    return (betas_1_m_t * x_t_minus_1 + betas_t * noise)

def q_sample_posterior(x0, xt, posterior_mean_coeff_1, posterior_mean_coeff_2, posterior_variance_coeff, t, noise=None):
    if noise is None:
        noise = torch.randn_like(xt).to(xt.device)
    mean_1 = extract(posterior_mean_coeff_1, t, xt)
    mean_2 = extract(posterior_mean_coeff_2, t, xt)
    variance = extract(posterior_variance_coeff, t, xt)
    return (mean_1 * x0 + mean_2 * xt + variance.sqrt() * noise)

# Reverse function
def p_sample(model, x, t, z_dim=2):
    device = next(model.parameters()).device
    t = torch.tensor([t]).to(device)
    sample = model(x, t)
    return (sample)

def p_sample_loop(model, shape, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, posterior_mean_coeff_1, posterior_mean_coeff_2, posterior_variance_coeff, n_steps=1, z_dim=2):
    device = next(model.parameters()).device
    cur_x = torch.randn(shape).to(device)
    x_seq = [cur_x]
    cur_x = p_sample(model, cur_x, n_steps, z_dim)
    cur_x = q_sample(cur_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, torch.tensor([n_steps-1]).to(device))
    x_seq.append(cur_x)
    for i in reversed(range(1,n_steps)):
        x_i = cur_x
        cur_x = p_sample(model, x_i, i, z_dim)
        cur_x = q_sample_posterior(cur_x, x_i, posterior_mean_coeff_1, posterior_mean_coeff_2, posterior_variance_coeff, torch.tensor([i]).to(device))
        x_seq.append(cur_x)
    assert len(x_seq) == n_steps + 1
    return x_seq

# Evaluation with KLD
def kld(x, y, grid=(-20,20), num_grid=400):
    p_x,_,_ = np.histogram2d(x[:, 0], x[:, 1], bins=num_grid, range=[[grid[0], grid[1]], [grid[0], grid[1]]]) 
    p_x /= x.shape[0]
    p_x += 1e-7
    p_y,_,_ = np.histogram2d(y[:, 0], y[:, 1], bins=num_grid, range=[[grid[0], grid[1]], [grid[0], grid[1]]])
    p_y /= y.shape[0]
    p_y += 1e-7
    return (p_x * np.log(p_x/p_y)).sum()

class Diffusion(object):
    def __init__(self, args, config, device=None):
        self.args = args
        self.config = config
        if device is None:
            device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
        self.device = device

        self.model_var_type = config.model.var_type
        self.num_timesteps = config.diffusion.timesteps
        betas = make_beta_schedule(schedule=config.diffusion.beta_schedule, num_timesteps=1000, start=config.diffusion.beta_start, end=config.diffusion.beta_end)[:self.num_timesteps]

        betas = self.betas = betas.float().to(self.device)
        self.betas_sqrt = torch.sqrt(betas)
        

        alphas = 1.0 - betas
        self.one_minus_betas_sqrt = torch.sqrt(alphas)
        alphas_cumprod = alphas.cumprod(dim=0)
        self.alphas_bar_sqrt = torch.sqrt(alphas_cumprod)
        self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_cumprod)
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
        )
        self.alphas_cumprod_prev = alphas_cumprod_prev
        self.posterior_mean_coeff_1 = (
            betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        self.posterior_mean_coeff_2 = (
            torch.sqrt(alphas) * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
        )
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        self.posterior_variance = posterior_variance
        if self.model_var_type == "fixedlarge":
            self.logvar = betas.log()
            # torch.cat(
            # [posterior_variance[1:2], betas[1:]], dim=0).log()
        elif self.model_var_type == "fixedsmall":
            self.logvar = posterior_variance.clamp(min=1e-20).log()

    def train(self):
        args = self.args
        config = self.config
        tb_logger = self.config.tb_logger
        dataset = get_dataset(args, config)
        train_loader = data.DataLoader(
            dataset,
            batch_size=config.training.batch_size,
            shuffle=True,
            num_workers=config.data.num_workers,
        )
        model = ConditionalModel(config)
        model = model.to(self.device)

        discriminator = Discriminator(config)
        discriminator = discriminator.to(self.device)

        optimizer = get_optimizer(self.config, model.parameters())
        optimizer_d = get_optimizer(self.config, discriminator.parameters())

        d_criterion = nn.BCEWithLogitsLoss()

        
        if self.config.model.ema:
            ema_helper = EMA(mu=self.config.model.ema_rate)
            ema_helper.register(model)
        else:
            ema_helper = None

        start_epoch, step = 0, 0
        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))
            model.load_state_dict(states[0])

            states[1]["param_groups"][0]["eps"] = self.config.optim.eps
            optimizer.load_state_dict(states[1])
            start_epoch = states[2]
            step = states[3]
            if self.config.model.ema:
                ema_helper.load_state_dict(states[4])

        for epoch in range(start_epoch, self.config.training.n_epochs):
            data_start = time.time()
            data_time = 0
            for i, x_0 in enumerate(train_loader):
                n = x_0.size(0)
                data_time += time.time() - data_start
                model.train()
                discriminator.train()
                step += 1

                # antithetic sampling
                t = torch.randint(
                    low=1, high=self.num_timesteps, size=(n // 2 + 1,)
                ).to(self.device)
                t = torch.cat([t, self.num_timesteps - t], dim=0)[:n]
                sampled_step_list = t.unique()

                # true samples
                x_0 = x_0.to(self.device)
                x_t1 = q_sample(x_0, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, t - 1)
                x_t = q_sample_next(x_t1, self.betas_sqrt, self.one_minus_betas_sqrt, t)                
                
                # generated samples
                x_0_prime = model(x_t, t)
                x_t1_prime = q_sample_posterior(x_0_prime, x_t, self.posterior_mean_coeff_1, self.posterior_mean_coeff_2, self.posterior_variance, t)
                

                loss = (x_t1 - x_t1_prime).square().mean() 

                # implicit term
                x_si = torch.randn(x_0.shape[0], self.config.model.data_dim).to(self.device)
                t_max = torch.tensor([self.num_timesteps-1]).to(self.device)
                x_t_si = q_sample(x_0, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, t_max)
                
                x_t_prime_si = model(x_si, t_max+1)


                x_true_logits = discriminator(x_t_si)
                x_fake_logits = discriminator(x_t_prime_si)
                d_loss = d_criterion(x_true_logits, torch.ones_like(x_true_logits)) + d_criterion(x_fake_logits, torch.zeros_like(x_fake_logits))

                tb_logger.add_scalar("discriminator loss", d_loss, global_step=step)

                optimizer_d.zero_grad()
                
                d_loss /= self.num_timesteps
                d_loss.backward()

                optimizer_d.step() 

                x_si = torch.randn(x_0.shape[0], self.config.model.data_dim).to(self.device)
                x_0_prime_si = model(x_si, t_max+1)
                x_t_prime_si = q_sample(x_0_prime_si, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, t_max)

                x_fake_logits = discriminator(x_t_prime_si)
                loss_T = d_criterion(x_fake_logits, torch.ones_like(x_fake_logits))

                tb_logger.add_scalar("implicit loss", loss_T, global_step=step)
                tb_logger.add_scalar("loss", loss, global_step=step)

                logging.info(
                    f"step: {step}, Noise Estimation loss: {loss.item()}, Prior loss: {loss_T.item()}, Discriminator loss: {d_loss.item()}, data time: {data_time / (i+1)}"
                )

                optimizer.zero_grad()
                
                loss+= loss_T / self.num_timesteps
                loss.backward()

                try:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.optim.grad_clip
                    )
                except Exception:
                    pass
                optimizer.step()
                
                if self.config.model.ema:
                    ema_helper.update(model)
                
                

                if epoch % self.config.training.snapshot_freq == 0 or step == 1:
                    states = [
                        model.state_dict(),
                        optimizer.state_dict(),
                        epoch,
                        step,
                    ]
                    if self.config.model.ema:
                        states.append(ema_helper.state_dict())

                    torch.save(
                        states,
                        os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)),
                    )
                    torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))

                if step % self.config.training.validation_freq == 0 or step == 1:
                    with torch.no_grad():
                        if epoch == start_epoch:
                            fig, axs = plt.subplots(1, self.num_timesteps, figsize=(self.num_timesteps*5.5, 5))
                            x0_seq = []
                            for i in range(0, self.num_timesteps):
                                cur_x = q_sample(dataset, self.alphas_bar_sqrt.cpu(), self.one_minus_alphas_bar_sqrt.cpu(), torch.tensor([i])).detach().cpu()
                                x0_seq.append(cur_x)
                                axs[self.num_timesteps - i - 1].scatter(cur_x[:, 0], cur_x[:, 1], s=10);
                                axs[self.num_timesteps - i - 1].set_title('$p(\mathbf{x}_{'+str(i)+'})$',fontsize=25)
                            tb_logger.add_figure('data', fig, step)

                        x_seq = p_sample_loop(model, (dataset.shape[0], self.config.model.data_dim), self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.posterior_mean_coeff_1, self.posterior_mean_coeff_2, self.posterior_variance, self.num_timesteps, self.config.model.z_dim)
                        fig, axs = plt.subplots(1, self.num_timesteps+1, figsize=((self.num_timesteps+1)*5.5, 5))
                        cur_x = x_seq.pop(0).detach().cpu()
                        axs[0].scatter(cur_x[:, 0], cur_x[:, 1], s=10);
                        axs[0].set_title('$p({z}_\mathbf{prior})$',fontsize=25)
                        for i in range(0, self.num_timesteps):
                            cur_x = x_seq[i].detach().cpu()
                            kl = kld(x0_seq[self.num_timesteps - i - 1].numpy(), cur_x.numpy())
                            kl_x0 = kld(x0_seq[0].numpy(), cur_x.numpy())
                            axs[i+1].scatter(cur_x[:, 0], cur_x[:, 1], s=10);
                            axs[i+1].set_title('$q(\mathbf{x}_{'+str(self.num_timesteps - i - 1)+'})$',fontsize=25)
                            axs[i+1].set_xlabel('KL($p(x_{})||q(x_{})$)={:.2f}\nKL($p(x_0)||q(x_{})$)={:.2f}'.format(self.num_timesteps - i - 1, self.num_timesteps - i - 1, kl, self.num_timesteps - i - 1, kl_x0), fontsize=25)
                        tb_logger.add_figure('samples', fig, step)
                        fig.savefig(os.path.join(args.im_path, 'samples_T{}_{}.pdf'.format(self.num_timesteps, step)))


                data_start = time.time()

                    
    def sample(self):
        model = ConditionalModel(self.config)

        if getattr(self.config.sampling, "ckpt_id", None) is None:
            states = torch.load(
                os.path.join(self.args.log_path, "ckpt.pth"),
                map_location=self.config.device,
            )
            ckpt_id = 'last'
        else:
            states = torch.load(
                os.path.join(
                    self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pth"
                ),
                map_location=self.config.device,
            )
            ckpt_id = self.config.sampling.ckpt_id
        model = model.to(self.device)
        model.load_state_dict(states[0], strict=True)

        if self.config.model.ema:
            ema_helper = EMA(mu=self.config.model.ema_rate)
            ema_helper.register(model)
            ema_helper.load_state_dict(states[-1])
            ema_helper.ema(model)
        else:
            ema_helper = None
        
        model.eval()

        if self.config.data.dataset == 'swiss_roll':
            x_0 = SwissRoll().sample(self.config.sampling.sampling_size)
        elif self.config.data.dataset == 'moons':
            x_0 = Moons().sample(self.config.sampling.sampling_size)
        elif self.config.data.dataset == '8gaussians':
            x_0 = Gaussians().sample(self.config.sampling.sampling_size - self.config.sampling.sampling_size%8, mode=8)
        elif self.config.data.dataset == '25gaussians':
            x_0 = Gaussians().sample(self.config.sampling.sampling_size - self.config.sampling.sampling_size%25, mode=25)

        x0_seq = []
        with torch.no_grad():
            fig, axs = plt.subplots(1, self.num_timesteps, figsize=(self.num_timesteps*8.5, 8))
            for i in range(0, self.num_timesteps):
                cur_x = q_sample(x_0, self.alphas_bar_sqrt.cpu(), self.one_minus_alphas_bar_sqrt.cpu(), torch.tensor([i])).detach().cpu()
                x0_seq.append(cur_x)
                axs[i].scatter(cur_x[:, 0], cur_x[:, 1], s=10);
                axs[i].set_title('$q(\mathbf{x}_{'+str(i)+'})$',fontsize=25)
                axs[i].tick_params(axis='x', labelsize=20)
                axs[i].tick_params(axis='y', labelsize=20)
            fig.savefig(os.path.join(self.args.im_path, 'diffusion_samples_T{}_{}.pdf'.format(self.num_timesteps, ckpt_id)))
            
            x_seq = p_sample_loop(model, (self.config.sampling.sampling_size, self.config.model.data_dim), self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.posterior_mean_coeff_1, self.posterior_mean_coeff_2, self.posterior_variance, self.num_timesteps, self.config.model.z_dim)[1:]
            fig, axs = plt.subplots(2, self.num_timesteps, figsize=(self.num_timesteps*8.5, 8*2.2))
            for i in range(0, self.num_timesteps):
                cur_x = x_seq[i].detach().cpu()
                kl = kld(x0_seq[self.num_timesteps - i - 1].numpy(), cur_x.numpy())
                kl_x0 = kld(x0_seq[0].numpy(), cur_x.numpy())
                axs[0,self.num_timesteps - i - 1].scatter(x0_seq[self.num_timesteps - i - 1][:, 0], x0_seq[self.num_timesteps - i - 1][:, 1], s=10)
                axs[0,self.num_timesteps - i - 1].tick_params(axis='x', labelsize=20)
                axs[0,self.num_timesteps - i - 1].tick_params(axis='y', labelsize=20);
                axs[0,self.num_timesteps - i - 1].set_title('$q(\mathbf{x}_{'+str(self.num_timesteps - i - 1)+'})$',fontsize=25)
                axs[1,self.num_timesteps - i - 1].scatter(cur_x[:, 0], cur_x[:, 1], s=10);
                axs[1,self.num_timesteps - i - 1].set_title('$p_{\\theta}(\mathbf{x}_{'+str(self.num_timesteps - i - 1)+'})$',fontsize=25)
                axs[1,self.num_timesteps - i - 1].set_xlabel('KL($q(x_{})||p(x_{})$)={:.2f}\nKL($q(x_0)||p(x_{})$)={:.2f}'.format(self.num_timesteps - i - 1, self.num_timesteps - i - 1, kl, self.num_timesteps - i - 1, kl_x0), fontsize=25)
                axs[1,self.num_timesteps - i - 1].tick_params(axis='x', labelsize=20)
                axs[1,self.num_timesteps - i - 1].tick_params(axis='y', labelsize=20);
            fig.savefig(os.path.join(self.args.im_path, 'generated_samples_T{}_{}.pdf'.format(self.num_timesteps, ckpt_id)))

            fig, axs = plt.subplots(2, self.num_timesteps, figsize=(self.num_timesteps*8.5, 8*2.2))
            for i in range(0, self.num_timesteps):
                cur_x = x_seq[i].detach().cpu()
                kl = kld(x0_seq[self.num_timesteps - i - 1].numpy(), cur_x.numpy())
                kl_x0 = kld(x0_seq[0].numpy(), cur_x.numpy())
                heatmap_x, _, _ = np.histogram2d(x0_seq[self.num_timesteps - i - 1][:,0].numpy(), x0_seq[self.num_timesteps - i - 1][:,1].numpy(), bins=100)
                axs[0,self.num_timesteps - i - 1].imshow(heatmap_x.T);
                axs[0,self.num_timesteps - i - 1].set_title('$q(\mathbf{x}_{'+str(self.num_timesteps - i - 1)+'})$',fontsize=25)
                axs[0,self.num_timesteps - i - 1].axis('off')
                axs[0,self.num_timesteps - i - 1].tick_params(axis='x', labelsize=20)
                axs[0,self.num_timesteps - i - 1].tick_params(axis='y', labelsize=20);
                heatmap, _, _ = np.histogram2d(cur_x[:,0].numpy(), cur_x[:,1].numpy(), bins=100)
                axs[1,self.num_timesteps - i - 1].imshow(heatmap.T);
                axs[1,self.num_timesteps - i - 1].set_title('$p_{\\theta}(\mathbf{x}_{'+str(self.num_timesteps - i - 1)+'})$',fontsize=25)
                axs[1,self.num_timesteps - i - 1].axis('off')
                axs[1,self.num_timesteps - i - 1].set_xlabel('KL($q(x_{})||p(x_{})$)={:.2f}\nKL($q(x_0)||p(x_{})$)={:.2f}'.format(self.num_timesteps - i - 1, self.num_timesteps - i - 1, kl, self.num_timesteps - i - 1, kl_x0), fontsize=25)
                axs[1,self.num_timesteps - i - 1].tick_params(axis='x', labelsize=20)
                axs[1,self.num_timesteps - i - 1].tick_params(axis='y', labelsize=20);
            fig.savefig(os.path.join(self.args.im_path, 'generated_distribution_T{}_{}.pdf'.format(self.num_timesteps, ckpt_id)))
