import argparse
import os
import json
import gym
import torch
import numpy as np
import sys
import wandb
import yaml

current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(current_dir, '../../..'))

from opelab.core.baselines.simple import OnPolicy, WeightedIS, ISStepwise, WeightedISStepwise, WeightedISStepwiseV2
from opelab.core.baselines.model_based_rollout import MBR
from opelab.core.baselines.diffuser import Diffuser
from opelab.core.baselines.fqe import FQE 
from opelab.core.baselines.doubly_robust import DoublyRobustEstimator
from opelab.core.baselines.dice import BestDice
from opelab.core.baselines.pgd.pgdiffusion import PolicyGuidedDiffusion
from opelab.core.policy import D4RLPolicy, D4RLSACPolicy
from opelab.examples.helpers import evaluate_policies, create_baselines, evaluate_method_for_sweep

BASELINE_CLASSES = {
        "OnPolicy": OnPolicy,
        "WeightedIS": WeightedIS,
        "ISStepwise": ISStepwise,
        "WeightedISStepwise": WeightedISStepwiseV2,
        "Diffuser": Diffuser,
        "PolicyGuidedDiffusion": PolicyGuidedDiffusion,
        "MBR": MBR,
        "FQE": FQE,
        "BestDice": BestDice,
        "DoublyRobust": DoublyRobustEstimator
    }

def main(config_path, device):
    with open(config_path, 'r') as f:
        config = json.load(f)

    env_name = config["env_name"]
    guidance_hyperparams = config["guidance_hyperparams"]
    target_policy_paths = config["target_policy_paths"]
    baseline_configs = config["baseline_configs"]
    experiment_params = config["experiment_params"]

    env = gym.make(env_name)

    env_min = torch.tensor(env.action_space.low, dtype=torch.float32, device=device)
    env_max = torch.tensor(env.action_space.high, dtype=torch.float32, device=device)
    action_bounds = [env_min, env_max]

    behavior_policy = D4RLPolicy(env_name).to(device)

    target_policies = [
        D4RLSACPolicy(path).to(device) for path in target_policy_paths
    ]
    
    reward_fn, terminate_fn = get_environment_specific_functions(env_name)

    def compute_normalization(env_name):
        dataset = gym.make(env_name).get_dataset()
        observations = dataset['observations']
        actions = dataset['actions']
        mean_state, std_state = np.mean(observations, axis=0), np.std(observations, axis=0)
        mean_action, std_action = np.mean(actions, axis=0), np.std(actions, axis=0)
        mean = np.concatenate((mean_state, mean_action))
        std = np.concatenate((std_state, std_action))
        return mean, std

    mean, std = compute_normalization(env_name)
    mean = torch.tensor(mean).to(device)
    std = torch.tensor(std).to(device)
    normalize_fn = lambda x: (x - mean) / std
    unnormalize_fn = lambda x: x * std + mean

    sweep_configs = config['sweep_config_file']
    sweep_configuration = yaml.safe_load(open(sweep_configs, 'rb'))
    
    
    def sweep_fn():
        wandb.init(project='dice_sweep')
        mse, corr = eval_baseline(wandb.config)
        wandb.log({'mse': mse, 'correlation': corr})
    def eval_baseline(config):    
        #construct baseline given config file
        name_of_policy = baseline_configs['name']
        baseline = BASELINE_CLASSES[name_of_policy](**config)
        
        #run eval script
        mse, corr = evaluate_method_for_sweep(
            env=env,
            target_policies=target_policies,
            behavior_policy=behavior_policy,
            baseline=baseline,
            terminate_fn= terminate_fn,
            **experiment_params
        )
    
    sweep_id = wandb.sweep(sweep=sweep_configuration, project='dice_sweep')

    
    wandb.agent(sweep_id, function=sweep_fn, count=50)
    

    

def get_environment_specific_functions(env_name):
    """
    Returns environment-specific reward_fn and terminate_fn.
    """
    if "hopper" in env_name.lower():
        reward_fn = None
        
        def terminate_fn(state):
            state_np = state.cpu().numpy() if isinstance(state, torch.Tensor) else state
            height = state_np[0]
            ang = state_np[1]
            return not (
                np.isfinite(state_np).all()
                and (np.abs(state_np[2:]) < 100).all()
                and (height > 0.7)
                and (abs(ang) < 0.2)
            )
        return reward_fn, terminate_fn
    
    elif 'cheetah' in env_name.lower():
        def reward_fn(env, observation, action):
            env.reset()
            qpos_dim = env.model.nq
            qvel_dim = env.model.nv
            qpos = np.concatenate(([0], observation[:qpos_dim-1]))
            qvel = observation[qpos_dim-1:qpos_dim-1 + qvel_dim]
            env.set_state(qpos, qvel)
            _, reward, _, _ = env.step(action)
            return reward

        def terminate_fn(state):
            return False
        
        return reward_fn, terminate_fn

    elif "walker2d" in env_name.lower():
        def reward_fn(env, observation, action):
            env.reset()
            qpos_dim = env.model.nq
            qvel_dim = env.model.nv
            qpos = np.concatenate(([0], observation[:qpos_dim - 1]))
            qvel = observation[qpos_dim - 1 : qpos_dim - 1 + qvel_dim]
            env.set_state(qpos, qvel)
            _, reward, _, _ = env.step(action)
            return reward

        def terminate_fn(state):
            state_np = state.cpu().numpy() if isinstance(state, torch.Tensor) else state
            height = state_np[0] 
            ang = state_np[1]  
            return not (
                np.isfinite(state_np).all()
                and (height > 0.8)  
                and (height < 2.0)  
                and (abs(ang) < 1.0)  
            )
        
        return reward_fn, terminate_fn

    else:
        raise NotImplementedError(f"Environment {env_name} not supported.")
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path to config file")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the experiment on")
    args = parser.parse_args()
    device = args.device
    main(args.config, args.device)
