from rsl_rl.addons.invdynamics.inv_dynamics_utils import *
from typing import List, Dict
from rsl_rl.addons.invdynamics.inv_dynamics_dataset_paths import dataset_paths
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP, InvDynamicsTransformer
from torch.utils.data import random_split
from lightning.pytorch.callbacks.early_stopping import EarlyStopping


def plot_epoch_errors(epoch_errors: List[List[float]], labels: List[str]):
    """Plot the training errors over epochs.
        epoch_errors
    """
    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 5))
    for i, errors in enumerate(epoch_errors):
        plt.plot(errors, label=f'{labels[i]}, final error: {errors[-1]:.2f}')
    plt.xlabel('Epoch')
    plt.ylabel('Error')
    plt.title('Training Errors Over Epochs')
    plt.legend()
    plt.grid()
    save_path = 'logs/analysis/plots/inv_epoch_errors_plot.png'
    plt.savefig(save_path, dpi=300)
    print(f"Plot saved as '{save_path}'") 


def reinitialized_and_train_model_mlp_offline(model: InvDynamicsMLP, dataset: DynamicSlidingWindowDataset, 
                                              epochs: int = 10, batch_size: int = 32, replacement: bool = False, 
                                              save_path: str = None, wandb_log=False, lr=1e-4, symmetry_augmentation_anymal=False,
                                              embodiment="anymal"):
    model.reinitialize_weights()

    torch.manual_seed(42)
    train_dataset, validation_dataset = random_split(dataset, [0.9, 0.1])

    sampler_train = RandomSampler(train_dataset, replacement=replacement, num_samples=len(train_dataset))

    # fixed_indices = torch.randperm(len(train_dataset))[:1000]  # fixed sample
    # sampler_train = torch.utils.data.SubsetRandomSampler(fixed_indices)

    print(f"Training samples num: {len(train_dataset)}, Validation samples num: {len(validation_dataset)}")

    dataloader_train = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, sampler=sampler_train)

    sampler_val = RandomSampler(validation_dataset, replacement=replacement, num_samples=len(validation_dataset))
    dataloader_val = DataLoader(validation_dataset, batch_size=batch_size, num_workers=0, sampler=sampler_val)

    ##############################################################################
    l_model = INVLightningModule(model=model, mode=model.mode, lr=lr, 
                                 symmetry_augmentation_anymal=symmetry_augmentation_anymal, 
                                 embodiment=embodiment)
    ##############################################################################
    # l_model = INVLightningModule.load_from_checkpoint("logs/pretrain/lightning/inv_vanilla_pedi_MLP_size_factor_8.ckpt")
    ##############################################################################

    # Early stopping callback
    early_stop_callback = EarlyStopping(monitor="val_error", min_delta=0.00, patience=5, verbose=True, mode="min", strict=True)

    if wandb_log:
        wandb_logger = WandbLogger(project="inv_dynamics_mlp")
        trainer = L.Trainer(max_epochs=epochs, log_every_n_steps=10, gradient_clip_val=5.0, logger=wandb_logger, 
                            check_val_every_n_epoch=1, 
                            # callbacks=[early_stop_callback],
                            )
    else:
        trainer = L.Trainer(max_epochs=epochs, log_every_n_steps=10, gradient_clip_val=5.0, 
                            callbacks=[early_stop_callback])
    trainer.fit(model=l_model, train_dataloaders=dataloader_train, val_dataloaders=dataloader_val)

    if wandb_log:
        wandb.finish()

    if save_path is not None:
        trainer.save_checkpoint(save_path)
        print(f"Model saved to {save_path}")
    return l_model.error_per_epoch


def train_single_model_mlp(inv_dynamics_cfg, save_path, lr, dataset_path, symmetry_augmentation_anymal=False,
                           embodiment="anymal"):

    

    dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, 
                                          window_size=inv_dynamics_cfg["input_timesteps"], 
                                          load_into_memory=True, 
                                        #   load_into_memory=False
                                          )
    
    # !! only for probing
    # dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=inv_dynamics_cfg["input_timesteps"], load_into_memory=False)

    # if the size of dataset is larger than 10M samples, only take 10%
    if len(dataset) > 6e6:
        # we need at the maximum 6M samples
        ratio = 6e6 * 1.0 / len(dataset)
        # to avoid too long training time, split the dataset if too large
        dataset, _ = random_split(dataset, [ratio, 1-ratio])
        print(f"Dataset too large, using only {ratio*100:.2f}% of it: {len(dataset)} samples.")


    print(f"Samples num: {len(dataset)}")
    model = eval(inv_dynamics_cfg["class_name"])(device="cuda", **inv_dynamics_cfg)
    ##################################################################################
    epoch_error = reinitialized_and_train_model_mlp_offline(model, dataset, epochs=100, batch_size=1024, 
                                                            replacement=False, save_path=save_path, wandb_log=True, 
                                                            lr=lr, 
                                                            symmetry_augmentation_anymal=symmetry_augmentation_anymal, 
                                                            embodiment=embodiment)
    ##################################################################################
    return epoch_error


def train_PIDM():
    # dataset_path = dataset_paths["Pedipulation Init (no random) Contact Obs"]
    # dataset_path = dataset_paths["Fixed-Base (Pedi)"]
    # dataset_path = dataset_paths["Pedipulation Init (no random)"]
    # dataset_path = dataset_paths["Pedipulation Init It 100"]
    # dataset_path = dataset_paths["Pedipulation Init (Absolute, Noise)"]
    # dataset_path = dataset_paths["Exploration Flat"]
    # dataset_path = dataset_paths["Exploration Rough"]
    # dataset_path = dataset_paths["Exploration Flat and Rough"]
    # dataset_path = dataset_paths["Locomotion Go1 100 Iter"]
    dataset_path = dataset_paths["Exploration Go1 240 Iter"]

    # dataset_path = dataset_paths["Pedi 100 Absolute"]

    symmetry_augmentation_anymal = False

    inv_dynamics_cfg = {
    "class_name": "InvDynamicsMLP",
    "dim_states": 33,  # 33 + 9 (contact booleans)
    "dim_actions": 12,
    "representation_dim": 256,
    # "hidden_dims": [512, 256, 128],
    "mode": "inv",
    # "mode": "fwd",
    "lstm_core": False,  # True for LSTM, False for MLP
    # "mode": "dl",
    "activation_name": "elu", # or "siren"
    "input_timesteps": 5,
    }

    errors_to_plot = []
    errors_labels = []

    for lr in [1e-4]:
        for size in [1,]:

            inv_dynamics_cfg["hidden_dims"] = [512*size, 256*size, 128*size]
            
            save_path = f"logs/pretrain/lightning/{inv_dynamics_cfg['mode']}_vanilla_pedi_MLP_size_factor_{size}.ckpt"
            label = f"MLP_lr={lr}"
            # inv_dynamics_cfg["lr"] = lr
            print(f"Training model with lr={lr}")
            epoch_error = train_single_model_mlp(inv_dynamics_cfg, save_path, lr, dataset_path, 
                                                 symmetry_augmentation_anymal=symmetry_augmentation_anymal, embodiment="anymal")
            errors_to_plot.append(epoch_error)
            errors_labels.append(label)

    plot_epoch_errors(errors_to_plot, errors_labels)



def train_PIDM_G1():

    # dataset_path = dataset_paths["Exploration G1 230 Iter"]
    dataset_path = dataset_paths["Locomotion G1 100 Iter"]

    inv_dynamics_cfg = {
    "class_name": "InvDynamicsMLP",
    "dim_states": 83,  
    "dim_actions": 37,
    "representation_dim": 256,
    # "hidden_dims": [512, 256, 128],
    "mode": "inv",
    # "mode": "fwd",
    "lstm_core": False,  # True for LSTM, False for MLP
    # "mode": "dl",
    "activation_name": "elu", # or "siren"
    "input_timesteps": 5,
    }

    errors_to_plot = []
    errors_labels = []

    for lr in [1e-4]:
        for size in [1,]:

            inv_dynamics_cfg["hidden_dims"] = [512*size, 256*size, 128*size]
            
            save_path = f"logs/pretrain/lightning/{inv_dynamics_cfg['mode']}_vanilla_pedi_MLP_size_factor_{size}.ckpt"
            label = f"MLP_lr={lr}"
            # inv_dynamics_cfg["lr"] = lr
            print(f"Training model with lr={lr}")
            epoch_error = train_single_model_mlp(inv_dynamics_cfg, save_path, lr, dataset_path, 
                                                 symmetry_augmentation_anymal=False, embodiment="G1")
            errors_to_plot.append(epoch_error)
            errors_labels.append(label)

    plot_epoch_errors(errors_to_plot, errors_labels)




def train_model_tf_sweep():

    inv_dynamics_cfg = {
    "class_name": "InvDynamicsTransformer",
    "dim_states": 33,  
    "dim_actions": 12,
    "representation_dim": 128,
    "feedforward_dim": 512,
    "num_layers": 3,
    "num_heads": 4,
    "mode": "fwd",
    "activation_name": "gelu",
    "input_timesteps": 5,
    }

    errors_to_plot = []
    errors_labels = []

    for input_timesteps in [10]:      
            save_path = f"logs/pretrain/lightning/{inv_dynamics_cfg['mode']}_vanilla_pedi_TF.ckpt"
            label = f"TF_input_timesteps={input_timesteps}"
            inv_dynamics_cfg["input_timesteps"] = input_timesteps
            print(f"Training model with input_timesteps={input_timesteps}")
            epoch_error = train_single_model_mlp(inv_dynamics_cfg, save_path)
            errors_to_plot.append(epoch_error)
            errors_labels.append(label)

    plot_epoch_errors(errors_to_plot, errors_labels)


def train_ensemble(inv_dynamics_cfg, vis=False):

    inv_ensemble = INVModelsEnsemble(inv_dynamics_cfg, "cuda")
    epoch_errors, full_dataset_size = inv_ensemble.retrain_models(
        dataset_path=dataset_paths["Pedipulation Init (no random)"],
        model_save_dir="logs/pretrain/new_inv_dynamics_offline/mlp_length_10",
        epochs=10,
        window_size=inv_dynamics_cfg["input_timesteps"]
        )
    
    plot_epoch_errors(epoch_errors, ["NAME"])

    if vis:
        visualizer = InvSamplesVisualization(n_neighbors=10, min_dist=0.5, random_state=42, use_PCA=False)
        vis_dataset_paths = {
            "Pedipulation Init": dataset_paths["Pedipulation Init"],
        }
        vis_samples_per_dataset = {
            "Pedipulation Init": 2000,
        }
        visualizer.visualize_samples_hue_function(dict_dataset_paths=vis_dataset_paths, hue_function=inv_ensemble.get_intrinsic_reward_ensemble_for_batch, vis_samples_per_dataset=vis_samples_per_dataset)


def check_ensemble(inv_dynamics_cfg):
    load_dir = "logs/pretrain/new_inv_dynamics_offline/iteration_0000"

    inv_ensemble = INVModelsEnsemble(inv_dynamics_cfg, "cuda")
    inv_ensemble.load_models(load_dir)
    inv_ensemble.check_parameter_explosion()

    visualizer = InvSamplesVisualization(n_neighbors=1, min_dist=0.5, random_state=42, use_PCA=False)
    vis_dataset_paths = {
        "Pedipulation Init": dataset_paths["Pedipulation Init"],
    }
    vis_samples_per_dataset = {
        "Pedipulation Init": 2000,
    }

    # visualizer.visualize_samples_hue_function(dict_dataset_paths=vis_dataset_paths, hue_function=inv_ensemble.get_intrinsic_reward_ensemble_for_batch, vis_samples_per_dataset=vis_samples_per_dataset)

    visualizer.visualize_samples_hue_function(dict_dataset_paths=vis_dataset_paths, hue_function=inv_ensemble.get_intrinsic_reward_single_for_batch, vis_samples_per_dataset=vis_samples_per_dataset)

if __name__ == "__main__":
    train_PIDM_G1()
    # train_PIDM()
    # train_model_tf_sweep()
    # train_ensemble(vis=False)
    # check_ensemble()