import torch
import torch.nn as nn
from rsl_rl.rsl_rl.utils.utils import resolve_nn_activation, siren_init
from abc import abstractmethod, ABC
from rsl_rl.rsl_rl.addons.nn_utils import init_lstm, init_gru
import math
from einops import repeat
import lightning as L
from datetime import datetime
import os
from torch.utils.data import DataLoader, RandomSampler


def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int, activation_name: str = "elu") -> nn.Sequential:
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        if activation_name=="siren":
            activation_first = resolve_nn_activation("siren_30")
            activation_following = resolve_nn_activation("siren_1")
        else:
            activation_first = resolve_nn_activation(activation_name)
            activation_following = resolve_nn_activation(activation_name)
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation_first)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation_following)
        mlp_model = nn.Sequential(*network_layers)

        if activation_name=="siren":
            for model in mlp_model.modules():
                if isinstance(model, nn.Linear):
                    siren_init(model, is_first=True if model == mlp_model[0] else False)
        return mlp_model


class RNNCore(nn.Module): 
    def __init__(self, 
                 dim_states, 
                 dim_actions, 
                 representation_dim, 
                 continuous_output_dim,
                 discrete_output_dim,
                 hidden_dim: int, 
                 num_layers: int = 1,
                 rnn_type: str = "gru",
                 activation_name: str = "elu",
                 weight_path = None, 
                 finetune_frozen = False,
                 device: str = "cpu",
                 **kwargs):
        """
            Args:
                dim_states: Dimension of the state space.
                dim_actions: Dimension of the action space.
                input_timesteps: Number of timesteps in the input sequence.
                representation_dim: Dimension of the state representation.
                hidden_dims: List of hidden layer dimensions.
                activation_name: Activation function to use in the MLP.

                weight_path: Path to the pretrained weights (optional).
                finetune_frozen: If True, freeze the pretrained weights during fine-tuning.
                device: Device to run the model on. 
        """
        super().__init__()
        if kwargs:
            print(
                "InvDynamicMLP.__init__ got unexpected arguments, which will be ignored: "
                + str([key for key in kwargs.keys()])
            )
        self.dim_states = dim_states
        self.dim_actions = dim_actions
        self.device = device # NOTE: not used
        self.hidden_dim = hidden_dim
        
        self.rnn_type = rnn_type.lower()
        if self.rnn_type == "gru":
            self.main_RNN = nn.GRU(
                input_size=dim_states+dim_actions,
                hidden_size=hidden_dim,
                num_layers=num_layers,
                batch_first=True,  # Set to True to have input shape as (batch, seq, feature)
                dropout=0.1,
            )
        elif self.rnn_type == "lstm":
            self.main_RNN = nn.LSTM(
                input_size=dim_states+dim_actions,
                hidden_size=hidden_dim,
                num_layers=num_layers,
                batch_first=True,  # Set to True to have input shape as (batch, seq, feature)
                dropout=0.1,
            )
        else:
            raise ValueError(f"Unsupported RNN type: {self.rnn_type}. Supported types are 'gru' and 'lstm'.")
        
        self.continuous_output_layer = nn.Linear(hidden_dim, continuous_output_dim)
        self.discrete_output_layer = nn.Linear(hidden_dim, discrete_output_dim)
        
        if weight_path is not None:
            try:
                if weight_path != "random_init":
                    self.load_state_dict(torch.load(weight_path, weights_only=True))
                else:
                    print("The inv dynamics module is initialized with random weights: only for comparison.")
                self.weight_path = weight_path
                # only instantiate the action encoder for RL when loading weights
                # self.critic_output_layer = nn.Linear(penultimate_dim, 1)
                if finetune_frozen:
                    self.freeze_pretrained_weights()
            except:
                raise Exception("Failed to load pretrained weights; Check if the path is correct and the architectures match.")
            
        self.reinitialize_weights()

        self.to(device)
        print(f"INV Dynamics trainable parameters: {self.get_number_of_trainable_parameters()}")
        

    def freeze_pretrained_weights(self):
        self.per_step_encoder.requires_grad_(False)
        self.main_RNN.requires_grad_(False)
        self.continuous_output_layer.requires_grad_(False)
        self.discrete_output_layer.requires_grad_(False)

    def forward(self, x, a, mask=None):
        """
            args:
                states: [batch_size, T, input_dim_states]
                actions: [batch_size, T, input_dim_actions]
        """
        input_per_timestep = torch.cat((x, a), dim=-1)  # [batch_size, T, input_dim_states + input_dim_actions]
        h, _ = self.main_RNN(input_per_timestep)  # h: [batch_size, T-1, hidden_dim]
        out_continuous = self.continuous_output_layer(h)  # out: [batch_size, T-1, dim_actions]
        out_discrete = self.discrete_output_layer(h)  # out: [batch_size, T-1, dim_actions]
        return out_continuous, out_discrete, h
    
    def forward_RL(self, observations):
        raise NotImplementedError("This method is not implemented for the PIDM")
    
    def reinitialize_weights(self):
        """
            Reinitialize the weights of the model.
            This is useful for resetting the model after training or for fine-tuning.
        """
        for mlp in [self.per_step_encoder, self.continuous_output_layer, self.discrete_output_layer]:
            for layer in mlp.modules():
                if isinstance(layer, nn.Linear):
                    nn.init.kaiming_uniform_(layer.weight)
                    if layer.bias is not None:
                        nn.init.zeros_(layer.bias)
        if self.rnn_type == init_lstm:
            init_lstm(self.main_RNN)
        elif self.rnn_type == init_gru:
            init_gru(self.main_RNN)



class NeuralINVSolver(nn.Module):
    def __init__(self, 
                 rnn_latent_dim: int, 
                 target_dim: int,
                 dim_actions: int, 
                 hidden_dims: list[int] = [128, 128],
                 ):
        super().__init__()
        self.mlp = build_mlp(
            input_dims=rnn_latent_dim + target_dim,
            hidden_dims=hidden_dims,
            output_dims=dim_actions,
            activation_name="elu",
        )
    
    def forward(self, rnn_latent, target):
        return self.mlp(torch.cat((rnn_latent, target), dim=-1))
    


class NeuralINVSolverModule(L.LightningModule):
    def __init__(self, rnn_core, inv_solver, mode="inv", lr=1e-5):
        super().__init__()
        self.save_hyperparameters()
        self.rnn_core: RNNCore = rnn_core
        self.inv_solver: NeuralINVSolver = inv_solver
        self.error_per_epoch = []
        self.error_accumulated = 0.0
        self.step_counter = 0
        self.mode = mode
        self.lr = lr
        self.penalize_grad = True
        self.grad_loss_beta = 10

        # self.noise_magnitude = torch.tensor( [0.1]*3 + [0.2]*3 + [0.05]*3 + [0.01]*12 + [1.5]*12) # make sure the order is lin_vel, ang_vel, gravity_vector, joint_angles,  joing_vels

        # if mode != "inv" and mode in :
        self.save_dir = f"logs/pretrain/lightning/inv_module_training/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
        os.makedirs(self.save_dir, exist_ok=True)

    def forward(self, *args):
        return self.model(*args)
    
    def on_train_epoch_start(self):
        self.error_accumulated = 0.0
        self.step_counter = 0
        if self.current_epoch % 20 == 0:
            self.save_model_pt(os.path.join(self.save_dir, f"epoch_{self.current_epoch:04d}.pt"))

    # def save_model_pt(self, save_path):
    #     """
    #     Save the model to a .pt file.
    #     :param save_path: Path to save the model.
    #     """
    #     torch.save(self.model.state_dict(), save_path)
    #     print(f"Model saved to {save_path}")

    def on_train_end(self):
        self.save_model_pt(os.path.join(self.save_dir, "final_model.pt"))
        return super().on_train_end()

    def add_noise_to_input(self, x_cut):
        return x_cut + torch.randn_like(x_cut)*self.noise_magnitude[None, None, :].to(x_cut.device)
    
    def prepare_input(self, batch):
        x, a = batch
        x_cut = self.add_noise_to_input(x[:, :-1])
        x_t_gt = x[:, -2, :]
        a_tm1_gt = a[:, -1, :]

        a_cut = a[:, :-1]
        desired_delta_x = x[:, -1, :] - x[:, -2, :]  # shape: (batch_size, input_dim_states)
        return x_cut, a_cut, desired_delta_x, x_t_gt, a_tm1_gt
    
    def cal_denoise_loss(self, out_continuous, out_discrete, x_t_gt):
        return 0.0
    
    def cal_inv_loss(self, inv_actions, a_tm1_gt):
        return nn.functional.l1_loss(inv_actions, a_tm1_gt)

    
    def step(self, batch, batch_idx):
        x_cut, a_cut, desired_delta_x, x_t_gt, a_tm1_gt = self.prepare_input(batch)
        out_continuous, out_discrete, h = self.rnn_core(x_cut, a_cut)
        denoise_loss = self.cal_denoise_loss(out_continuous, out_discrete, x_t_gt)

        inv_actions = self.inv_solver(h[:, -1, :], desired_delta_x)
        inv_loss = self.cal_inv_loss(inv_actions, a_tm1_gt)

        total_loss = inv_loss + denoise_loss
        mean_abs_error = torch.mean(torch.abs(inv_actions - a_tm1_gt))

        return denoise_loss, total_loss, mean_abs_error
    
    def training_step(self, batch, batch_idx):
        denoise_loss, loss, mean_abs_error = self.step(batch, batch_idx)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_error", mean_abs_error, on_epoch=True)
        self.error_accumulated += mean_abs_error.item()
        self.step_counter += 1
        return loss

    def validation_step(self, batch, batch_idx):
        denoise_loss, loss, mean_abs_error = self.step(batch, batch_idx)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("val_error", mean_abs_error, on_epoch=True)
        self.error_accumulated += mean_abs_error.item()
        self.step_counter += 1
        return loss
    
    def on_train_epoch_end(self):
        self.error_per_epoch.append(self.error_accumulated / self.step_counter if self.step_counter > 0 else 0.0)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)
    
    @torch.no_grad()
    def validate_batch_detail(self, batch):
        """
        Validate the model on a batch of data and return the losses and magnitudes.
        :param batch: A batch of data.
        :return: A tuple of (losses, magnitudes).
        """
        x, a = batch
        output = self.forward(x, a)
        if self.mode == "inv":
            a_tm1 = a[:, -1, :]
            mean_abs_error = torch.abs(output - a_tm1)
            target_magnitude = torch.abs(a_tm1)
        elif self.mode == "fwd":
            delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]
            mean_abs_error = torch.abs((output - delta_x_t))
            target_magnitude = torch.abs(delta_x_t)
        elif self.mode == "jacobian":
            jacobian, bias_state = output
            delta_pred = torch.bmm(jacobian, a[:, -1, :].unsqueeze(-1)).squeeze(-1) + bias_state
            delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]
            mean_abs_error = torch.abs(delta_pred - delta_x_t)
            target_magnitude = torch.abs(delta_x_t)
        elif self.mode == "dl":  # decoupled linear
            w, b = output
            delta_pred = w * a[:, -1, :].squeeze(-1) + b
            delta_x_t = (x[:, -1, :] - x[:, -2, :])[..., 9:21]
            mean_abs_error = torch.abs(delta_pred - delta_x_t)
            target_magnitude = torch.abs(delta_x_t) 

        return mean_abs_error.flatten(), target_magnitude.flatten()




