import lightning as L
import torch
from torch import nn
from torch.nn import functional as F
import os
from torch.utils.data import Dataset, RandomSampler, DataLoader
import h5py
import warnings
from torch.utils.data import random_split
from lightning.pytorch.profilers import SimpleProfiler
from torch.nn.utils.rnn import pad_sequence
from lightning.pytorch.loggers import WandbLogger
from typing import Union, List
import wandb
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsRNN
from rsl_rl.addons.invdynamics.inv_dynamics_dataset_paths import dataset_paths


def inv_dynamics_sequence_collate_fn(batch):
    """
    batch: List of tuple (inv_input, actions), both are tensors shaped [seq_len, dim].
    """
    inv_input_seq, actions_seq = zip(*batch)

    # Pad sequences to the length of the longest one in the batch
    # Result: [batch_size, max_seq_len, dim]
    padded_input_seq = pad_sequence(inv_input_seq, batch_first=True)  # padding value is 0 by default
    padded_actions_seq = pad_sequence(actions_seq, batch_first=True)  # padding value is 0 by default

    # Create mask: 1 for real tokens, 0 for padding
    # We use sequence lengths to create the mask
    lengths = torch.tensor([seq.size(0) for seq in inv_input_seq])
    max_len = padded_input_seq.size(1)
    mask = torch.arange(max_len).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    mask = mask.to(torch.float32)  # or bool depending on use

    return padded_input_seq, padded_actions_seq, mask


class INVSequenceDataset(Dataset):
    def __init__(self, h5_path):
        self.data = []  # List of (x_t, x_tp1, a_tp)
        
        # Read and store the entire dataset into memory
        with h5py.File(h5_path, "r") as f:
            for group_name in f:
                group = f[group_name]
                if "inv_input" in group and "actions" in group:
                    inv_input = torch.tensor(group["inv_input"][...])  # shape: (T, ...)
                    actions = torch.tensor(group["actions"][...])      # shape: (T-1, ...)

                    self.data.append((inv_input, actions))

    def __len__(self):
        return len(self.data)
    
    def len_timesteps(self):
        total_timesteps = sum(inv_input.shape[0] for inv_input, _ in self.data)
        return total_timesteps

    def __getitem__(self, idx):
        return self.data[idx]


class P4RLSequenceLightningModule(L.LightningModule):
    def __init__(self, model, mode="inv", enable_noise = False):
        super().__init__()
        self.save_hyperparameters() 
        self.model: InvDynamicsRNN = model
        self.error_per_epoch = []
        self.error_accumulated = 0.0
        self.step_counter = 0
        self.automatic_optimization = False # to manually control the optimization steps
        self.mode = mode  # "inv" for inverse dynamics, "fwd" for forward dynamics
        self.enable_noise = enable_noise
        if enable_noise:
            raise NotImplementedError("Adding noise not allowed yet because the target will also be corrupted")

        # TODO: grad norm clipping value should be set as a hyperparameter

        self.noise_magnitude = torch.tensor( [0.1]*3 + [0.2]*3 + [0.05]*3 + [0.01]*12 + [1.5]*12) # make sure the order is lin_vel, ang_vel, gravity_vector, joint_angles,  joing_vels
        # noise not in effect

        assert mode in ["inv", "fwd"], "P4RLSequenceLightningModule: mode must be either 'inv' or 'fwd'"
    
    def on_train_epoch_start(self):
        self.error_accumulated = 0.0
        self.step_counter = 0


    def cal_loss_inv(self, pred, x, a, mask):
        # Inverse dynamics loss
        loss = F.l1_loss(pred, a[:, 1:, :], reduction='none').mean(-1)
        m = mask[:, 1:]  # Mask for valid timesteps, excluding the first timestep
        loss_m = loss * m # Apply mask to the loss
        loss_mean = loss_m.sum() / m.sum() if m.sum() > 0 else torch.tensor(0.0, device=self.device)  # Average loss over valid timesteps
        return loss_mean

    def cal_loss_fwd(self, pred, x, a, mask):
        # Forward dynamics loss
        delta_x = x[:, 1:, 9:21] - x[:, :-1, 9:21]  # Calculate delta states
        loss = F.l1_loss(pred, delta_x, reduction='none').mean(-1)
        m = mask[:, 1:]
        loss_m = loss * m # Apply mask to the loss
        loss_mean = loss_m.sum() / m.sum() if m.sum() > 0 else torch.tensor(0.0, device=self.device)  # Average loss over valid timesteps
        return loss_mean


    def training_step(self, batch, batch_idx):
        """
            batch: tuple of (x_t, a_tp, mask)
            (x_t, a_tp): each is [batch_size, seq_len, input_dim_states]
            mask: [batch_size, seq_len] - mask for the sequence, where 1 means valid and 0 means invalid
        
        """
        if self.mode == "inv":
            x, a, mask = batch
            if self.enable_noise:
                # Add noise to the input states
                noise = torch.randn_like(x) * self.noise_magnitude.to(x.device)
                x = x + noise
            pred = self.model(x, a, mask)
            loss = self.cal_loss_inv(pred, x, a, mask)

        elif self.mode == "fwd":
            x, a, mask = batch
            if self.enable_noise:
                # Add noise to the input states
                noise = torch.randn_like(x) * self.noise_magnitude.to(x.device)
                x = x + noise
            pred = self.model(x, a, mask)
            loss = self.cal_loss_fwd(pred, x, a, mask)

        optimizer: torch.optim.Optimizer = self.optimizers()

        optimizer.zero_grad()
        self.manual_backward(loss)
        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip_value)  # Clip gradients to avoid exploding gradients
        # torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0)  # Clip gradients to avo
        optimizer.step()

        self.error_accumulated += loss.item()
        self.step_counter += 1

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        # self.log("train_y_magnitude", avg_action_magnitude, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    @torch.no_grad()
    def validate_batch_detail(self, batch):
        if self.mode == "inv":
            x, a, mask = batch
            m = mask[:, 1:] 
            pred = self.model(x, a, mask)
            error = torch.abs(pred - a[:, 1:, :])[m].flatten()  # Calculate absolute error
            a_magnitude = torch.abs(a[:, 1:, :])[m].flatten()
        elif self.mode == "fwd":
            x, a, mask = batch
            m = mask[:, 1:] 
            pred = self.model(x, a, mask)
            delta_x = x[:, 1:, 9:21] - x[:, :-1, 9:21]  # Calculate delta states
            error = torch.abs(pred - delta_x)[m].flatten()
            a_magnitude = torch.abs(delta_x)[m].flatten()
        return error, a_magnitude
    
    def on_train_epoch_end(self):
        self.error_per_epoch.append(self.error_accumulated / self.step_counter if self.step_counter > 0 else 0.0)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-5)


def inv_dynamics_sequence_collate_fn(batch):
    """
    batch: List of tuple (inv_input, actions), both are tensors shaped [seq_len, dim].
    """
    inv_input_seq, actions_seq = zip(*batch)

    # Pad sequences to the length of the longest one in the batch
    # Result: [batch_size, max_seq_len, dim]
    padded_input_seq = pad_sequence(inv_input_seq, batch_first=True)  # padding value is 0 by default
    padded_actions_seq = pad_sequence(actions_seq, batch_first=True)  # padding value is 0 by default

    # Create mask: 1 for real tokens, 0 for padding
    # We use sequence lengths to create the mask
    lengths = torch.tensor([seq.size(0) for seq in inv_input_seq])
    max_len = padded_input_seq.size(1)
    mask = torch.arange(max_len).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    mask = mask.to(torch.bool)  # or bool depending on use

    return padded_input_seq, padded_actions_seq, mask


def reinitialized_and_train_model_rnn(model: InvDynamicsRNN, dataset: INVSequenceDataset, epochs: int = 10, 
                                      batch_size: int = 32, replacement: bool = False, mode: str = "inv",
                                      save_path = None):
    model.reinitialize_weights()
    sampler = RandomSampler(dataset, replacement=replacement, num_samples=len(dataset))
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, sampler=sampler, collate_fn=inv_dynamics_sequence_collate_fn)
    l_model = P4RLSequenceLightningModule(model=model, mode=mode, enable_noise=False)
    trainer = L.Trainer(max_epochs=epochs, log_every_n_steps=10)
    trainer.fit(model=l_model, train_dataloaders=dataloader)
    if save_path is not None:
        trainer.save_checkpoint(save_path)
    return l_model.error_per_epoch


def plot_epoch_errors(epoch_errors: List[List[float]], labels: List[str]):
    """Plot the training errors over epochs.
        epoch_errors
    """
    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 5))
    for i, errors in enumerate(epoch_errors):
        plt.plot(errors, label=f'{labels[i]}, final error: {errors[-1]:.2f}')
    plt.xlabel('Epoch')
    plt.ylabel('Error')
    plt.title('Training Errors Over Epochs')
    plt.legend()
    plt.grid()
    save_path = 'logs/analysis/plots/inv_epoch_errors_plot.png'
    plt.savefig(save_path, dpi=300)
    print(f"Plot saved as '{save_path}'") 



def train_single_model_rnn(inv_dynamics_cfg, save_path):

    # wandb_logger = WandbLogger(project=f"{mode}_dynamics_new")

    dataset_path = dataset_paths["Pedipulation Init (no random)"]
    dataset = INVSequenceDataset(h5_path=dataset_path)
    
    model: InvDynamicsRNN = eval(inv_dynamics_cfg["class_name"])(device="cuda", **inv_dynamics_cfg)

    full_dataset_size = len(dataset)
    print("full dataset trajectories number:", full_dataset_size)
    len_timesteps = dataset.len_timesteps()
    print("full dataset timesteps number:", len_timesteps)
    samples_number_needed = 10 * model.get_number_of_trainable_parameters()
    print("least samples number recommended for training:", samples_number_needed)

    epoch_error = reinitialized_and_train_model_rnn(model, dataset, mode =inv_dynamics_cfg["mode"], epochs=10, batch_size=1024, replacement=False, save_path=save_path)
    return epoch_error


def train_model_rnn_sweep():

    inv_dynamics_cfg = {
    "class_name": "InvDynamicsRNN",
    "dim_states": 33,
    "dim_actions": 12,
    "representation_dim": 256,
    "hidden_dim": 512,
    "num_layers": 1,
    "rnn_type": "LSTM", 
    "activation_name": "elu",
    "mode": "fwd",
    "weight_path": None,
    "finetune_frozen": False,
    }

    errors_to_plot = []
    errors_labels = []

    f"logs/pretrain/lightning/inv_vanilla_pedi.ckpt"

    for num_layers in [1,]:
        for hidden_dim in [1024]:
            save_path = f"logs/pretrain/lightning/{inv_dynamics_cfg['mode']}_vanilla_pedi_LSTM_{num_layers}layers_{hidden_dim}hidden.ckpt"
            label = f"{inv_dynamics_cfg['rnn_type']}_{num_layers}layers_{hidden_dim}hidden"
            inv_dynamics_cfg["num_layers"] = num_layers
            inv_dynamics_cfg["hidden_dim"] = hidden_dim
            print(f"Training model with num_layers={num_layers}, hidden_dim={hidden_dim}")
            epoch_error = train_single_model_rnn(inv_dynamics_cfg, save_path)
            errors_to_plot.append(epoch_error)
            errors_labels.append(label)

    plot_epoch_errors(errors_to_plot, errors_labels)
    print(f"model saved to {save_path}")

if __name__ == "__main__":
    train_model_rnn_sweep()
