import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import argparse
import itertools
import numpy as np
from tqdm import tqdm
import logging
import glob
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

# from torchvision.utils import save_image
import pandas as pd

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_

import torch.optim as optim
from utils import *
import matplotlib.gridspec as gridspec
from Scheduler import WarmUpScheduler # GradualWarmupScheduler

# from zuko.utils import odeint
from torchdiffeq import odeint_adjoint as odeint

from dataloader.dataloader_mnist import *
from dataloader.dataloader_cifar import *
from dataloader.dataloader_pendulum import *

from runner_mnist import PlotLogLikelihoodCallback, PlotLossCallback
from FM.fm import *

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.strategies import DDPStrategy

torch.set_default_dtype(torch.float64)
torch.set_printoptions(precision=3)

class FlowLightning(pl.LightningModule):
    def __init__(self, vt: nn.Module, config, args, stage="recon"):
        super().__init__()
        self.config = config
        self.args = args
        self.stage = stage

        self.vt = vt

        self.d = self.config.flow.z_dim
        self.c, self.p = self.config.data.channel, self.config.data.size
        self.sig_min = 1e-4

        self.automatic_optimization = False
        self.last_validation_batch = None

    def setup(self, stage=None):
    	self.prior = Normal(torch.zeros(self.c, self.p, self.p).to(self.device), torch.ones(self.c, self.p, self.p).to(self.device))
    	self.flow_matching_loss = FlowMatchingLoss_marginal(self.vt, self.prior, self.sig_min)

    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.vt.parameters(),
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)

        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer = optimizer,
            T_max = self.config.training.n_epochs,
            eta_min = 1e-5,
            last_epoch = -1
                        )
        warmUpScheduler = WarmUpScheduler(
            optimizer = optimizer, 
            lr_scheduler=cosineScheduler, 
            warmup_steps=self.config.training.n_epochs // 10, 
            warmup_start_lr=0.00005,
            len_loader=self.config.data.samplesize//self.config.training.batch_size
            )

        return {"optimizer": optimizer, "lr_scheduler": warmUpScheduler, "interval":"epoch", "frequency":1}


    def training_step(self, batch, batch_idx):
        X, y = batch
        if not self.config.model.cnn:
            X = X.view(-1, self.c*self.p**2)

        loss = self.flow_matching_loss(X)
        self.log('train_loss', loss, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        opt = self.optimizers()
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        if not self.config.model.cnn:
            X = X.view(-1, self.c*self.p**2)

        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X, "y": y}

        val_loss = self.flow_matching_loss(X)
        self.log('val_loss', val_loss, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        return val_loss

    def on_train_epoch_end(self):
        # manual scheduler step
        sch = self.lr_schedulers()
        sch.step()

    def on_validation_epoch_end(self):
        if self.current_epoch % self.config.training.snapshot_freq == 0:
            if self.last_validation_batch is not None:
                X = self.last_validation_batch["X"]
                y = self.last_validation_batch["y"]
            else:
                X = None
            self.sample(100, x=X)
            self.last_validation_batch = None

    def sample(self, n, x=None):
        self.vt.eval()
        with torch.no_grad():
            if x is not None:
                x1 = x[:n] * 0.5 + 0.5* self.prior.sample((n,)).to(self.device)
            else:
                x1 = self.prior.sample((n,)).to(self.device)
            x0 = self.vt.decode(x1)
            x0_np = inv_transform(x0).cpu().detach().numpy()

            fig, axes = plt.subplots(10, n//10, figsize=(10, 10))
            for i, ax in enumerate(axes.flatten()):
                x0_np_tr = np.transpose(x0_np[i], (1, 2, 0))
                ax.imshow(x0_np_tr, cmap='gray')
                ax.axis('off')  # hides ticks and labels

            plt.tight_layout()
            # plt.savefig(os.path.join(self.args.log_sample_path, '{}_sampels.png'.format(ckpt_file)))
            plt.savefig(os.path.join(self.args.log_sample_path, f'sampx_grid_epoch_{self.current_epoch}.png'))
            plt.close()

        self.vt.train()


class FMRunner():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)

        self.d = self.config.flow.z_dim
        self.c, self.p = self.config.data.channel, self.config.data.size


        self.vt = cnnLLK(self.p, freqs=10,
            in_ch=self.c,
        	fct=nn.Softplus(),
            mod_ch=self.config.model.mod_ch,
            unet=self.config.model.unet,
        	hidden_features = [self.config.model.ngf] * self.config.model.depth
            )

        if torch.cuda.is_available():
            accelerator='gpu'
            strategy=DDPStrategy(find_unused_parameters=True)
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"


        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=1,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )

        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)

        self.trainer = pl.Trainer(
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            callbacks=[self.checkpoint_callback, plot_loss_callback],
        )

    def train(self):
        print(self.args.config)
        if self.args.config == "mnist.yml":
            dataset, val_dataset, sampler, val_sampler = get_mnist(
                self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize)
        elif self.args.config == "cifar.yml":
            dataset, val_dataset, sampler, val_sampler = get_cifar(
                self.config.data.n_classes, "data", self.config.data.samplesize, self.config.data.test_samplesize, download=self.config.data.download)
        else:
            raise NotImplementedError("Either 'mnist.yml' or 'cifar.yml'")

        train_dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size,
                                num_workers=self.config.data.num_workers, sampler=sampler)
        val_dataloader = DataLoader(val_dataset, batch_size=self.config.training.batch_size, 
                                 num_workers=self.config.data.num_workers, sampler=val_sampler, drop_last=True)

        # train autodecoder with reconstruction
        model = FlowLightning(self.vt, self.config, self.args)
        if not self.args.resume_training:
            ckpt_path = None
        else:
            ckpt_path = self.checkpoint_callback.best_model_path
        self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)
        


# -------------------------------------------- LDS ------------------------------------------

def plot_image_sequence_and_trajectory(image_sequence1, image_sequence2, figsize=(12, 2)):
    """
    Plots a sequence of images in the first row and a latent trajectory
    in the second row, aligned by step.

    Args:
        image_sequence (list or np.ndarray or torch.Tensor):
            A list/array containing the sequence of images.
            Each image should be suitable for plt.imshow (e.g., HxW, HxWxC).
            If PyTorch tensors, expects shape like (N, C, H, W) or (N, H, W).
        latent_trajectory (np.ndarray or torch.Tensor):
            The latent trajectory data. Expected shape (n_steps, latent_dim).
        img_title (str, optional):
            Title for the image sequence row. Defaults to "Image Sequence".
        traj_title (str, optional):
            Y-axis label for the trajectory plot. Defaults to "Latent Trajectory (z)".
        figsize (tuple, optional):
            The figure size for the plot. Defaults to (12, 5).
    """
    # --- Input Validation and Setup ---
    if isinstance(image_sequence1, torch.Tensor):
        # If sequence is a single tensor (N, C, H, W) or (N, H, W), convert to list
        if image_sequence1.ndim >= 3:
             image_sequence1 = [img for img in image_sequence1] # Iterate over the 0th dimension
             image_sequence2 = [img for img in image_sequence2]

    n_steps = len(image_sequence1)
    if n_steps == 0:
        print("Input image sequence is empty, nothing to plot.")
        return

    # --- Plotting Setup ---
    fig = plt.figure(figsize=figsize)
    # Create a grid: 2 rows, n_steps columns. Give more height to images.
    gs = gridspec.GridSpec(2, n_steps, height_ratios=[1, 1], hspace=0.3)

    # --- Plot Image Sequence (First Row) ---
    for i in range(n_steps):
        ax_img = fig.add_subplot(gs[0, i])

        # Process image
        img = image_sequence1[i]
        # Handle PyTorch Tensors within the list
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
        # Handle channel dimension (e.g., C, H, W -> H, W, C or H, W if C=1)
        if img.ndim == 3 and img.shape[0] in [1, 3]: # Check if first dim is channel
             img = np.squeeze(img) # Remove channel dim if 1
             if img.ndim == 3: # If still 3D (RGB), move channel to last axis
                 img = np.transpose(img, (1, 2, 0))
        elif img.ndim == 3 and img.shape[-1] not in [1, 3]: # Check if last dim is not channel
             # Handle cases like (H, W, C) where C is not 1 or 3, assume grayscale
             if img.shape[-1] > 3:
                 img = img[..., 0] # Take the first channel if unsure

        # Display image
        ax_img.imshow(img, cmap='gray' if img.ndim == 2 else None, aspect='equal')
        ax_img.set_xticks([])
        ax_img.set_yticks([])
        if i == 0:
            ax_img.set_ylabel(r"$x$") # Set row title on the first image's y-axis

    # --- Plot Image Sequence (First Row) ---
    for i in range(n_steps):
        ax_img = fig.add_subplot(gs[1, i])

        # Process image
        img = image_sequence2[i]
        # Handle PyTorch Tensors within the list
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
        # Handle channel dimension (e.g., C, H, W -> H, W, C or H, W if C=1)
        if img.ndim == 3 and img.shape[0] in [1, 3]: # Check if first dim is channel
             img = np.squeeze(img) # Remove channel dim if 1
             if img.ndim == 3: # If still 3D (RGB), move channel to last axis
                 img = np.transpose(img, (1, 2, 0))
        elif img.ndim == 3 and img.shape[-1] not in [1, 3]: # Check if last dim is not channel
             # Handle cases like (H, W, C) where C is not 1 or 3, assume grayscale
             if img.shape[-1] > 3:
                 img = img[..., 0] # Take the first channel if unsure

        # Display image
        ax_img.imshow(img, cmap='gray' if img.ndim == 2 else None, aspect='equal')
        ax_img.set_xticks([])
        ax_img.set_yticks([])
        if i == 0:
            ax_img.set_ylabel(r"$x$") # Set row title on the first image's y-axis


    # Overall figure title (optional)
    # fig.suptitle("Image Sequence and Latent Trajectory", fontsize=14)

    # Adjust layout - may need manual tweaking depending on titles/legends
    # plt.tight_layout(rect=[0, 0.03, 0.95 if latent_dim > 1 else 1, 0.95]) # Leave space for legend/title
    fig.tight_layout()
    # plt.show()

class FlowLightning_LDS(pl.LightningModule):
    def __init__(self, vt: nn.Module, config, args, stage="recon"):
        super().__init__()
        self.config = config
        self.args = args
        self.stage = stage

        self.vt = vt

        
        self.c, self.p = self.config.data.channel, self.config.data.p
        self.S = self.config.data.S
        self.sig_min = 1e-4

        self.automatic_optimization = False
        self.last_validation_batch = None

    def setup(self, stage=None):
        self.flow_matching_loss = FlowMatchingLoss_marginal_LDS(self.vt, self.sig_min)

    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(
            itertools.chain(
                self.vt.parameters(),
                ), 
            lr=self.config.optim.lr,
            weight_decay=self.config.optim.weight_decay)

        cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer = optimizer,
            T_max = self.config.training.n_epochs,
            eta_min = 1e-5,
            last_epoch = -1
                        )
        warmUpScheduler = WarmUpScheduler(
            optimizer = optimizer, 
            lr_scheduler=cosineScheduler, 
            warmup_steps=self.config.training.n_epochs // 10, 
            warmup_start_lr=0.00005,
            len_loader=self.config.data.samplesize//self.config.training.batch_size
            )

        return {"optimizer": optimizer, "lr_scheduler": warmUpScheduler, "interval":"epoch", "frequency":1}


    def training_step(self, batch, batch_idx):
        X, y, indices = batch
        X = X.view(self.S, -1, self.c, self.p, self.p).to(dtype=torch.float64)
        indices = indices.view(self.S, -1)

        loss = self.flow_matching_loss(X)
        self.log('train_loss', loss, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        opt = self.optimizers()
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()

        return loss

    def validation_step(self, batch, batch_idx):
        X, y, indices = batch
        X = X.view(self.S, -1, self.c, self.p, self.p).to(dtype=torch.float64)
        indices = indices.view(self.S, -1)

        if batch_idx == self.trainer.num_val_batches[0] - 1:
            self.last_validation_batch = {"X": X, "y": y, "indices":indices}

        val_loss = self.flow_matching_loss(X)
        self.log('val_loss', val_loss, on_epoch=True, sync_dist=True, prog_bar=True, logger=True)

        return val_loss

    def on_train_epoch_end(self):
        # manual scheduler step
        sch = self.lr_schedulers()
        sch.step()

    def on_validation_epoch_end(self):
        if self.current_epoch % self.config.training.snapshot_freq == 0:
            if self.last_validation_batch is not None:
                X = self.last_validation_batch["X"]
                y = self.last_validation_batch["y"]
                indices = self.last_validation_batch["indices"]

            self.sample(100)
            self.last_validation_batch = None

    def sample(self, n, x=None):
        self.vt.eval()
        with torch.no_grad():
            x1 = torch.randn(self.S, n, self.c, self.p, self.p).to(self.device)
            x0 = self.vt.decode(x1)
            x0_np = inv_transform(x0).cpu().detach().numpy().reshape(n,self.S,self.c,self.p,self.p)

            plot_image_sequence_and_trajectory(x0_np[0], x0_np[1], figsize=(20,2))
            plt.savefig(os.path.join(self.args.log_sample_path, f'gen_video_{self.current_epoch}.png'))
            plt.close()

        self.vt.train()


class FMRunner_LDS():
    def __init__(self, args, config):
        self.args = args
        self.config = config
        args.log_sample_path = os.path.join(args.log_path, 'samples')
        os.makedirs(args.log_sample_path, exist_ok=True)

        
        self.c, self.p = self.config.data.channel, self.config.data.p
        self.S = self.config.data.S

        self.vt = GRUVelocityModel(self.S, self.c, self.p, 
                    num_time_frequencies=10,
                    rnn_hidden_dim=128,
                    num_rnn_layers=3,
                    frame_feature_dim=64,
                    hidden_features=[self.config.model.ngf]*self.config.model.depth,
                    fct=nn.Softplus()
                    )

        if torch.cuda.is_available():
            accelerator='gpu'
            strategy=DDPStrategy(find_unused_parameters=True)
            devices="auto"
        else:
            accelerator='cpu'
            devices="auto"
            strategy = "auto"


        self.checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Metric to monitor
            dirpath=self.args.log_path,  # Directory where checkpoints will be saved
            filename='best-checkpoint-{epoch:02d}-{val_loss:.2f}',  # Filename convention
            save_top_k=1,  # Only save the best model based on val_loss
            mode='min'  # Minimize the validation loss
        )

        plot_loss_callback = PlotLossCallback(save_path=os.path.join(self.args.log_sample_path, f'loss.png'), update_interval=1)

        self.trainer = pl.Trainer(
            max_epochs=self.config.training.n_epochs, 
            # accelerator='gpu',
            accelerator = accelerator,
            devices=devices,
            strategy=strategy,
            callbacks=[self.checkpoint_callback, plot_loss_callback],
        )

    def train(self):
        
        self.length = self.config.data.Stotal # 100 # self.S + 1*(self.config.training.batch_size-1)
        train_dataloader, val_dataloader = get_pendulum_dataloader(self.config.data.samplesize, self.config.data.test_samplesize, 
            self.p, self.length, "data", 
            window_size=self.S, stride=1, batch_size=self.config.training.batch_size,
            batch_from_same_trajectory=self.config.data.samebatch,
            state_dim=self.config.data.n, gen=self.config.data.gen, shuffle=True, num_workers=self.config.data.num_workers)

        # train autodecoder with reconstruction
        model = FlowLightning_LDS(self.vt, self.config, self.args)

        if not self.args.resume_training:
            ckpt_path = None
        else:
            ckpt_path = self.checkpoint_callback.best_model_path
        self.trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)

