import os
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
# from torchvision.utils import save_image

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import get_named_beta_schedule
from Scheduler import GradualWarmupScheduler
from sfa_lds import *
from torchdiffeq import odeint_adjoint as odeint

from dataloader.dataloader_lds import *

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy

# import torch.distributed as dist
# import torch.multiprocessing as mp
# from torch.nn.parallel import DistributedDataParallel as DDP
# import socket

torch.set_printoptions(precision=3)
torch.set_default_dtype(torch.float64)


def map_to_position(input_tensor, num_classes):
    """
    Maps each element in the input tensor to a one-hot encoded row
    with a 1 in the position specified by the value in the input tensor.
    
    Parameters:
    - input_tensor: A tensor of values between 0 and 9.
    
    Returns:
    - A new tensor where each row is a one-hot encoding with 1 in the position
      specified by the corresponding element in the input_tensor.
    """
    num_rows = input_tensor.size(0)
    output_tensor = torch.zeros((num_rows, num_classes))  # Create a tensor of zeros with 10 columns
    output_tensor[torch.arange(num_rows), input_tensor] = 1  # Set the index specified by input_tensor to 1
    
    return output_tensor

def scaledsigmoid(x):
    return torch.sigmoid(x) * 2 - 1


class PlotLogLikelihoodCallback(Callback):
    def __init__(self, save_path="loss_plot.png", log_keys=("log_likelihood_x", "log_likelihood_z")):
        """
        Callback to plot the log likelihoods logged during training for two variables (e.g., x and z).

        Args:
            log_keys (tuple): A tuple containing keys for the two log likelihoods.
        """
        super().__init__()
        self.save_path = os.path.join(save_path, f'llk.png')
        self.log_keys = log_keys
        self.log_likelihoods_x = []
        self.log_likelihoods_z = []

    def on_train_epoch_end(self, trainer, pl_module):
        """
        Called at the end of each training epoch.

        Args:
            trainer (Trainer): The PyTorch Lightning trainer instance.
            pl_module (LightningModule): The LightningModule being trained.
        """
        # Retrieve the logged values from the trainer's logger
        if self.log_keys[0] in trainer.callback_metrics:
            log_likelihood_x = trainer.callback_metrics[self.log_keys[0]].item()
            self.log_likelihoods_x.append(log_likelihood_x)
        
        if self.log_keys[1] in trainer.callback_metrics:
            log_likelihood_z = trainer.callback_metrics[self.log_keys[1]].item()
            self.log_likelihoods_z.append(log_likelihood_z)

        # Plot the log likelihoods
        plt.figure(figsize=(12, 6))

        # Subplot for log likelihood of x
        plt.subplot(1, 2, 1)
        plt.plot(self.log_likelihoods_x, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(x|z)")
        plt.title("Log Likelihood of x During Training")
        # plt.legend()
        plt.grid()

        # Subplot for log likelihood of z
        plt.subplot(1, 2, 2)
        plt.plot(self.log_likelihoods_z, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(z|x)")
        plt.title("Log Likelihood of z During Training")
        # plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.savefig(self.save_path)
        plt.close()


class PlotLossCallback(Callback):
    def __init__(self, save_path="loss_plot.png", update_interval=5):
        super().__init__()
        self.save_path = save_path
        self.update_interval = update_interval
        self.train_losses = []
        self.val_losses = []
        self.epochs = []

    def on_validation_end(self, trainer, pl_module):
        # Get current epoch
        current_epoch = trainer.current_epoch

        # Log training and validation loss
        train_loss = trainer.callback_metrics.get("train_loss")
        val_loss = trainer.callback_metrics.get("val_loss")

        if train_loss is not None and val_loss is not None:
            self.epochs.append(current_epoch)
            self.train_losses.append(train_loss.cpu().item())
            self.val_losses.append(val_loss.cpu().item())

        # Update the plot every `update_interval` epochs
        if current_epoch % self.update_interval == 0:
            self.plot_and_save()

    def plot_and_save(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.epochs, self.train_losses, label="Training Loss", marker="o")
        plt.plot(self.epochs, self.val_losses, label="Validation Loss", marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.save_path)
        plt.close()


# Define your LightningModule
class FlowMatchingLightningModule(pl.LightningModule):
    def __init__(self, vt: nn.Module, rt: nn.Module, prior, config, args):
        super().__init__()
        self.config = config
        self.args = args
        self.lr = self.config.optim.lr
        self.prior = prior
        self.vt = vt
        self.rt = rt
        if self.config.model.const:
            self.flow_matching_loss = FlowMatchingLossc(vt, rt, prior, alpha=self.config.training.alpha)
        else:
            self.flow_matching_loss = FlowMatchingLoss(vt, rt, prior, alpha=self.config.training.alpha)
        # self.current_epoch = 0

        self.automatic_optimization = False
        self.last_validation_batch = None


    def forward(self, n, x):
        print("x shape should be (S,1,p)", x.shape)
        # for generating data given class z (batched integer)
        z1 = self.prior.sample(n, self.config.data.S, device=self.device)
        z0 = self.rt.decodeS(z1, x)
        
        x1 = torch.randn(self.config.data.S, n, self.config.data.p, device=self.device)
        
        x0 = self.vt.decodeS(x1, z0)
        x0_np = x0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
        z0_np = z0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
        return x0_np, z0_np #, z0_np

    def training_step(self, batch, batch_idx):
        # print("train")
        X, y = batch  # Assuming the batch is the input data `x`
        # print("X", X[0])
        X = X.view(self.config.data.S, -1, self.config.data.p).to(dtype=torch.float64)
        
        loss = self.flow_matching_loss(X)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        g_opt, d_opt = self.optimizers()
        d_opt.zero_grad()
        g_opt.zero_grad()
        self.manual_backward(loss)
        d_opt.step()
        g_opt.step()

        self.clip_gradients(d_opt, gradient_clip_val=self.config.training.clipval, gradient_clip_algorithm="norm")
        self.clip_gradients(g_opt, gradient_clip_val=self.config.training.clipval, gradient_clip_algorithm="norm")

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X = X.view(self.config.data.S, -1, self.config.data.p).to(dtype=torch.float64)
        val_loss = self.flow_matching_loss(X)

        # Store the last batch for plotting
        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X, "y": y}
        
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
        
        return val_loss

    def on_train_epoch_end(self):
        sch1, sch2 = self.lr_schedulers()
        # sch.step()
        sch1.step()
        sch2.step()

    def on_validation_epoch_end(self):
        if self.last_validation_batch is not None:
            X = self.last_validation_batch["X"]
            y = self.last_validation_batch["y"]
            
            """ Snapshot sampling at the end of every epoch """
            # if self.config.training.snapshot_sampling:
            if self.current_epoch % self.config.training.snapshot_freq == 0:
                log_llk, log_post_z = self.eval_latent(X, y)
                # self.generate(X)

                self.log('val_log_post_z', log_post_z, on_step=False, on_epoch=True, sync_dist=True, logger=True)
                self.log('val_log_lik', log_llk, on_step=False, on_epoch=True, sync_dist=True, logger=True)

                print("log_post", log_post_z)
                print("log_lik", log_llk)

        # Clear the stored batch for next epoch
        self.last_validation_batch = None


    def configure_optimizers(self):
        # Define your optimizer
        # optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        # return optimizer
        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.vt.parameters(),
                ), 
            lr=self.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = self.config.training.n_epochs,
                            eta_min = 1e-6,
                            last_epoch = -1
                        )
        warmUpScheduler = GradualWarmupScheduler(
                                optimizer = optimizer,
                                multiplier = self.config.optim.multiplier,
                                warm_epoch = self.config.training.n_epochs // 10,
                                after_scheduler = cosineScheduler,
                                last_epoch = self.current_epoch
                            )
        optimizer_latent = torch.optim.AdamW(
            itertools.chain(
                # self.prior.parameters(),
                self.rt.parameters()
                ), 
            lr=self.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler_latent = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer_latent,
                            T_max = self.config.training.n_epochs,
                            eta_min = 1e-6,
                            last_epoch = -1
                        )
        warmUpScheduler_latent = GradualWarmupScheduler(
                                optimizer = optimizer_latent,
                                multiplier = self.config.optim.multiplier,
                                warm_epoch = self.config.training.n_epochs // 10,
                                after_scheduler = cosineScheduler_latent,
                                # after_scheduler = Scheduler,
                                last_epoch = self.current_epoch
                            )
        return [optimizer, optimizer_latent], [
                {'scheduler': warmUpScheduler,
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}, 
                {'scheduler': warmUpScheduler_latent, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}]

    def eval_latent(self, x, y):
        # Set models to evaluation mode
        self.vt.eval()
        self.rt.eval()
        # self.prior.eval()
        n = x.shape[1]
        with torch.no_grad():

            n = x.shape[1]
            # generate samples
            z1 = self.prior.sample(n, self.config.data.S, device=self.device) # S,n,p
            x1 = torch.randn(self.config.data.S, n, self.config.data.p, device=self.device)
            z0 = self.rt.decodeS(z1, x)
            x0 = self.vt.decodeS(x1, z0)

            log_lik = self.vt.log_probS(x,z0).mean()
            log_post_z = self.rt.log_probS(z0,x,0.,self.prior).mean()

            x0_np = x0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
            z0_np = z0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
            x_np = x.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
            y_np = y.cpu().detach().numpy()

            fig, axes = plt.subplots(2,2, figsize=(12, 5))
            # plt.hist2d(*x.T, bins=64)
            axes[0,0].imshow(x_np[10].T, cmap="gray", origin="lower", aspect=.5)
            axes[0,1].imshow(x0_np[10].T, cmap="gray", origin="lower", aspect=.5)
            # axes[1].plot(jnp.arange(self.config.data.S), z1_np[0])
            axes[1,0].plot(y_np[10][:,0], y_np[0][:,1])
            axes[1,0].set_xmargin(0)
            # axes[2].plot(jnp.arange(self.config.data.S), y1_np[0])
            axes[1,1].plot(z0_np[10][:,0], z0_np[0][:,1])
            axes[1,1].set_xmargin(0)
            axes[0,0].set_ylabel("real data")
            axes[0,1].set_ylabel("gen data")
            axes[1,0].set_ylabel("real latent")
            axes[1,1].set_ylabel("gen latent")
            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'image_eval_epoch_{self.current_epoch}.png'))
            plt.close()
        # Optionally, switch back to training mode after sampling
        self.vt.train()
        self.rt.train()
        # self.prior.train()
        return log_lik, log_post_z
        
    def test_step(self, batch, batch_idx):

        X, y = batch  # Assuming the batch is the input data `x`
        _X = X.view(self.config.data.S, -1, self.config.data.p).to(torch.float64)

        n = _X.shape[1]
        # generate samples
        z1 = self.prior.sample(n, self.config.data.S, device=self.device) # S,n,p
        x1 = torch.randn(self.config.data.S, n, self.config.data.p, device=self.device)
        z0 = self.rt.decodeS(z1, _X)
        x0 = self.vt.decodeS(x1, z0)

        x0_np = x0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
        z0_np = z0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
        x_np = X.cpu().detach().numpy()
        y_np = y.cpu().detach().numpy()

        if self.args.figure:
            fig, axes = plt.subplots(2,2, figsize=(12, 5))
            # plt.hist2d(*x.T, bins=64)
            axes[0,0].imshow(x_np[10].T, cmap="gray", origin="lower", aspect=.5)
            axes[0,1].imshow(x0_np[10].T, cmap="gray", origin="lower", aspect=.5)
            # axes[1].plot(jnp.arange(self.config.data.S), z1_np[0])
            axes[1,0].plot(y_np[10][:,0], y_np[0][:,1])
            axes[1,0].set_xmargin(0)
            # axes[2].plot(jnp.arange(self.config.data.S), y1_np[0])
            axes[1,1].plot(z0_np[10][:,0], z0_np[0][:,1])
            axes[1,1].set_xmargin(0)
            axes[0,0].set_ylabel("real data")
            axes[0,1].set_ylabel("gen data")
            axes[1,0].set_ylabel("real latent")
            axes[1,1].set_ylabel("gen latent")
            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'test_eval_gen.png'))
            plt.close()
        else:
            test_loss = self.flow_matching_loss(_X)
            self.log('test_loss', test_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

            # evaluate L2 distance between generated and truth
            rmse_obs = np.sqrt((((x0_np-x_np)**2).sum(-1)).mean(1)).mean()
            rmse_latent = np.sqrt((((z0_np-y_np)**2).sum(-1)).mean(1)).mean()

            self.log('test_latent_loss', rmse_latent, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
            self.log('test_obs_loss', rmse_obs, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)


class LDSRunner():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)
        # os.makedirs(self.args.tb_path)

        self.S = self.config.data.S
        
        # Initialize models `vt` and `rt`
        if self.config.model.const:
            self.vt = LLKc(
                self.config.data.p, self.config.flow.feature_dim, self.S, 
                hidden_features=[self.config.model.ngf]*3,
                fct=nn.Tanh(),
                )
        else:
            self.vt = LLK(
                self.config.data.p, self.config.flow.feature_dim, self.S, dsemb=self.config.model.dsemb, 
                hidden_features=[self.config.model.ngf]*1,
                fct=nn.Tanh(),
                )
        if self.config.flow.AR:
            self.rt = ARCNF(
                self.config.data.p, self.config.flow.feature_dim, self.S, dsemb=self.config.flow.dsemb, 
                hidden_features=[self.config.flow.ngf]*1,
                fct=nn.Tanh(),
                )
        else:
            self.rt = fullCNF(
                self.config.data.p, self.config.flow.feature_dim, self.S, dsemb=self.config.flow.dsemb, 
                num_hidden=128,
                hidden_features=[self.config.flow.ngf]*1,
                fct=nn.Tanh(),
                batch_norm=True
                )
        
        self.prior = LatentDynamicalSystem(self.config.flow.feature_dim) # close to vertex
        # self.prior = Dirichlet(torch.ones(self.config.flow.feature_dim))
        # load data
        self.dataset, self.val_dataset = get_lds(
            self.config.data.samplesize, self.config.data.n, self.config.data.p, self.S, 
            self.config.data.theta, "data", sin=self.config.data.sin, gen=True, plot=True)
        
        # Define the ModelCheckpoint callback
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=1,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )
        # Initialize the Trainer
        # print("using {}".format(self.config.device))
        if torch.cuda.is_available():
            accelerator='gpu'
            strategy="ddp"
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"

        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)
        plot_llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("val_log_lik", "val_log_post_z"))
        self.trainer = pl.Trainer(
            # gradient_clip_val=self.config.optim.clip_value,
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            callbacks=[plot_loss_callback, plot_llk_callback, self.checkpoint_callback],
            log_every_n_steps=1,
        )

        
    def train(self):
        
        train_dataloader = DataLoader(self.dataset, batch_size=self.config.training.batch_size, shuffle=True, persistent_workers=False,
                                num_workers=self.config.data.num_workers)
        val_dataloader = DataLoader(self.val_dataset, batch_size=self.config.training.batch_size, shuffle=False, persistent_workers=False,
                                 num_workers=self.config.data.num_workers, drop_last=True)
        # Initialize the Lightning model
        model = FlowMatchingLightningModule(self.vt, self.rt, self.prior, self.config, self.args)
        # Run the training loop
        if not self.args.resume_training:
            ckpt_path = None
        else:
            ckpt_path = self.checkpoint_callback.best_model_path


        self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)


    def sample(self):
        _, test_data = get_lds(
            500, self.config.data.n, self.config.data.p, self.S, 
            self.config.data.theta, "data", sin=self.config.data.sin, gen=True, plot=True)

        test_dataloader = DataLoader(test_data, batch_size=100, shuffle=False, persistent_workers=False,
                                 num_workers=self.config.data.num_workers, drop_last=True)
        # Initialize the Lightning model
        model = FlowMatchingLightningModule(self.vt, self.rt, self.prior, self.config, self.args)
        # Run the training loop
        ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.test(model, test_dataloader, ckpt_path=ckpt_path)

    def figure(self):
        _, test_data = get_lds(
            30, self.config.data.n, self.config.data.p, self.S, 
            self.config.data.theta, "data", sin=self.config.data.sin, gen=True, plot=True)

        test_dataloader = DataLoader(test_data, batch_size=30, shuffle=False, persistent_workers=False,
                                 num_workers=self.config.data.num_workers, drop_last=True)
        # Initialize the Lightning model
        model = FlowMatchingLightningModule(self.vt, self.rt, self.prior, self.config, self.args)
        # Run the training loop
        ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.test(model, test_dataloader, ckpt_path=ckpt_path)



