import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
from sklearn.manifold import TSNE

import ot
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import *
from Scheduler import GradualWarmupScheduler
from gmvae import *
# from zuko.flow.gaussianization import GF

# from zuko.utils import odeint
from torchdiffeq import odeint_adjoint as odeint

from dataloader.dataloader_pinwheel import *

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy



torch.set_default_dtype(torch.float64)
torch.set_printoptions(precision=3)



def map_to_position(input_tensor, num_classes):
    """
    Maps each element in the input tensor to a one-hot encoded row
    with a 1 in the position specified by the value in the input tensor.
    
    Parameters:
    - input_tensor: A tensor of values between 0 and 9.
    
    Returns:
    - A new tensor where each row is a one-hot encoding with 1 in the position
      specified by the corresponding element in the input_tensor.
    """
    num_rows = input_tensor.size(0)
    output_tensor = torch.zeros((num_rows, num_classes))  # Create a tensor of zeros with 10 columns
    output_tensor[torch.arange(num_rows), input_tensor] = 1  # Set the index specified by input_tensor to 1
    
    return output_tensor

def first_three_eigen_proj(x):
    # Step 1: Compute the covariance matrix
    x_centered = (x - x.mean(0, keepdims=True))/np.sqrt(x.var(0, keepdims=True))
    cov_matrix = np.cov(x_centered, rowvar=False)

    # Step 2: Perform Eigen decomposition
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # Step 3: Sort eigenvalues and eigenvectors in descending order
    sorted_indices = np.argsort(eigenvalues)[::-1]
    first_two_eigenvectors = eigenvectors[:, sorted_indices[:3]]

    # Step 4: Project the array onto the first two Eigen directions
    projected_array = np.dot(x, first_two_eigenvectors)

    return projected_array

def logit_trans(data, eps=1e-6):
    data = eps + (1 - 2 * eps) * data
    return torch.log(data) - torch.log1p(-data)

class PlotLogLikelihoodCallback(Callback):
    def __init__(self, save_path="loss_plot.png", log_keys=("log_likelihood_x", "log_likelihood_z")):
        """
        Callback to plot the log likelihoods logged during training for two variables (e.g., x and z).

        Args:
            log_keys (tuple): A tuple containing keys for the two log likelihoods.
        """
        super().__init__()
        self.save_path = os.path.join(save_path, f'llk.png')
        self.log_keys = log_keys
        self.log_likelihoods_x = []
        self.log_likelihoods_z = []

    def on_train_epoch_end(self, trainer, pl_module):
        """
        Called at the end of each training epoch.

        Args:
            trainer (Trainer): The PyTorch Lightning trainer instance.
            pl_module (LightningModule): The LightningModule being trained.
        """
        # Retrieve the logged values from the trainer's logger
        if self.log_keys[0] in trainer.callback_metrics:
            log_likelihood_x = trainer.callback_metrics[self.log_keys[0]].item()
            self.log_likelihoods_x.append(log_likelihood_x)
        
        if self.log_keys[1] in trainer.callback_metrics:
            log_likelihood_z = trainer.callback_metrics[self.log_keys[1]].item()
            self.log_likelihoods_z.append(log_likelihood_z)

        # Plot the log likelihoods
        plt.figure(figsize=(12, 6))

        # Subplot for log likelihood of x
        plt.subplot(1, 2, 1)
        plt.plot(self.log_likelihoods_x, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(x|z)")
        plt.title("Log Likelihood of x During Training")
        # plt.legend()
        plt.grid()

        # Subplot for log likelihood of z
        plt.subplot(1, 2, 2)
        plt.plot(self.log_likelihoods_z, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(z|x)")
        plt.title("Log Likelihood of z During Training")
        # plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.savefig(self.save_path)
        plt.close()



class PlotLossCallback(Callback):
    def __init__(self, save_path="loss_plot.png", update_interval=5):
        super().__init__()
        self.save_path = save_path
        self.update_interval = update_interval
        self.train_losses = []
        self.val_losses = []
        self.epochs = []

    def on_validation_end(self, trainer, pl_module):
        # Get current epoch
        current_epoch = trainer.current_epoch

        # Log training and validation loss
        train_loss = trainer.callback_metrics.get("train_loss")
        val_loss = trainer.callback_metrics.get("val_loss")

        if train_loss is not None and val_loss is not None:
            self.epochs.append(current_epoch)
            self.train_losses.append(train_loss.cpu().item())
            self.val_losses.append(val_loss.cpu().item())

        # Update the plot every `update_interval` epochs
        if current_epoch % self.update_interval == 0:
            self.plot_and_save()

    def plot_and_save(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.epochs, self.train_losses, label="Training Loss", marker="o")
        plt.plot(self.epochs, self.val_losses, label="Validation Loss", marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.save_path)
        plt.close()


class GMVAELightningModule(pl.LightningModule):
    def __init__(self, decoder: nn.Module, encoderpi: nn.Module, encoderz: nn.Module, priorz, config, args):
        super().__init__()
        self.config = config
        self.args = args
        
        self.decoder = decoder
        self.encoderz = encoderz
        self.encoderpi = encoderpi
        self.priorz = priorz
        self.k, self.d = self.config.flow.k_dim, self.config.flow.z_dim
        self.p = self.config.data.size
        # Register buffers for priorz parameters

        self.automatic_optimization = False
        self.last_validation_batch = None


    def setup(self, stage=None):
    	self.elbo_loss = ELBO(self.decoder, self.encoderz, self.encoderpi, self.priorz, self.k)


    def training_step(self, batch, batch_idx):
        # print("train")
        X, y = batch  # Assuming the batch is the input data `x`
        X = X.to(dtype=torch.float64)
        # print("sample", X[0])
        # print("X", X[0])
        loss = self.elbo_loss(X)

        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        g_opt = self.optimizers()
        g_opt.zero_grad()

        self.manual_backward(loss)
        g_opt.step()

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X = X.to(dtype=torch.float64)
        val_loss = self.elbo_loss(X)

        # sample from posterior
        zc, zcidx, pic = self.sample(X)

        # evaluate log likelihood of z|x and x|z
        log_post_z = self.encoderz.log_prob(zcidx, X, zc).mean()
        log_lik = self.decoder.log_prob(zc, X).mean()


        # Store the last batch for plotting
        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X, "y": y}

        self.log('val_loss', val_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
        self.log('val_log_post_z', log_post_z, on_step=False, on_epoch=True, sync_dist=True, logger=True)
        self.log('val_log_lik', log_lik, on_step=False, on_epoch=True, sync_dist=True, logger=True)

        return val_loss

    def on_train_epoch_end(self):
        # manual scheduler step
        
        sch = self.lr_schedulers()
        sch.step()

    def on_validation_epoch_end(self):
        if self.last_validation_batch is not None:
            X = self.last_validation_batch["X"]
            y = self.last_validation_batch["y"]


            if self.current_epoch % self.config.training.snapshot_freq == 0:
                print("sampling")
                self.generate(X, y)

        # Clear the stored batch for next epoch
        self.last_validation_batch = None

    def configure_optimizers(self):
        # Define your optimizer
        # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        # return optimizer
        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.decoder.parameters(),
                self.encoderz.parameters(),
                self.encoderpi.parameters(),
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)
        
        # return optimizer

        # return optimizer
        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = self.config.training.n_epochs // 10 * 9,
                            eta_min = 0,
                            last_epoch = -1
                        )
        warmUpScheduler = GradualWarmupScheduler(
                                optimizer = optimizer,
                                multiplier = self.config.optim.multiplier,
                                warm_epoch = self.config.training.n_epochs // 10,
                                after_scheduler = cosineScheduler,
                                # after_scheduler = Scheduler,
                                last_epoch = self.current_epoch
                            )
        
        return [optimizer], [{'scheduler': warmUpScheduler, 
                            'monitor': 'train_loss',
                            "interval":"epoch", 
                            "frequency":1}]

    def generate(self, x, y):
        # generate posterior predictive
        self.decoder.eval()
        self.encoderz.eval()
        self.encoderpi.eval()
        # self.priorz.eval()

        with torch.no_grad():
            z1idx = torch.randint(0,self.k,(self.config.sample.n_gen, ))
            z1idx = map_to_position(z1idx, self.k).to(self.device)
            # print(z1idx[:100])
            z1 = self.priorz.rsample(z1idx, None)
        
            xnew = self.decoder.sample(z1)
            zcidx, _, _ = self.encoderpi.rsample(xnew)
            zc = self.encoderz.rsample(zcidx, xnew)

            znp = zc[0].cpu().detach().numpy()
            xnp = xnew[0].cpu().detach().numpy()

            plt.figure()

            plt.scatter(xnp[:,0], xnp[:,1], c=znp, marker=".", cmap=plt.colormaps['gist_rainbow'])
            plt.colorbar()
            plt.savefig(os.path.join(self.args.log_sample_path, 'image_grid_{}.png'.format(self.current_epoch)))
            plt.close()

            # posterior sample from pit
            fig, axes = plt.subplots(self.config.data.n_classes, 3, figsize=(10,self.config.data.n_classes), sharex=True, constrained_layout=True)
            for row_idx, row_axes in enumerate(axes):
                # print(f"Row {row_idx}")
                m = len(row_axes)
                mask = y==row_idx
                if mask.sum() == 0:
                    pass
                else:
                    x_k = x[mask]
                    for col_idx, ax in enumerate(row_axes):   
                        if col_idx < len(x_k):
                            x0 = x_k[col_idx].repeat(100,1)
                            pic, logits, prob = self.encoderpi.rsample(x0)
                            
                            kidx = torch.argmax(pic, dim=-1).squeeze()
                            kidx_np = kidx.cpu().detach().numpy()

                            # ax.hist(z0, bins=20, alpha=0.7)
                            ax.hist(kidx_np, bins=20, alpha=0.7)
                        else:
                            ax.axis('off')
                        # ax.set_title(f"y={y[i]}")
                        if col_idx == 0:
                            ax.set_ylabel(f"y={row_idx}")

            # plt.tight_layout(pad=3.0)
            plt.savefig(os.path.join(self.args.log_sample_path, 'postpi_grid_epoch_{}.png'.format(self.current_epoch)))
            plt.close()

        self.decoder.train()
        self.encoderz.train()
        self.encoderpi.train()
        # self.priorz.train()


    def sample(self, x):
        zcidx, logits, prob = self.encoderpi.rsample(x)
        zc = self.encoderz.rsample(zcidx, x)

        return zc, zcidx, prob

    def test_generate(self, n):
        # generate n samples
        z1idx = torch.randint(0,self.k,(n, ))
        z1idx = map_to_position(z1idx, self.k).to(self.device)
        # print(z1idx[:100])
        z1 = self.priorz.rsample(z1idx, None)
    
        xnew = self.decoder.sample(z1)
        zcidx, _, _ = self.encoderpi.rsample(xnew)
        
        return xnew[0], zcidx[0]

    def test_step(self, batch, batch_idx):
        X, y = batch  # Assuming the batch is the input data `x`
        loss = self.elbo_loss(X.to(torch.float64))
        n = len(X)

        # compute distance between true sample to generated samples
        xnew, ynew = self.test_generate(n)
        xnew = xnew.cpu().detach().numpy()
        ynew = ynew.cpu().detach().numpy()
        ynew = np.argmax(ynew, axis=1)
        print("ynew", ynew.shape)
        xorg = X.cpu().detach().numpy()
        yorg = y.cpu().detach().numpy()
        w = 1/n * np.ones(n)

        M = ot.dist(xorg, xnew, "euclidean")
        M /= M.max() * 0.1
        d_emd = ot.emd2(w, w, M)

        self.log('test_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
        self.log('test_emd', d_emd, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)




class GMVAERunner():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)


        self.k, self.d = self.config.flow.k_dim, self.config.flow.z_dim
        self.p = self.config.data.size

        self.decoder = LLKNet(self.p, self.d, hidden_features=[self.config.model.ngf]*4
        	)
        self.encoderz = GaussianNet(self.k, self.d, self.p, hidden_features=[self.config.model.ngf]*2
            )
        self.encoderpi = CatNet(self.p, self.k, hidden_features=[self.config.model.ngf]*2
            # fct=nn.Tanh()
            )
        self.priorz = GaussianNet(self.k, self.d, hidden_features=[], batch_norm=True)


        # Define the ModelCheckpoint callback
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=5,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )
        # Initialize the Trainer

        if torch.cuda.is_available():
            accelerator='gpu'
            strategy="ddp"
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"

        # Add the callback
        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)
        plot_llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("val_log_lik", "val_log_post_z"))

        self.trainer = pl.Trainer(
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            logger=pl.loggers.TensorBoardLogger(save_dir=self.args.tb_path, name="lightning_logs"),  # Optional: log to TensorBoard
            # logger=self.config.tb_logger
            callbacks=[plot_loss_callback, plot_llk_callback],
        )

    def train(self):
        # load data
        dataset, test_dataset = get_dataset(self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize)
        train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                num_workers=self.config.data.num_workers)
        val_dataloader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=self.config.data.num_workers, drop_last=True)
        # Initialize the Lightning model
        model = GMVAELightningModule(self.decoder, self.encoderpi, self.encoderz, self.priorz, self.config, self.args)
        # Run the training loop
        if not self.args.resume_training:
            ckpt_path = None
        else:
            _ckpt_path = os.path.join(self.args.tb_path, "lightning_logs/version_0/checkpoints")
            ckpt_files = [f for f in 
                    os.listdir(_ckpt_path) if f.endswith('.ckpt')]
            ckpt_path = os.path.join(_ckpt_path, ckpt_files[-1])

        self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)

    def sample(self):
        dataset, test_dataset = get_dataset(self.config.data.n_classes, "data", 2000, 2000)

        test_dataloader = DataLoader(test_dataset, batch_size=500, shuffle=True,
                             num_workers=self.config.data.num_workers, drop_last=True)

        model = GMVAELightningModule(self.decoder, self.encoderpi, self.encoderz, self.priorz, self.config, self.args)

        _ckpt_path = os.path.join(self.args.tb_path, "lightning_logs/version_0/checkpoints")
        ckpt_files = [f for f in 
            os.listdir(_ckpt_path) if f.endswith('.ckpt')]
        ckpt_path = os.path.join(_ckpt_path, ckpt_files[-1])

        self.trainer.test(model, dataloaders=test_dataloader, ckpt_path=ckpt_path)


