import lightning as L
import torch
from torch import nn
from torch.nn import functional as F
import os
from torch.utils.data import Dataset, RandomSampler, DataLoader
import h5py
import warnings
from torch.utils.data import random_split
from lightning.pytorch.profilers import SimpleProfiler
from torch.nn.utils.rnn import pad_sequence
from lightning.pytorch.loggers import WandbLogger
import wandb
from rsl_rl.addons.invdynamics.inv_dynamics_module import build_mlp


def inv_dynamics_sequence_collate_fn(batch):
    """
    batch: List of tuple (inv_input, actions), both are tensors shaped [seq_len, dim].
    """
    inv_input_seq, actions_seq = zip(*batch)

    # Pad sequences to the length of the longest one in the batch
    # Result: [batch_size, max_seq_len, dim]
    padded_input_seq = pad_sequence(inv_input_seq, batch_first=True)  # padding value is 0 by default
    padded_actions_seq = pad_sequence(actions_seq, batch_first=True)  # padding value is 0 by default

    # Create mask: 1 for real tokens, 0 for padding
    # We use sequence lengths to create the mask
    lengths = torch.tensor([seq.size(0) for seq in inv_input_seq])
    max_len = padded_input_seq.size(1)
    mask = torch.arange(max_len).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    mask = mask.to(torch.float32)  # or bool depending on use

    return padded_input_seq, padded_actions_seq, mask


class INVSequenceDataset(Dataset):
    def __init__(self, h5_path):
        self.data = []  # List of (x_t, x_tp1, a_tp)
        
        # Read and store the entire dataset into memory
        with h5py.File(h5_path, "r") as f:
            for group_name in f:
                group = f[group_name]
                if "inv_input" in group and "actions" in group:
                    inv_input = torch.tensor(group["inv_input"][...])  # shape: (T, ...)
                    actions = torch.tensor(group["actions"][...])      # shape: (T-1, ...)

                    self.data.append((inv_input, actions))

    def __len__(self):
        return len(self.data)
    
    def len_timesteps(self):
        total_timesteps = sum(inv_input.shape[0] for inv_input, _ in self.data)
        return total_timesteps

    def __getitem__(self, idx):
        return self.data[idx]




class P4RLSequenceLightningModule(L.LightningModule):
    def __init__(self, model, dim_hidden, dim_output, chunk_steps=20, grad_norm_clip_value=1.0, mode="inv", enable_noise = False):
        super().__init__()
        self.save_hyperparameters() 
        self.model: nn.LSTM = model
        self.output_layer = nn.Linear(dim_hidden, dim_output)  # Assuming input_size is the output size
        self.error_per_epoch = []
        self.error_accumulated = 0.0
        self.step_counter = 0
        self.chunk_steps = chunk_steps
        self.automatic_optimization = False # to manually control the optimization steps
        self.grad_norm_clip_value = grad_norm_clip_value
        self.mode = mode  # "inv" for inverse dynamics, "fwd" for forward dynamics
        self.enable_noise = enable_noise

        self.noise_magnitude = torch.tensor( [0.1]*3 + [0.2]*3 + [0.05]*3 + [0.01]*12 + [1.5]*12) # make sure the order is lin_vel, ang_vel, gravity_vector, joint_angles,  joing_vels

        assert mode in ["inv", "fwd"], "P4RLSequenceLightningModule: mode must be either 'inv' or 'fwd'"

    def forward(self, x_t, x_tp1):
        return self.model(x_t, x_tp1)
    
    def on_train_epoch_start(self):
        self.error_accumulated = 0.0
        self.step_counter = 0

    def training_step(self, batch, batch_idx):
        if isinstance(self.model, nn.RNNBase):
            return self.training_step_rnn_tbptt(batch, batch_idx)
        else:
            return self.training_step_fixed_step_mlp(batch, batch_idx)
        
    def chunks_generator(self, batch):
        """
            mode: "inv" for inverse dynamics, "fwd" for forward dynamics
        """
        x_t, a_tm1, mask = batch

        if self.mode == "inv":

            a_tm2 = torch.roll(a_tm1, shifts=1, dims=1) # shift actions to make the alignment: (x_{t}, a_{t-2}); a_{T-1} to predict
            a_tm2[:, 0, :] = 0  # Avoid information leakage
            
            if self.enable_noise:
                x_input = torch.cat((x_t + torch.randn_like(x_t)*self.noise_magnitude[None, None, :].to(x_t.device), a_tm2), dim=-1)
            else: 
                x_input = torch.cat((x_t, a_tm2), dim=-1)

            mask[:, 0] = 0  # No valid prediction at first timestep

            # Split into chunks
            split_x_a_pairs = x_input.split(self.chunk_steps, dim=1)
            split_a = a_tm1.split(self.chunk_steps, dim=1)
            split_mask = mask.split(self.chunk_steps, dim=1)

            for x_seq, y_seq, m_seq in zip(split_x_a_pairs, split_a, split_mask):
                # Optionally skip incomplete chunk
                if x_seq.shape[1] != self.chunk_steps:
                    continue
                yield x_seq, y_seq, m_seq
        elif self.mode == "fwd":
            a_t = torch.roll(a_tm1, shifts=-1, dims=1) # shift actions to make the alignment: (x_{t}, a_{t}); x_{T+1} to predict
            a_t[:, -1, :] = 0  

            if self.enable_noise:
                x_input = torch.cat((x_t + torch.randn_like(x_t)*self.noise_magnitude[None, None, :].to(x_t.device), a_t), dim=-1)
            else: 
                x_input = torch.cat((x_t, a_t), dim=-1)

            # x_tp1 = torch.roll(x_t, shifts=-1, dims=1)  # Shift x_t to get x_tp1
            delta_q_tp1 = (torch.roll(x_t, shifts=-1, dims=1) - x_t)[..., 9:21] # delta prediction
            mask[:, -1] = 0  # No valid prediction for the last timestep

            # Split into chunks
            split_x_a_pairs = x_input.split(self.chunk_steps, dim=1)
            split_delta_q_tp1 = delta_q_tp1.split(self.chunk_steps, dim=1)
            split_mask = mask.split(self.chunk_steps, dim=1)

            for x_seq, y_seq, m_seq in zip(split_x_a_pairs, split_delta_q_tp1, split_mask):
                # Optionally skip incomplete chunk
                if x_seq.shape[1] != self.chunk_steps:
                    continue
                yield x_seq, y_seq, m_seq
    

    def training_step_fixed_step_mlp(self, batch, batch_idx):

        optimizer: torch.optim.Optimizer = self.optimizers()
        losses_and_masks = []
        y_gt_and_masks = []

        # 4. Perform the optimization in a loop
        for x_seq, y_seq, m_seq in self.chunks_generator(batch):

            if x_seq.shape[1] < self.chunk_steps:
                # If the sequence is shorter than trunk_steps, skip this batch
                continue
            x = x_seq.flatten(1, 2)  # Flatten the sequence dimension
            y = y_seq[:, -1]
            m = m_seq[:, -1]

            if not m.any():
                # If the mask is all zeros, skip this batch
                continue

            h = self.model(x)
            y_pred = self.output_layer(h)  
            loss_ = F.l1_loss(y_pred, y, reduction='none').mean(dim=-1)  # Compute L1 loss for each timestep in the sequence
            loss_m = loss_ * m # Apply mask to the loss

            loss_mean = loss_m.sum() / m.sum() if m.sum() > 0 else torch.tensor(0.0, device=self.device)  # Average loss over valid timesteps

            optimizer.zero_grad()
            self.manual_backward(loss_mean)
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip_value)  # Clip gradients to avoid exploding gradients
            torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0)  # Clip gradients to avo
            optimizer.step()

            # Apply mask to the loss
            losses_and_masks.append((loss_m.detach(), m))
            y_gt_and_masks.append((y, m))

        avg_loss = sum([(loss_m * mask).sum() for loss_m, mask in losses_and_masks]) / sum([mask.sum() for _, mask in losses_and_masks])
        avg_action_magnitude = sum([(y.abs().mean(-1) * mask).sum() for y, mask in y_gt_and_masks]) / sum([mask.sum() for _, mask in y_gt_and_masks])

        self.log("train_loss", avg_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_y_magnitude", avg_action_magnitude, on_step=True, on_epoch=True, prog_bar=True)
        return avg_loss
    

    def validate_manual_fixed_step_mlp(self, batch):
        losses_and_masks = []
        y_gt_and_masks = []

        # 4. Perform the optimization in a loop
        for x_seq, y_seq, m_seq in self.chunks_generator(batch):

            if x_seq.shape[1] < self.chunk_steps:
                # If the sequence is shorter than trunk_steps, skip this batch
                continue
            x = x_seq.flatten(1, 2)  # Flatten the sequence dimension
            y = y_seq[:, -1]
            m = m_seq[:, -1]

            if not m.any():
                # If the mask is all zeros, skip this batch
                continue

            h = self.model(x)
            y_pred = self.output_layer(h)  
            loss_ = F.l1_loss(y_pred, y, reduction='none').mean(dim=-1)  # Compute L1 loss for each timestep in the sequence
            loss_m = loss_ * m # Apply mask to the loss

            # Apply mask to the loss
            losses_and_masks.append((loss_m.detach(), m))
            y_gt_and_masks.append((y, m))

        losses_vec = torch.cat([loss_m[mask.bool()] for loss_m, mask in losses_and_masks], dim=-1)
        magnitude_vec = torch.cat([y.abs().mean(-1)[mask.bool()] for y, mask in y_gt_and_masks])

        return losses_vec, magnitude_vec



    def training_step_rnn_tbptt(self, batch, batch_idx):
        """
            batch: tuple of (x_t, a_tp, mask)
            (x_t, a_tp): each is [batch_size, seq_len, input_dim_states]
            mask: [batch_size, seq_len] - mask for the sequence, where 1 means valid and 0 means invalid
        
        """

        hiddens = None
        optimizer: torch.optim.Optimizer = self.optimizers()
        losses_and_masks = []
        y_gt_and_masks = []

        # 4. Perform the optimization in a loop
        for x, y, m in self.chunks_generator(batch):
            if not m.any():
                # If the mask is all zeros, skip this batch
                continue

            h, hiddens = self.model(x, hiddens)
            y_pred = self.output_layer(h)  
            loss_ = F.l1_loss(y_pred, y, reduction='none').mean(dim=-1)  # Compute L1 loss for each timestep in the sequence
            loss_m = loss_ * m # Apply mask to the loss

            loss_mean = loss_m.sum() / m.sum() if m.sum() > 0 else torch.tensor(0.0, device=self.device)  # Average loss over valid timesteps

            optimizer.zero_grad()
            self.manual_backward(loss_mean)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip_value)  # Clip gradients to avoid exploding gradients
            # torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0)  # Clip gradients to avo
            optimizer.step()

            # 5. "Truncate"
            if type(hiddens) is tuple:
                hiddens = [h.detach() for h in hiddens]
            else:
                # If hiddens is a single tensor, detach it
                hiddens = hiddens.detach()
            # Apply mask to the loss
            losses_and_masks.append((loss_m.detach(), m))
            y_gt_and_masks.append((y, m))

        avg_loss = sum([(loss_m * mask).sum() for loss_m, mask in losses_and_masks]) / sum([mask.sum() for _, mask in losses_and_masks])
        avg_action_magnitude = sum([(y.abs().mean(-1) * mask).sum() for y, mask in y_gt_and_masks]) / sum([mask.sum() for _, mask in y_gt_and_masks])

        self.log("train_loss", avg_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_y_magnitude", avg_action_magnitude, on_step=True, on_epoch=True, prog_bar=True)
        return avg_loss
    

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-5)


def get_number_of_trainable_parameters(model) -> int:
    """
        Returns the number of trainable parameters in the model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



def train(dataset_path, mode):

    wandb_logger = WandbLogger(project=f"{mode}_dynamics_new")

    dim_output = 12  
    dim_hidden = 128

    # Choose between GRU model and MLP model
    # model = nn.GRU(
    #     input_size=33+12,  # Concatenate x_t and x_tp1
    #     hidden_size=dim_hidden,
    #     num_layers=2,
    #     batch_first=True,  # Set to True to have input shape as (batch, seq, feature)
    #     dropout=0.1,
    # )

    chunk_steps = 30
    model = build_mlp(
        input_dims=(33+12)*chunk_steps,  
        hidden_dims=[512, 256, 128],  # Hidden dimensions for the MLP
        output_dims=128,  # Output dimension is the action dimension
        activation_name="elu"
    )

    # reinitialize the data every time we retrain the model, because the data samples may increase
    dataset = INVSequenceDataset(h5_path=dataset_path)
    print("num samples in the dataset:", len(dataset))
    full_dataset_size = len(dataset)
    print("full dataset trajectories number:", full_dataset_size)
    len_timesteps = dataset.len_timesteps()
    print("full dataset timesteps number:", len_timesteps)

    samples_number_needed = 10 * get_number_of_trainable_parameters(model)

    
    if dataset.len_timesteps() < samples_number_needed:
        warnings.warn(
            "The number of samples in the dataset may be too small for training the inverse dynamics model!"
            f"Number of samples: {len(dataset)}, Ratio num_samples/num_trainable_parameters: {len(dataset)/(0.1*samples_number_needed)}."
        )
    else:
        # select only a subset of the dataset to limit the training time
        # ratio_of_training_samples = samples_number_needed / dataset.len_timesteps() 
        ratio_of_training_samples = 0.5
        dataset, _ = random_split(dataset, [ratio_of_training_samples, 1-ratio_of_training_samples]) 


    sampler = RandomSampler(dataset, replacement=True, num_samples=len(dataset))
    dataloader = DataLoader(dataset, batch_size=1024, num_workers=0, sampler=sampler, collate_fn=inv_dynamics_sequence_collate_fn)
    l_model = P4RLSequenceLightningModule(model=model, dim_hidden=dim_hidden, dim_output=dim_output, chunk_steps=chunk_steps, mode=mode, enable_noise = True)
    # train_error_logger = TrainErrorLogger()
    profiler = SimpleProfiler()
    trainer = L.Trainer(max_epochs=10, log_every_n_steps=10, profiler=profiler, logger=wandb_logger, default_root_dir="logs/pretrain/lightning") # callbacks=[train_error_logger]
    trainer.fit(model=l_model, train_dataloaders=dataloader)

    trainer.save_checkpoint(f"logs/pretrain/lightning/{mode}_vanilla_pedi.ckpt")
    # l_model._save_to_state_dict("logs/pretrain/lightning/saved.ckpt")

    wandb.finish()


if __name__ == "__main__":
    mode = "inv"
    # mode = "inv"  # Change to "fwd" for forward dynamics training
    # dataset_path = "logs/datasets/inv_dynamics/inv_datasets/vanilla_pedi.h5"
    dataset_path = "logs/datasets/inv_dynamics/inv_datasets/vanilla_pedi.h5"
    train(dataset_path, mode=mode)
