import argparse
import os
import pickle
import random
import numpy as np
from tqdm import tqdm
import metaworld
from stable_baselines3 import SAC
from metaworld.policies.sawyer_pick_place_v2_policy import SawyerPickPlaceV2Policy
from gym.wrappers import TimeLimit
import torch

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Collect data from MetaWorld with SAC advantage labels")
    parser.add_argument('--device', type=str, default='cuda', help="Device to run on (cuda/cpu)")
    parser.add_argument('--train_tasks', type=list, default=[1,2,3,4,5,6,7,8,9,
                                                            11,12,13,14,15,16,17,18,19,
                                                            21,22,23,24,25,26,27,28,29,
                                                            31,32,33,34,35,36,37,38,39,
                                                            41,42,43,44,45,46,47,48,49], help="Tasks to train on")
    parser.add_argument('--eval_tasks', type=list, default=[0,10,20,30,40], help="Tasks to eval on")
    parser.add_argument('--max-episode-steps', type=int, default=200, help="Max episode steps")
    parser.add_argument('--n-trajs', type=int, default=1000, help="Number of trajectories to collect for each task")
    parser.add_argument('--p-good', type=int, default=80, help="Good policy probability")
    parser.add_argument('--p-bad', type=int, default=20, help="Bad policy probability")
    return parser.parse_args()


def overload_env_step(self, action):
    """Overload environment step function to handle truncation."""
    truncated = False
    observation, reward, terminated, info = self.env.step(action)
    self._elapsed_steps += 1
    if self._elapsed_steps >= self._max_episode_steps:
        truncated = True
    done = terminated or truncated
    return observation, reward, done, info


def set_env(task_id, max_episode_steps):
    """Set up the MetaWorld environment."""
    ml1 = metaworld.ML1('pick-place-v2', seed=task_id) # NOTE: seed=task_id for reproducibility!!
    env = ml1.train_classes['pick-place-v2']()
    task = ml1.train_tasks[task_id]
    env.set_task(task)
    env.max_path_length = max_episode_steps
    env = TimeLimit(env, max_episode_steps=max_episode_steps)
    env.step = overload_env_step.__get__(env, env.__class__)
    return env


def load_sac_model(task_id, p_good, p_bad):
    ref_model_path = f"./ml1-pick-place-ckpts/task_{task_id}/task{task_id}_best"
    ref_model = SAC.load(ref_model_path)
    if p_good == 100:
        good_model_path = f"./ml1-pick-place-ckpts/task_{task_id}/task{task_id}_best"
    else:
        good_model_path = f"./ml1-pick-place-ckpts/task_{task_id}/task{task_id}_{p_good}"
    good_model = SAC.load(good_model_path)
    # bad_model_path = f"./ml1-pick-place-ckpts/task_{task_id}/task{task_id}_{p_bad}"
    # bad_model = SAC.load(bad_model_path)
    return ref_model, good_model


def rollin_mdp(env, ref_model, good_model, task_id):
    """
    Generate a trajectory by taking actions based on advantages.
    
    Parameters:
    -----------
    env: environment instance
    model: SAC model for value estimation
    deterministic_policy: optional deterministic policy for optimal actions
    p_good: probability of taking optimal action for the "good" policy
    p_bad: probability of taking optimal action for the "bad" policy
    
    Returns:
    --------
    trajectory: dictionary containing trajectory data
    """
    # np.random.seed(np.random.randint(0, 1000000))
    # Lists to save context information
    context_states = []
    context_next_states = []
    context_rewards = []
    context_actions = []
    optimal_actions = []
    # reset environment
    env.seed(task_id)
    env.reset()
    obs = env.reset()
    done = False
    # generate trajectory
    while not done:
        context_states.append(obs)
        # Generate actions and advantages
        action, _ = good_model.predict(obs, deterministic=False) # NOTE: deterministic=False for exploration!!
        optimal_action, _ = ref_model.predict(obs, deterministic=True)
        context_actions.append(action)
        optimal_actions.append(optimal_action)
        # Actually take the preferred action in the environment
        obs, reward, done, _ = env.step(action)
        context_rewards.append(reward)
        context_next_states.append(obs)
    # create trajectory dictionary
    traj_dict = {
        # context
        'context_states': np.array(context_states), # shape of [max_episode_steps, state_dim]
        'context_rewards': np.array(context_rewards), # shape of [max_episode_steps, 1]
        'context_next_states': np.array(context_next_states), # shape of [max_episode_steps, state_dim]
        'context_actions': np.array(context_actions), # shape of [max_episode_steps, action_dim]
        'optimal_actions': np.array(optimal_actions), # shape of [max_episode_steps, action_dim]
        'task_id': task_id,
    }
    return traj_dict


def generate_metaworld_histories(task_ids, max_episode_steps, n_trajs, p_good=0.7, p_bad=0.3):
    trajs = []
    for task_id in tqdm(task_ids, desc='Generating trajectories'):
        # Load SAC model
        ref_model, good_model = load_sac_model(task_id, p_good, p_bad)
        # Try loading deterministic policy if available
        # deterministic_policy = SawyerPickPlaceV2Policy()
        # Generate trajectories for this task
        env = set_env(task_id, max_episode_steps)
        for _ in range(n_trajs):
            trajs.append(rollin_mdp(env, ref_model, good_model, task_id))
    return trajs


def build_metaworld_data_filename(n_tasks, n_trajs, config, mode=0):
    mode_str = 'train' if mode == 0 else 'test' if mode == 1 else 'eval'
    filename = f'datasets_new/metaworld_DIT_tasks{n_tasks}_trajs{n_trajs}_p{config["p_good"]}_{mode_str}.pkl'
    return filename


def main():
    """Main function to collect data."""
    np.random.seed(42)
    random.seed(42)
    
    args = parse_args()
    print("Args: ", args)
    
    if not os.path.exists('datasets_new'):
        os.makedirs('datasets_new', exist_ok=True)
    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)        
    # Configuration
    config = {
        'n_trajs': args.n_trajs,
        'max_episode_steps': args.max_episode_steps,
        'p_good': args.p_good,
        'p_bad': args.p_bad,
    }

    train_task_ids = args.train_tasks
    test_task_ids = args.train_tasks # NOTE:for testing dataset used for training, we use the same tasks for training and testing
    eval_task_ids = args.eval_tasks
    
    # Generate train trajectories
    print("Generating training trajectories...")
    train_trajs = generate_metaworld_histories(
        train_task_ids,
        max_episode_steps=config['max_episode_steps'],
        n_trajs=config['n_trajs'],
        p_good=config['p_good'],
        p_bad=config['p_bad']
    )
    
    train_filepath = build_metaworld_data_filename(
        len(train_task_ids), config['n_trajs'], config, mode=0)
    with open(train_filepath, 'wb') as file:
        pickle.dump(train_trajs, file)            
    print(f"Saved to {train_filepath}.")
    
    # Generate test trajectories
    print("Generating test trajectories...")
    test_trajs = generate_metaworld_histories(
        test_task_ids,
        max_episode_steps=config['max_episode_steps'],
        n_trajs=10,
        p_good=config['p_good'],
        p_bad=config['p_bad']
    )
    
    test_filepath = build_metaworld_data_filename(
        len(test_task_ids), 10, config, mode=1)
    with open(test_filepath, 'wb') as file:
        pickle.dump(test_trajs, file) 
    print(f"Saved to {test_filepath}.")
    
    # Generate eval trajectories
    for p_good in [20, 40, 50, 60, 80, 100]:
        print(f"Generating evaluation trajectories for p_good={p_good}...")
        eval_trajs = generate_metaworld_histories(
            eval_task_ids,
            max_episode_steps=config['max_episode_steps'],
            n_trajs=50,
            p_good=p_good,
            p_bad=config['p_bad']
        )
        config['p_good'] = p_good
        eval_filepath = build_metaworld_data_filename(
            len(eval_task_ids), 50, config, mode=2)
        with open(eval_filepath, 'wb') as file:
            pickle.dump(eval_trajs, file)        
        print(f"Saved to {eval_filepath}.")

    


if __name__ == '__main__':
    main()