"""
Module containing model implementations for hierarchical reinforcement learning.
"""
from stable_baselines3 import PPO, DQN
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.her import HerReplayBuffer
from stable_baselines3.dqn import MlpPolicy as DqnMlpPolicy
from stable_baselines3.ppo import MlpPolicy as PpoMlpPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv
import torch
import os
from hrl.common import *
import os
import torch
import pandas as pd
import imageio
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
import json
import yaml

BASE_HRL_DIR = os.getenv('BASE_HRL_DIR')

ACTIVATION_FN_MAP = {
    "ReLU": torch.nn.ReLU,
    "Tanh": torch.nn.Tanh,
    "LeakyReLU": torch.nn.LeakyReLU,
    "ELU": torch.nn.ELU,
    "GELU": torch.nn.GELU,
}

import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class SequentialEmbeddingExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, embedding_dim=64, hidden_dim=128):
        # Calculate output dimension
        features_dim = hidden_dim
        super().__init__(observation_space, features_dim)
        
        self.num_states = 256
        self.sequence_length = 10
        self.embedding_dim = embedding_dim
        
        # Embedding layer: +1 for padding token (-1)
        self.embedding = nn.Embedding(self.num_states + 1, embedding_dim)
        
        self.fc = nn.Sequential(
            nn.Linear(embedding_dim * self.sequence_length, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    
    def forward(self, observations):
        batch_size = observations.shape[0]
        # Map -1 to the last embedding index (256)
        obs_mapped = torch.where(observations == -1, self.num_states, observations)
        
        embedded = self.embedding(obs_mapped.long())  # (batch_size, seq_len, embedding_dim)
        embedded_flat = embedded.view(batch_size, -1)  # (batch_size, seq_len * embedding_dim)
        features = self.fc(embedded_flat)
        return features
    
class CustomDQNPolicy(DqnMlpPolicy):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs,
                        features_extractor_class=SequentialEmbeddingExtractor,
                        features_extractor_kwargs={
                            "embedding_dim": 32,
                            "hidden_dim": 256
                        }
                        )

class CustomPPOPolicy(PpoMlpPolicy):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs,
                        features_extractor_class=SequentialEmbeddingExtractor,
                        features_extractor_kwargs={
                            "embedding_dim": 32,
                            "hidden_dim": 256
                        }
                        )

def create_learning_rate_schedule(schedule_config):
    if schedule_config["type"] == "linear":
        initial_value = schedule_config["initial_value"]
        final_value = schedule_config["final_value"]
        
        def linear_schedule(progress):
            return final_value + progress * (initial_value - final_value)
        
        return linear_schedule
    
    elif schedule_config["type"] == "exponential":
        initial_value = schedule_config["initial_value"]
        final_value = schedule_config["final_value"]
        
        def exponential_schedule(progress):
            return final_value + (initial_value - final_value) * (1 - progress)
        
        return exponential_schedule
    
    elif schedule_config["type"] == "constant":
        value = schedule_config["value"]
        
        def constant_schedule(progress):
            return value
        
        return constant_schedule
    
    else:
        raise ValueError(f"Unknown schedule type: {schedule_config['type']}")

def resolve_special_values(config):
    resolved_config = config.copy()
    
    # Handle activation function
    if "_activation_fn" in resolved_config:
        fn_name = resolved_config.pop("_activation_fn")
        if fn_name in ACTIVATION_FN_MAP:
            resolved_config["activation_fn"] = ACTIVATION_FN_MAP[fn_name]
    
    # Handle learning rate schedule
    if "_learning_rate_schedule" in resolved_config:
        schedule_config = resolved_config.pop("_learning_rate_schedule")
        resolved_config["learning_rate"] = create_learning_rate_schedule(schedule_config)
    
    return resolved_config

def setup_callbacks(tensorboard_path, eval_vec_env, eval_freq, model_prefix):
    # Evaluation callback
    eval_callback = EvalCallback(
        eval_vec_env,
        best_model_save_path=os.path.join(os.path.dirname(tensorboard_path), "best_model"),
        log_path=os.path.join(os.path.dirname(tensorboard_path), "eval_results"),
        eval_freq=eval_freq,
        n_eval_episodes=10,
        deterministic=True,
        verbose=1
    )
    
    # Checkpoint callback
    checkpoint_callback = CheckpointCallback(
        save_freq=eval_freq,
        save_path=os.path.join(os.path.dirname(tensorboard_path), "checkpoints"),
        name_prefix=f"{model_prefix}_model",
        verbose=1
    )
    
    # Combine callbacks
    return [eval_callback, checkpoint_callback]

def get_experiment_paths(args, algorithm_name):
    hard_name = "hard" if args.hrl_hard else f"std_{format_number(args.hrl_std)}"
    exp_path = os.path.join(BASE_HRL_DIR, f"{args.env}_{args.hrl_exp_name}_{hard_name}_{algorithm_name}")
    model_path = os.path.join(exp_path, "model")
    tensorboard_path = os.path.join(exp_path, "tensorboard")
    
    return exp_path, model_path, tensorboard_path

def load_config(filepath=None, algorithm=None, config_name="default"):
    """
    Load configuration from a file or use default.
    
    Args:
        filepath: Path to the configuration file
        algorithm: Algorithm name
        config_name: Name of the configuration to load
        
    Returns:
        Loaded configuration
    """
    # Load from file
    ext = os.path.splitext(filepath)[1].lower()
    if ext == '.json':
        with open(filepath, 'r') as f:
            configs = json.load(f)
    elif ext in ['.yaml', '.yml']:
        with open(filepath, 'r') as f:
            configs = yaml.safe_load(f)
    else:
        raise ValueError(f"Unsupported file extension: {ext}. Use .json, .yaml, or .yml")
    
    # Verify the loaded configuration
    if algorithm not in configs:
        raise ValueError(f"Algorithm {algorithm} not found in config file {filepath}")
    
    if config_name not in configs[algorithm]:
        raise ValueError(f"Configuration name {config_name} not found for algorithm {algorithm} in config file {filepath}")
    
    config = configs[algorithm][config_name]
    return config

def build_ppo(vec_env, eval_vec_env, eval_freq, tensorboard_path='/tensorboard/', **kwargs):
    # Get device
    device = "cpu" 
    
    policy_type = kwargs.pop('policy_type', "MlpPolicy")
    if policy_type == "CustomPPOPolicy":
        policy_type = CustomPPOPolicy
     
    # Extract policy kwargs
    policy_kwargs = {
        'ortho_init': kwargs.pop('ortho_init', True)
    }
    
    # Handle net_arch specially
    if 'net_arch' in kwargs:
        policy_kwargs['net_arch'] = kwargs.pop('net_arch')
    
    # Handle activation function specially
    if 'activation_fn' in kwargs:
        policy_kwargs['activation_fn'] = kwargs.pop('activation_fn')
    
    # Create model with remaining kwargs passed directly
    model = PPO(
        policy_type,
        env=vec_env,
        device=device,
        tensorboard_log=tensorboard_path,
        policy_kwargs=policy_kwargs,
        verbose=1,
        **kwargs  # Pass all remaining kwargs directly to PPO
    )
    
    # Get callbacks
    callbacks = setup_callbacks(tensorboard_path, eval_vec_env, eval_freq, "ppod")
    
    return model, callbacks

def build_dqn_her(vec_env, eval_vec_env, eval_freq, tensorboard_path='/tensorboard/', **kwargs):
    # Extract HER-specific parameters
    n_sampled_goal = kwargs.pop("n_sampled_goal", 12)
    goal_selection_strategy = kwargs.pop("goal_selection_strategy", "future")
    
    # Extract policy kwargs
    policy_kwargs = {
        'normalize_images': False
    }
    
    # Handle net_arch specially
    if 'net_arch' in kwargs:
        policy_kwargs['net_arch'] = kwargs.pop('net_arch')
    
    # Create model with the cleaned kwargs
    model = DQN(
        "MultiInputPolicy",
        env=vec_env,
        replay_buffer_class=HerReplayBuffer,
        replay_buffer_kwargs=dict(
            n_sampled_goal=n_sampled_goal,
            goal_selection_strategy=goal_selection_strategy,
        ),
        tensorboard_log=tensorboard_path,
        policy_kwargs=policy_kwargs,
        verbose=1,
        **kwargs  # Pass all remaining kwargs directly to DQN
    )
    
    # Get callbacks
    callbacks = setup_callbacks(tensorboard_path, eval_vec_env, eval_freq, "dqn_her")
    
    return model, callbacks

def build_dqn(vec_env, eval_vec_env, eval_freq, tensorboard_path='/tensorboard/', **kwargs):
    # Extract policy type
    policy_type = kwargs.pop('policy_type', "MlpPolicy")
    
    # Extract policy kwargs
    policy_kwargs = {}
    
    # Handle net_arch specially
    if 'net_arch' in kwargs:
        policy_kwargs['net_arch'] = kwargs.pop('net_arch')
    
    print(f"kwargs: {kwargs}")
    
    # Create model with the cleaned kwargs
    model = DQN(
        policy_type,
        env=vec_env,
        tensorboard_log=tensorboard_path,
        policy_kwargs=policy_kwargs,
        verbose=1,
        **kwargs  # Pass all remaining kwargs directly to DQN
    )
    
    # Get callbacks
    callbacks = setup_callbacks(tensorboard_path, eval_vec_env, eval_freq, "dqn")
    
    return model, callbacks

# Dictionary mapping algorithm names to their build functions
ALGORITHM_MAP = {
    "ppo": build_ppo,
    "ppo_lrs": build_ppo,
    "dqn": build_dqn,
    "dqn_her": build_dqn_her,
}


def run_test(args, model_class):
    """
    Run test for a model.
    
    Args:
        args: Command line arguments
        model_class: The model class to load (e.g., PPO, DQN)
    """
    # Create environment for testing
    test_env = setup_hrl_environment(args, render_mode="rgb_array")
    
    # Load and test the model
    loaded_model = model_class.load(args.hrl_model_path, env=test_env)
    results = test_model(
        model=loaded_model,
        env=test_env,
        base_dir="/".join(args.hrl_model_path.split("/")[:-2]),
        num_episodes=args.hrl_test_episodes if hasattr(args, 'hrl_test_episodes') else 3,
        save_video=False,
        fps=args.hrl_fps if hasattr(args, 'hrl_fps') else 30
    )
    
    test_env.close()
    return results

def write_config(args, config, exp_path):
    
    # Log the configuration being used
    print(f"Using configuration for {algorithm}:")
    for key, value in config.items():
        if key != "net_arch" and key != "activation_fn":  # Skip complex objects for readability
            print(f"  {key}: {value}")
            
    hrl_args = {} 
    for arg_name in vars(args):
        if arg_name.startswith("hrl"):
            hrl_args[arg_name] = getattr(args, arg_name)
    
    print(config)
    
    os.makedirs(exp_path, exist_ok=True)
    # Save both config and hrl_args to the config file
    combined_config = {
        "sb3_config": config,
        "hrl_args": hrl_args
    }
    with open(os.path.join(exp_path, "config.yaml"), 'w') as f:
        yaml.dump(combined_config, f)
    
if __name__ == "__main__":
    # Parse arguments
    args = read_args()
    args.test = True
    args.model = "mop"
    
    # Get the build function for the selected algorithm
    algorithm = args.hrl_algo
    if algorithm not in ALGORITHM_MAP:
        raise ValueError(f"Unknown algorithm: {algorithm}. Available algorithms: {list(ALGORITHM_MAP.keys())}")
    
    build_fn = ALGORITHM_MAP[algorithm]
    
    config = resolve_special_values(load_config(
        filepath="hrl/config.yml",
        algorithm=algorithm,
        config_name=args.hrl_config_name
    ))
    
    # Create a build function with parameters
    def configured_build_fn(vec_env, eval_vec_env, eval_freq, tensorboard_path='/tensorboard/'):
        return build_fn(vec_env, eval_vec_env, eval_freq, tensorboard_path, **config)
    
    # Setup logging
    setup_logging(args.log_level)
    torch.set_num_threads(torch.get_num_threads())
    
    if args.hrl_test:
        model_class = PPO if algorithm.startswith("ppo") else DQN
        run_test(args, model_class)
    else:
        exp_path, model_path, tensorboard_path = get_experiment_paths(args, algorithm)
        write_config(args, config, exp_path)
        train(args, configured_build_fn, args.hrl_timesteps, args.hrl_eval_freq, tensorboard_path, model_path)

     
