import lightning as L
import torch
from torch import nn
from torch.nn import functional as F
import os
from torch.utils.data import Dataset, RandomSampler, DataLoader
import h5py
import warnings
from torch.utils.data import random_split
from lightning.pytorch.profilers import SimpleProfiler
from torch.nn.utils.rnn import pad_sequence
from lightning.pytorch.loggers import WandbLogger
import wandb
from rsl_rl.addons.invdynamics.inv_dynamics_module import build_mlp
from rsl_rl.addons.invdynamics.inv_dynamics_utils import DynamicSlidingWindowDataset
from typing import Union, List, Dict, Tuple
from rsl_rl.rsl_rl.utils.utils import resolve_nn_activation


def batch_left_pinv(A: torch.Tensor) -> torch.Tensor:
    """
    Compute the left pseudo-inverse of a batch of matrices A with shape [b, m, n], where m > n.
    Returns: Tensor of shape [b, n, m]
    """
    A_T = A.transpose(-1, -2)                      # [b, n, m]
    ATA = A_T @ A                                  # [b, n, n]
    ATA_inv = torch.pinverse(ATA)                   # [b, n, n]
    A_pinv = ATA_inv @ A_T                         # [b, n, m]
    return A_pinv


def prepare_jacobian_input_from_obs(obs):
    """
        obs_t: [batch_size, 273]. See P4RL-Pre-Dynamic-Pedipulation-Flat-Blind-Anymal-D-v0's config for details.
    """
    b=obs.shape[0]
    history_length_in_one_obs = 6
    jp, a, lin_vel, ang_vel, grav, jv, command = obs[:, 0:72], obs[:, 72:144], obs[:, 144:162], obs[:, 162:180], obs[:, 180:198], obs[:, 198:270], obs[:, 270:]
    jp, a, lin_vel, ang_vel, grav, jv = jp.reshape(b, history_length_in_one_obs, -1), a.reshape(b, history_length_in_one_obs, -1), lin_vel.reshape(b, history_length_in_one_obs, -1), ang_vel.reshape(b, history_length_in_one_obs, -1), grav.reshape(b, history_length_in_one_obs, -1), jv.reshape(b, history_length_in_one_obs, -1)
    obs_segment = torch.cat([lin_vel, ang_vel, grav, jp, jv], dim=-1) # [b, 6, 21]

    current_step_input = torch.cat([obs_segment[:, -1, :], a[:, -1], command], dim=-1)  # [b, 48]
                                   # mimic the default policy input
    return obs_segment[:, 1:, :], a[:, 1:, :], current_step_input # [b, 5, 21]


class JacobianMLP(nn.Module):
    def __init__(self, dim_states, dim_actions, dim_states_output, input_timesteps, hidden_dims: list[int], representation_dim, 
                 backbone_output_dim,
                 weight_path = None, 
                 finetune_frozen = False, activation_name: str = "elu",
                 **kwargs):
        super(JacobianMLP, self).__init__()

        if activation_name=="siren":
            raise NotImplementedError

        if kwargs:
            print(
                "DynamicMLP.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.module_type = "jacobian"

        self.dim_states = dim_states
        self.dim_actions = dim_actions
        self.input_timesteps = input_timesteps
        self.dim_states_output = dim_states_output
        self.action_output = True

        penultimate_dim = backbone_output_dim

        self.hidden_dim = hidden_dims
        # all pretrained modules
        if dim_actions != 0:
            self.dim_actions = dim_actions
            self.action_encoder = build_mlp(dim_actions, [128], representation_dim, activation_name=activation_name)
        else:
            self.dim_actions = 0
            print("No action encoder for pretraining, this model now is a series forecasting model for states. You should not see this normally.")
        self.state_encoder = build_mlp(dim_states, [128], representation_dim, activation_name=activation_name)
        self.main_MLP = build_mlp(2*representation_dim*input_timesteps, hidden_dims, penultimate_dim, activation_name=activation_name)

        output_dim = dim_states_output*dim_actions + dim_states_output
        self.backbone_output_dim = output_dim + 48 # NOTE
        self.output_layer = nn.Sequential(
            resolve_nn_activation(activation_name),
            nn.Linear(penultimate_dim, output_dim)
        )
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path

                self.state_space_policy_network = build_mlp(48, [128]*3, 21, activation_name=activation_name)
                self.action_residual_net = build_mlp(48+12, [128]*3, dim_actions, activation_name=activation_name)
               
                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")

    def freeze_pretrained_weights(self):
        self.state_encoder.requires_grad_(False)
        self.action_encoder.requires_grad_(False)
        self.main_MLP.requires_grad_(False)
          
    def forward_RL(self, x):
        states, actions, current_step_input = prepare_jacobian_input_from_obs(x)
        with torch.no_grad():
            # jacobian net is NOT trained in RL
            jacobian, bias_state = self.forward(states, actions)

        state_space_command = self.state_space_policy_network(current_step_input)  # [b, dim_states_output]
        action_guess = (batch_left_pinv(jacobian) @ (state_space_command - bias_state).unsqueeze(-1)).squeeze(-1)
        action_residual = self.action_residual_net(torch.cat([current_step_input, action_guess], dim=-1))  # [b, dim_actions] TODO: add state space command to the action residual net input
        action = action_guess + action_residual  # [b, dim_actions]
        return action

    def forward(self, states, actions) -> Tuple[torch.Tensor, torch.Tensor]:
        """
            args:
                states: [batch_size, input_timesteps, input_dim_states]
                actions: [batch_size, input_timesteps, input_dim_actions]
        """
        embedded_states = torch.flatten(self.state_encoder(states), start_dim=1)
        if self.dim_actions != 0:
            embedded_actions = torch.flatten(self.action_encoder(actions), start_dim=1)
            x = torch.cat((embedded_states, embedded_actions), dim=-1)
        else:
            x = embedded_states
        x = self.main_MLP(x)
        x = self.output_layer(x)
        jacobian = x[:, :self.dim_states_output * self.dim_actions].reshape(x.shape[0], self.dim_states_output, self.dim_actions)
        bias_state = x[:, self.dim_states_output * self.dim_actions:].reshape(x.shape[0], self.dim_states_output)
        return jacobian, bias_state



class JacobianLightningModule(L.LightningModule):
    def __init__(self, model, grad_norm_clip_value=1.0, enable_noise = False):
        super().__init__()
        self.save_hyperparameters() 
        self.model: JacobianMLP = model
        self.error_per_epoch = []
        self.error_accumulated = 0.0
        self.step_counter = 0
        self.automatic_optimization = True
        self.grad_norm_clip_value = grad_norm_clip_value
        self.mode = mode  # "inv" for inverse dynamics, "fwd" for forward dynamics
        self.enable_noise = enable_noise

        self.noise_magnitude = torch.tensor( [0.1]*3 + [0.2]*3 + [0.05]*3 + [0.01]*12 + [1.5]*12) # make sure the order is lin_vel, ang_vel, gravity_vector, joint_angles,  joing_vels


    def forward(self, x_t, x_tp1):
        return self.model(x_t, x_tp1)
    
    def on_train_epoch_start(self):
        self.error_accumulated = 0.0
        self.step_counter = 0

    def training_step(self, batch, batch_idx):
        return self.training_step_fixed_step_mlp(batch, batch_idx)
    

    def training_step_fixed_step_mlp(self, batch, batch_idx):

        x, a = batch #  both in shape (batch_size, window_size, input_dim)
        x_input, a_input = x[:, :-1], a[:, :-1] 
        x_plus_delta = (x[:, -1] - x[:, -2])[..., :21]
        a_t = a[:, -1]  # the last action in the sequence

        if self.enable_noise:
            # add the same amount of noise as observed in simulation during RL; recorded data is noise free.
            x_input = x_input + (torch.rand_like(x_input)*2-1)*self.noise_magnitude[None, None, :].to(x_input.device)

        jacobian, bias_state = self.model(x_input, a_input)  # Get the Jacobian and bias state

        # Compute the predicted change in state
        delta_pred = torch.bmm(jacobian, a_t.unsqueeze(-1)).squeeze(-1) + bias_state  # Shape: (batch_size, dim_states)
        # Compute the loss
        loss = F.l1_loss(delta_pred, x_plus_delta, reduction='none') # [batch_size, dim_states]
        avg_loss = loss.mean()  # Average over the batch
        train_loss_joint_angles = loss[:, 9:21].mean()  # NOTE: pay very close attention to those indices.

        self.log("train_loss", avg_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_loss_joint_angles", train_loss_joint_angles, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_magnitude_joint_changes", torch.abs(x_plus_delta[:, 9:21]).mean(), on_step=True, on_epoch=True, prog_bar=True)
        return avg_loss


    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-5)
    
    def save_torch(self, path):
        """
            Save the model weights to a file.
        """
        torch.save(self.model.state_dict(), path)


def get_number_of_trainable_parameters(model: nn.Module) -> int:
    """
        Returns the number of trainable parameters in the model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



def train(dataset_path, mode):

    wandb_logger = WandbLogger(project=f"{mode}_dynamics_new")

    dim_output = 12  
    dim_hidden = 512
    window_size = 5

    # Choose between GRU model and MLP model

    model = JacobianMLP(
        dim_states=33,  # state dimension
        dim_actions=12,  # action dimension
        dim_states_output=21, 
        input_timesteps=window_size,
        hidden_dims=[dim_hidden, dim_hidden//2, dim_hidden//4],
        representation_dim=dim_hidden//8,
        backbone_output_dim=dim_hidden//2,
        weight_path=None,  # path to the pretrained weights if available
        finetune_frozen=False,
        activation_name="elu"
    )

    # reinitialize the data every time we retrain the model, because the data samples may increase
    dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=window_size+1)
    print("num samples in the dataset:", len(dataset))
    full_dataset_size = len(dataset)
    print("full dataset trajectories number:", full_dataset_size)
    len_timesteps = dataset.len_timesteps()
    print("full dataset timesteps number:", len_timesteps)

    # samples_number_needed = 10 * get_number_of_trainable_parameters(model)

    
    # if dataset.len_timesteps() < samples_number_needed:
    #     warnings.warn(
    #         "The number of samples in the dataset may be too small for training the inverse dynamics model!"
    #         f"Number of samples: {len(dataset)}, Ratio num_samples/num_trainable_parameters: {len(dataset)/(0.1*samples_number_needed)}."
    #     )
    # else:
    #     # select only a subset of the dataset to limit the training time
    #     # ratio_of_training_samples = samples_number_needed / dataset.len_timesteps() 
    #     ratio_of_training_samples = 0.5
    #     dataset, _ = random_split(dataset, [ratio_of_training_samples, 1-ratio_of_training_samples]) 

    # only use this sampler if you want to use "bootstrap"
    # if not using explicitly a random sampler, put shuffle=True in DataLoader! 
    # sampler = RandomSampler(dataset, replacement=True, num_samples=len(dataset))
    dataloader = DataLoader(dataset, batch_size=1024, num_workers=0, shuffle=True)
    l_model = JacobianLightningModule(model=model, enable_noise = True)
    # train_error_logger = TrainErrorLogger()

    # profiler = SimpleProfiler()
    trainer = L.Trainer(max_epochs=10, log_every_n_steps=10, logger=wandb_logger, default_root_dir="logs/pretrain/lightning") # callbacks=[train_error_logger]
    trainer.fit(model=l_model, train_dataloaders=dataloader)

    trainer.save_checkpoint(f"logs/pretrain/lightning/{mode}.ckpt")
    l_model.save_torch(f"logs/pretrain/lightning/{mode}.pt")
    # l_model._save_to_state_dict("logs/pretrain/lightning/saved.ckpt")

    wandb.finish()


if __name__ == "__main__":
    mode = "jacobian"
    dataset_path = "logs/datasets/inv_dynamics/inv_datasets/vanilla_pedi.h5"
    train(dataset_path, mode=mode)
