import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import *
from Scheduler import WarmUpScheduler
from gldsvae import *
from torchdiffeq import odeint_adjoint as odeint

from dataloader.dataloader_lds import *
from dataloader.dataloader_pendulum import *

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
import matplotlib.pyplot as plt


torch.set_printoptions(precision=3)
torch.autograd.set_detect_anomaly(True)

import gc
gc.collect()


def scaledsigmoid(x):
    return torch.sigmoid(x) * 2 - 1



class PlotLogLikelihoodCallback(Callback):
    def __init__(self, save_path="loss_plot.png", log_keys=("log_likelihood_x", "log_likelihood_z")):
        """
        Callback to plot the log likelihoods logged during training for two variables (e.g., x and z).

        Args:
            log_keys (tuple): A tuple containing keys for the two log likelihoods.
        """
        super().__init__()
        self.save_path = os.path.join(save_path, f'llk.png')
        self.log_keys = log_keys
        self.log_likelihoods_x = []
        self.log_likelihoods_z = []

    def on_train_epoch_end(self, trainer, pl_module):
        """
        Called at the end of each training epoch.

        Args:
            trainer (Trainer): The PyTorch Lightning trainer instance.
            pl_module (LightningModule): The LightningModule being trained.
        """
        # Retrieve the logged values from the trainer's logger
        if self.log_keys[0] in trainer.callback_metrics:
            log_likelihood_x = trainer.callback_metrics[self.log_keys[0]].item()
            self.log_likelihoods_x.append(log_likelihood_x)
        
        if self.log_keys[1] in trainer.callback_metrics:
            log_likelihood_z = trainer.callback_metrics[self.log_keys[1]].item()
            self.log_likelihoods_z.append(log_likelihood_z)

        # Plot the log likelihoods
        plt.figure(figsize=(12, 6))

        # Subplot for log likelihood of x
        plt.subplot(1, 2, 1)
        plt.plot(self.log_likelihoods_x, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(x|z)")
        plt.title("Log Likelihood of x During Training")
        # plt.legend()
        plt.grid()

        # Subplot for log likelihood of z
        plt.subplot(1, 2, 2)
        plt.plot(self.log_likelihoods_z, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Log p(z|x)")
        plt.title("Log Likelihood of z During Training")
        # plt.legend()
        plt.grid()

        plt.tight_layout()
        plt.savefig(self.save_path)
        plt.close()



class PlotLossCallback(Callback):
    def __init__(self, save_path="loss_plot.png", update_interval=5):
        super().__init__()
        self.save_path = save_path
        self.update_interval = update_interval
        self.train_losses = []
        self.val_losses = []
        self.epochs = []

    def on_validation_end(self, trainer, pl_module):
        # Get current epoch
        current_epoch = trainer.current_epoch

        # Log training and validation loss
        train_loss = trainer.callback_metrics.get("train_loss")
        val_loss = trainer.callback_metrics.get("val_loss")

        if train_loss is not None and val_loss is not None:
            self.epochs.append(current_epoch)
            self.train_losses.append(train_loss.cpu().item())
            self.val_losses.append(val_loss.cpu().item())

        # Update the plot every `update_interval` epochs
        if current_epoch % self.update_interval == 0:
            self.plot_and_save()

    def plot_and_save(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.epochs, self.train_losses, label="Training Loss", marker="o")
        plt.plot(self.epochs, self.val_losses, label="Validation Loss", marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.save_path)
        plt.close()


# Define your LightningModule
class GLDSVAELightningModule(pl.LightningModule):
    def __init__(self, decoder: nn.Module, encoder: nn.Module, prior, config, args):
        super().__init__()
        self.config = config
        self.args = args
        self.prior = prior
        self.decoder = decoder
        self.encoder = encoder
        # self.current_epoch = 0

        self.automatic_optimization = False
        self.last_validation_batch = None
        if self.args.config == "lds.yml":
            self.p = self.config.data.p
        else:
            self.p = self.config.data.p ** 2 * self.config.data.channel

        self.q = self.config.data.q
        self.S = self.config.data.S
        self.F = self.config.data.Stotal

    def setup(self, stage=None):
        self.elbo_loss = ELBOc(self.decoder, self.encoder, self.prior, 
            self.config.data.S, self.config.flow.feature_dim, const=self.config.model.const)

    def training_step(self, batch, batch_idx):
        # print("train")
        X, y, indices = batch  # Assuming the batch is the input data `x`
        # print("X", X[0])

        X = X.view(self.config.data.S, -1, self.p).to(torch.float64)
        indices = indices.view(self.config.data.S, -1)
        
        loss = self.elbo_loss(X, indices)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        g_opt, d_opt = self.optimizers()
        d_opt.zero_grad()
        g_opt.zero_grad()
        self.manual_backward(loss)
        d_opt.step()
        g_opt.step()

        self.clip_gradients(d_opt, gradient_clip_val=self.config.training.clipval, gradient_clip_algorithm="norm")

        return loss

    def validation_step(self, batch, batch_idx):
        X, y, indices = batch
        X = X.view(self.config.data.S, -1, self.p).to(torch.float64)
        indices = indices.view(self.config.data.S, -1)
        val_loss = self.elbo_loss(X, indices)

        # Store the last batch for plotting
        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X, "y": y, "indices": indices}
        
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
        
        return val_loss

    def on_train_epoch_end(self):
        sch1, sch2 = self.lr_schedulers()
        sch1.step()
        sch2.step()


    def on_validation_epoch_end(self):
        if self.last_validation_batch is not None:
            X = self.last_validation_batch["X"]
            y = self.last_validation_batch["y"]
            indices = self.last_validation_batch["indices"]
            
            """ Snapshot sampling at the end of every epoch """
            # if self.config.training.snapshot_sampling:
            if self.current_epoch % self.config.training.snapshot_freq == 0:
                log_llk, log_post_z = self.eval_latent(X, y, indices)

                self.log('val_log_lik', log_llk, on_step=False, on_epoch=True, sync_dist=True, logger=True)
                self.log('val_log_post_z', log_post_z, on_step=False, on_epoch=True, sync_dist=True, logger=True)


        # Clear the stored batch for next epoch
        self.last_validation_batch = None

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.decoder.parameters(),
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = self.config.training.n_epochs,
                            eta_min = 0,
                            last_epoch = -1
                        )
        warmUpScheduler = WarmUpScheduler(
            optimizer = optimizer, 
            lr_scheduler=cosineScheduler, 
            warmup_steps=self.config.training.n_epochs // 10, 
            warmup_start_lr=0.00005,
            len_loader=self.config.data.samplesize//self.config.training.batch_size
            )
        optimizer_latent = torch.optim.AdamW(
            itertools.chain(
                # self.prior.parameters(),
                self.encoder.parameters()
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)
        cosineScheduler_latent = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer_latent,
                            T_max = self.config.training.n_epochs,
                            eta_min = 0,
                            last_epoch = -1
                        )
        warmUpScheduler_latent = WarmUpScheduler(
            optimizer = optimizer_latent, 
            lr_scheduler=cosineScheduler_latent, 
            warmup_steps=self.config.training.n_epochs // 10, 
            warmup_start_lr=0.00005,
            len_loader=self.config.data.samplesize//self.config.training.batch_size
            )
        return [optimizer, optimizer_latent], [
                {'scheduler': warmUpScheduler, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}, 
                {'scheduler': warmUpScheduler_latent, 
                'monitor': 'train_loss',
                "interval":"epoch", 
                "frequency":1}]

    def forward(self, n, x, indices):
        z0 = self.encoder.sample(x, indices)
        x0 = self.decoder.sampleS(z0)

        if self.args.config == "lds.yml":
            x0_np = x0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
        else:
            x0_np = x0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.channel,self.config.data.p,self.config.data.p))
        # z1_np = z1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
        z0_np = z0.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))

        return x0_np, z0_np


    def eval_latent(self, x, y, indices):
        # Set models to evaluation mode
        self.decoder.eval()
        self.encoder.eval()
        # self.prior.eval()
        n = x.shape[1]
        with torch.no_grad():
            # find the latent z given data x
            z0 = self.prior.sample(n, self.config.data.S, device=self.device) # S,n,p
            z1 = self.encoder.sample(x, indices)
            x1 = self.decoder.sampleS(z1)

            # evaluate log probs
            log_llk = self.decoder.log_probS(z1, x).mean()
            log_post_z = self.encoder.log_prob(x, z1, indices).mean()

            # x1_np = x.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
            if self.args.config == "lds.yml":
                x1_np = x1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
                x_np = x.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
            else:
                x1_np = (x1).cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.channel,self.config.data.p,self.config.data.p))
                x_np = (x).cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.channel,self.config.data.p,self.config.data.p))
            
            z1_np = z1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
            y_np = y.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))

            plot_image_sequence_and_trajectory(x1_np[0], z1_np[0], figsize=(20,2))
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'image_eval_epoch_{self.current_epoch}.png'))
            plt.close()

            plot_image_sequence_and_trajectory(x_np[0], y_np[0], figsize=(20,2))
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'image_true.png'))
            plt.close()

        # Optionally, switch back to training mode after sampling
        self.decoder.train()
        self.encoder.train()
        # self.prior.train()
        return log_llk, log_post_z
        
    def test_step(self, batch, batch_idx):
        X, y, indices = batch  # Assuming the batch is the input data `x`
        _X = X.view(self.config.data.S, -1, self.p).to(torch.float64)
        indices = indices.view(self.config.data.S, -1) 

        n = _X.shape[1]
        # generate samples
        z0 = self.prior.sample(n, self.config.data.S, device=self.device) # S,n,p
        z1 = self.encoder.sample(_X, indices)
        x1 = self.decoder.sampleS(z1)

        # x1_np = x1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
        if self.args.config == "lds.yml":
            x1_np = x1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.p))
            x_np = X.cpu().detach().numpy()
        else:
            x1_np = (x1).cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.data.channel,self.config.data.p,self.config.data.p))
            x_np = (X).cpu().detach().numpy()

        z1_np = z1.cpu().detach().numpy().reshape((-1,self.config.data.S,self.config.flow.feature_dim))
        
        y_np = y.cpu().detach().numpy()

        if self.args.figure:
            self.eval_latent(_X, y, indices)

        else:

            test_loss = self.elbo_loss(_X)
            self.log('test_loss', test_loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

            # evaluate L2 distance between generated and truth
            rmse_obs = np.sqrt((((x1_np-x_np)**2).sum(-1)).mean(1)).mean()
            rmse_latent = np.sqrt((((z1_np-y_np)**2).sum(-1)).mean(1)).mean()

            self.log('test_latent_loss', rmse_latent, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)
            self.log('test_obs_loss', rmse_obs, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

class ShiftedTanh(torch.nn.Module):
    def __init__(self, a=1.0, b=1.0):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x):
        return torch.tanh(x) + self.b

class GLDSVAErunner():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)
        # os.makedirs(self.args.tb_path)

        if self.args.config == "lds.yml":
            self.p = self.config.data.p
        else:
            self.p = self.config.data.p ** 2 * self.config.data.channel

        self.S = self.config.data.S
        self.F = self.config.data.Stotal

        if self.config.model.const:
        	self.decoder = LLKc(
        		self.p, self.config.flow.feature_dim, self.S, 
                hidden_features=[512]*2,
                # fct=SiLU()
                fct=nn.Softplus(),
                # fct=ShiftedTanh()
                )
        else:
        	self.decoder = LLK(
        		self.p, self.config.flow.feature_dim, self.S,
        		dsemb=self.config.model.dsemb, 
        		hidden_features=[512]*0,
                # fct=SiLU()
        		)

        self.encoder = fullGaussRecogNet(
            self.config.flow.feature_dim, self.p, self.S, self.F,
            dsemb=self.config.flow.dsemb, 
            num_layers=1,
            num_hidden=32, # 32
            emb_dim=100,
            hidden_features=[512]*0,
            fct=nn.Tanh(),
        )

        self.prior = LatentDynamicalSystem(self.config.flow.feature_dim) # close to vertex
        # self.prior = Dirichlet(torch.ones(self.config.flow.feature_dim))
        # load data

        # Define the ModelCheckpoint callback
        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=1,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )
        # Initialize the Trainer
        # print("using {}".format(self.config.device))
        if self.config.device == "cuda":
            accelerator='gpu'
            strategy="ddp"
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"

        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)
        plot_llk_callback = PlotLogLikelihoodCallback(save_path=self.args.log_sample_path, log_keys=("val_log_lik", "val_log_post_z"))

        self.trainer = pl.Trainer(
            # gradient_clip_val=self.config.optim.clip_value,
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            callbacks=[plot_loss_callback, plot_llk_callback, self.checkpoint_callback],
            detect_anomaly=True
            # logger=self.config.tb_logger
        )


    def train(self):
        if self.args.config == "lds.yml":
            self.dataset, self.val_dataset = get_lds(
                self.config.data.samplesize, self.config.data.n, self.config.data.p, self.S, 
                self.config.data.theta, "data", sin=self.config.data.sin, gen=True, plot=True)
            train_dataloader = DataLoader(self.dataset, batch_size=self.config.training.batch_size, shuffle=True, persistent_workers=False,
                                    num_workers=self.config.data.num_workers)
            val_dataloader = DataLoader(self.val_dataset, batch_size=self.config.training.batch_size, shuffle=False, persistent_workers=False,
                                     num_workers=self.config.data.num_workers, drop_last=True)

        elif self.args.config == "pendulum.yml":
            self.length = self.config.data.Stotal # 100 # self.S + 1*(self.config.training.batch_size-1)
            train_dataloader, val_dataloader = get_pendulum_dataloader(self.config.data.samplesize, self.config.data.test_samplesize, 
                self.config.data.p, self.length, "data", 
                window_size=self.S, stride=1, batch_size=self.config.training.batch_size,
                batch_from_same_trajectory=self.config.data.samebatch,
                state_dim=self.config.data.n, gen=self.config.data.gen, shuffle=True, num_workers=self.config.data.num_workers)


        # Initialize the Lightning model
        model = GLDSVAELightningModule(self.decoder, self.encoder, self.prior, self.config, self.args)
        # Run the training loop
        if not self.args.resume_training:
            ckpt_path = None
        else:
            ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)

    def sample(self):

        if self.args.config == "lds.yml":
            self.dataset, self.val_dataset = get_lds(
                self.config.data.samplesize, self.config.data.n, self.config.data.p, self.S, 
                self.config.data.theta, "data", sin=self.config.data.sin, gen=True, plot=True)
            train_dataloader = DataLoader(self.dataset, batch_size=self.config.training.batch_size, shuffle=True, persistent_workers=False,
                                    num_workers=self.config.data.num_workers)
            val_dataloader = DataLoader(self.val_dataset, batch_size=self.config.training.batch_size, shuffle=False, persistent_workers=False,
                                     num_workers=self.config.data.num_workers, drop_last=True)

        elif self.args.config == "pendulum.yml":
            self.length = self.config.data.Stotal # 100 # self.S + 1*(self.config.training.batch_size-1)
            train_dataloader, val_dataloader = get_pendulum_dataloader(self.config.data.samplesize, self.config.data.test_samplesize, 
                self.config.data.p, self.length, "data", 
                window_size=self.S, stride=1, batch_size=self.config.training.batch_size,
                batch_from_same_trajectory=self.config.data.samebatch,
                state_dim=self.config.data.n, gen=self.config.data.gen, shuffle=True, num_workers=self.config.data.num_workers)


        # Initialize the Lightning model
        model = GLDSVAELightningModule(self.decoder, self.encoder, self.prior, self.config, self.args)
        # Run the training loop
        ckpt_path = self.checkpoint_callback.best_model_path

        self.trainer.test(model, val_dataloader, ckpt_path=ckpt_path)


    def figure(self):

        if self.args.config == "lds.yml":
            self.dataset, self.val_dataset = get_lds(
                self.config.data.samplesize, self.config.data.n, self.config.data.p, self.S, 
                self.config.data.theta, "data", sin=self.config.data.sin, gen=True, plot=True)
            train_dataloader = DataLoader(self.dataset, batch_size=self.config.training.batch_size, shuffle=True, persistent_workers=False,
                                    num_workers=self.config.data.num_workers)
            val_dataloader = DataLoader(self.val_dataset, batch_size=self.config.training.batch_size, shuffle=False, persistent_workers=False,
                                     num_workers=self.config.data.num_workers, drop_last=True)

        elif self.args.config == "pendulum.yml":
            self.length = self.config.data.Stotal # 100 # self.S + 1*(self.config.training.batch_size-1)
            train_dataloader, val_dataloader = get_pendulum_dataloader(self.config.data.samplesize, self.config.data.test_samplesize, 
                self.config.data.p, self.length, "data", 
                window_size=self.S, stride=1, batch_size=self.config.training.batch_size,
                batch_from_same_trajectory=self.config.data.samebatch,
                state_dim=self.config.data.n, gen=self.config.data.gen, shuffle=True, num_workers=self.config.data.num_workers)


        # Initialize the Lightning model
        model = GLDSVAELightningModule(self.decoder, self.encoder, self.prior, self.config, self.args)
        # Run the training loop
        ckpt_path =  self.checkpoint_callback.best_model_path

        self.trainer.test(model, val_dataloader, ckpt_path=ckpt_path)



        