"""Transfer experiment script for evaluating reward models and imitation policies under perturbed environments.

Two modes:
1. Reward model path (--reward-model-path): Load reward model, retrain PPO on perturbed env, evaluate
2. Imitation model path (--imitation-model-path): Load Q-value model from checkpoint, evaluate directly

Usage:
    # Mode 1: Evaluate reward model transfer
    python transfer.py --env-id LunarLander-v3 --env-params '{"wind_power": 10.0}' \
        --reward-model-path /path/to/best_model.pt \
        --expert-policy-path ~/umfavi/expert_policies/ppo/LunarLander-v3_1/best_model.zip
    
    # Mode 2: Evaluate imitation learning transfer
    python transfer.py --env-id LunarLander-v3 --env-params '{"wind_power": 10.0}' \
        --imitation-model-path /path/to/best_model.pt \
        --expert-policy-path ~/umfavi/expert_policies/ppo/LunarLander-v3_1/best_model.zip
"""

import argparse
import json
from pathlib import Path
from typing import Optional, Callable
import gymnasium as gym
import torch
import wandb

from umfavi.encoder.feature_modules import MLPFeatureModule
from umfavi.encoder.reward_encoder import RewardEncoder
from umfavi.evaluation.regret import compute_regret
from umfavi.multi_fb_model import MultiFeedbackTypeModel
from umfavi.loglikelihoods.make_nll import make_nll
from umfavi.types import FeedbackType
from umfavi.utils.feature_transforms import get_action_transform, get_observation_transform
from umfavi.utils.gym import get_act_dim, get_obs_dim
from umfavi.utils.reproducibility import seed_everything
from umfavi.envs.make_env import make_env
from umfavi.utils.policies import create_policy
from umfavi.learned_reward_wrapper import LearnedRewardWrapper
from umfavi.envs.env_types import TabularEnv



def load_checkpoint(checkpoint_path: Path) -> dict:
    """Load a checkpoint file and return its contents."""
    checkpoint_path = checkpoint_path.expanduser()
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    return torch.load(checkpoint_path, map_location='cpu', weights_only=False)


def reconstruct_multi_fb_model(
    checkpoint: dict,
    env: gym.Env,
    act_transform: Optional[Callable] = None,
    obs_transform: Optional[Callable] = None,
    device: str = "cpu",
) -> MultiFeedbackTypeModel:
    """
    Reconstruct the full MultiFeedbackTypeModel from checkpoint data.
    
    This mirrors the model construction in train.py and loads the full
    model state dict from checkpoint["model_state_dict"].
    
    Args:
        checkpoint: Loaded checkpoint dictionary (from load_checkpoint)
        env: Environment to use for getting dimensions
        act_transform: Action transform
        obs_transform: Observation transform
        device: Device to load model on
        
    Returns:
        Reconstructed MultiFeedbackTypeModel with loaded weights
    """
    # Get dimensions
    obs_dim = get_obs_dim(env, obs_transform)
    act_dim = get_act_dim(env, act_transform)
    
    # Extract training args from checkpoint
    train_args = checkpoint.get("args", {})
    if not train_args:
        raise ValueError("Checkpoint does not contain 'args'. Cannot reconstruct model architecture.")
    
    # Get architecture parameters from saved args
    encoder_hidden_sizes = train_args.get("encoder_hidden_sizes", [256, 256])
    reward_domain = train_args.get("reward_domain", "sa")
    
    # Determine active feedback types
    feedback_config = {
        FeedbackType.PREFERENCE: train_args.get("n_pref_samples", 0),
        FeedbackType.DEMONSTRATION: train_args.get("n_demo_samples", 0),
        FeedbackType.RATING: train_args.get("n_rating_samples", 0),
    }
    active_feedback_types = [fb_type for fb_type, n_samples in feedback_config.items() if n_samples > 0]
    
    if not active_feedback_types:
        raise ValueError("No active feedback types found in checkpoint args. Cannot reconstruct decoders.")
    
    # Check if actions are discrete
    actions_discrete = isinstance(env.action_space, gym.spaces.Discrete)
    
    # Create encoder
    feature_module = MLPFeatureModule(
        obs_dim,
        act_dim,
        encoder_hidden_sizes,
        reward_domain=reward_domain
    )
    reward_encoder = RewardEncoder(feature_module)
    
    # Create Q-value model
    q_model = MLPFeatureModule(
        state_dim=obs_dim,
        action_dim=None if actions_discrete else act_dim,
        hidden_sizes=encoder_hidden_sizes + [act_dim] if actions_discrete else encoder_hidden_sizes + [1],
        reward_domain='s' if actions_discrete else 'sa',
        activate_last_layer=False
    )
    
    # Create decoders
    decoders = {fb_type: make_nll(fb_type, actions_discrete=actions_discrete) for fb_type in active_feedback_types}
    
    # Create multi-feedback model
    fb_model = MultiFeedbackTypeModel(
        encoder=reward_encoder,
        q_model=q_model,
        decoders=decoders,
        actions_discrete=actions_discrete
    )
    fb_model.to(device)
    
    # Load state dict
    fb_model.load_state_dict(checkpoint["model_state_dict"])
    
    return fb_model


def make_perturbed_env(env_id: str, env_params: dict, seed: int = 0) -> gym.Env:
    """
    Create a gymnasium environment with perturbed parameters.
    
    Args:
        env_id: Environment ID (e.g., "LunarLander-v3")
        env_params: Dictionary of environment parameters to override
        seed: Random seed for environment
        
    Returns:
        Gymnasium environment with specified parameters
    """
    if "acrobot" in env_id.lower():
        from umfavi.envs.wrappers.acrobot_transfer_env import AcrobotTransferEnv
        base_env = gym.make(env_id)
        env = AcrobotTransferEnv(base_env, **env_params)
        env.reset(seed=seed)
    elif env_id.startswith("grid"):
        # Handle custom grid environments (grid_cliff, grid_sparse, grid_trap, etc.)
        from umfavi.envs.grid_env.env import GridEnv
        rew_type = env_id.split("_")[1]
        grid_params = {**env_params, "reward_type": rew_type, "seed": seed}
        env = GridEnv(**grid_params)
    else:
        env = gym.make(env_id, render_mode=None, **{k: v for k, v in env_params.items() if k not in ["env_id", "seed"]})
        env.reset(seed=seed)

    return env


def run_transfer(args: argparse.Namespace) -> dict[str, float]:
    """
    Run transfer experiment based on provided arguments.
    
    Args:
        args: Parsed command line arguments
        
    Returns:
        Dictionary with evaluation results
    """
    seed_everything(args.seed)
    
    # Parse environment parameters
    env_params = {}
    if args.env_params:
        env_params = json.loads(args.env_params)
    
    # Initialize W&B if enabled
    if args.log_wandb:
        wandb_config = {
            "env_id": args.env_id,
            "env_params": env_params,
            "fb_model_path": args.fb_model_path,
            "optimal_policy_path": args.optimal_policy_path,
            "gamma": args.gamma,
            "num_samples": args.num_samples,
            "max_num_steps": args.max_num_steps,
            "seed": args.seed,
            "act_transform": args.act_transform,
            "obs_transform": args.obs_transform,
            "mode": args.mode,
        }
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=args.wandb_name,
            config=wandb_config,
        )

    # Load checkpoint
    checkpoint = load_checkpoint(Path(args.fb_model_path))
    
    # Parse feature transforms (seed env for reproducible dimension calculation)
    base_env = make_env(**(env_params | {"env_id": args.env_id, "seed": args.seed}))

    act_transform = get_action_transform(args, base_env)
    obs_transform = get_observation_transform(args, base_env)
    
    # Reconstruct the full model from checkpoint
    fb_model = reconstruct_multi_fb_model(
        checkpoint=checkpoint,
        env=base_env,
        act_transform=act_transform,
        obs_transform=obs_transform,
        device="cpu",
    )
    fb_model.eval()

    make_perturbed_env_fn = lambda: make_perturbed_env(args.env_id, env_params, seed=args.seed)

    def make_wrapped_env(seed: int):
        # Create a fresh environment instance each time for reproducibility
        perturbed_env_instance = make_perturbed_env_fn()
        return LearnedRewardWrapper(
            perturbed_env_instance,
            fb_model.encoder,
            seed=seed,
            act_transform=act_transform,
            obs_transform=obs_transform,
        )

    is_tabular = isinstance(base_env, TabularEnv)
    true_optimal_policy = None  # in the tabular setting we don't need a reference policy
    if not is_tabular:
        true_optimal_policy = create_policy(args.optimal_policy_path, float("inf"), env=base_env, gamma=args.gamma)
    regret, mean_rew, discounted_value, _ = compute_regret(
        true_optimal_policy=true_optimal_policy,
        train_env_fn=make_wrapped_env,
        eval_env_fn=make_perturbed_env_fn,
        is_tabular=is_tabular,
        is_imitation=args.mode == "imitation",
        fb_model=fb_model,
        gamma=args.gamma,
        n_regret_samples=args.num_samples,
        max_num_steps=args.max_num_steps,
        seed_fn=lambda i: args.seed * args.num_samples + i,  # ensure distinct results across seeds
        true_reward_threshold=args.retrain_reward_thresh,
        verbose=args.retrain_verbose,
        progress_bar=not args.no_progress_bar,
        ppo_seed=args.seed,
        reference_env_name=args.env_id,
    )
    results = {
        "regret": regret,
        "mean_reward": mean_rew,
        "discounted_value": discounted_value,
    }

    # Print results
    print(f"\nResults:")
    print(f"  Regret: {results['regret']:.4f}")
    if results["mean_reward"]:
        print(f"  Mean reward: {results['mean_reward']:.4f}")

    # Log results to W&B if enabled
    if args.log_wandb:
        wandb.log(results)
        wandb.finish()

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Transfer experiment: evaluate reward models or imitation policies under perturbed environments"
    )
    
    # Environment configuration
    parser.add_argument("--env_id", type=str, required=True, help="Base environment ID (e.g., LunarLander-v3)")
    parser.add_argument("--reward_domain", type=str, default="sa", help="Reward domain (s or sa)")
    parser.add_argument("--env_params", type=str,default="{}", help="JSON dict of environment parameters to override (e.g., '{\"wind_power\": 10.0}')")
    
    # Model paths (mutually exclusive modes)
    model_group = parser.add_mutually_exclusive_group(required=True)
    model_group.add_argument("--fb_model_path", type=str, help="Path to feedback model checkpoint (.pt file)")

    # Mode
    parser.add_argument("--mode", type=str, default="reward_model", choices=["reward_model", "imitation"], help="Mode to run the transfer experiment in (default: reward_model)")

    # Model coniguration
    parser.add_argument("--hidden_sizes", type=int, nargs="+", default=[256, 256], help="Hidden sizes for the encoder MLP (default: [256, 256])")
    
    # Expert policy for regret computation
    parser.add_argument("--optimal_policy_path", type=str, required=True, help="Path to optimal policy for regret computation")
    
    # Evaluation parameters
    parser.add_argument("--num_samples", type=int, default=300, help="Number of MC samples for regret estimation (default: 100)")
    parser.add_argument("--max_num_steps", type=int, default=1000, help="Maximum steps per episode (default: 1000)")
    parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor (default: 0.99)")
    
    # Reproducibility
    parser.add_argument("--seed", type=int, default=0, help="Random seed (default: 0)")
    
    # Training verbosity (for Mode 1)
    parser.add_argument("--retrain_verbose", type=int, default=1, help="Verbosity level for PPO retraining (default: 1)")
    parser.add_argument("--no_progress_bar", action="store_true", help="Disable progress bar during PPO retraining")
    parser.add_argument("--retrain_reward_thresh", type=float, default=None, help="True reward threshold for early stopping during PPO retraining (default: None)")

    # Feature transforms
    parser.add_argument("--act_transform", type=str, default=None, help="Action transform to apply (default: None)")
    parser.add_argument("--obs_transform", type=str, default=None, help="Observation transform to apply (default: None)")

    # Weights and Biases logging
    parser.add_argument("--log_wandb", action="store_true", help="Enable Weights and Biases logging")
    parser.add_argument("--wandb_project", type=str, default="umfavi-transfer", help="W&B project name (default: umfavi-transfer)")
    parser.add_argument("--wandb_entity", type=str, default=None, help="W&B entity/team name (default: None)")
    parser.add_argument("--wandb_name", type=str, default=None, help="W&B run name (default: auto-generated)")

    args = parser.parse_args()
    run_transfer(args)


if __name__ == "__main__":
    main()
