import argparse
import os
import pickle
import random
import numpy as np
from tqdm import tqdm
import metaworld
from stable_baselines3 import SAC
from metaworld.policies.sawyer_pick_place_v2_policy import SawyerPickPlaceV2Policy
from gym.wrappers import TimeLimit
import torch

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Collect data from MetaWorld with SAC advantage labels")
    parser.add_argument('--device', type=str, default='cuda', help="Device to run on (cuda/cpu)")
    parser.add_argument('--train_tasks', type=list, default=[1,2,3,4,5,6,7,8,9,
                                                            11,12,13,14,15,16,17,18,19,
                                                            21,22,23,24,25,26,27,28,29,
                                                            31,32,33,34,35,36,37,38,39,
                                                            41,42,43,44,45,46,47,48,49], help="Tasks to train on")
    parser.add_argument('--eval_tasks', type=list, default=[0,10,20,30,40], help="Tasks to eval on")
    parser.add_argument('--max-episode-steps', type=int, default=200, help="Max episode steps")
    parser.add_argument('--n-trajs', type=int, default=1000, help="Number of trajectories to collect for each task")
    parser.add_argument('--p-good', type=int, default=80, help="Good policy probability")
    parser.add_argument('--p-bad', type=int, default=20, help="Bad policy probability")
    return parser.parse_args()


def overload_env_step(self, action):
    """Overload environment step function to handle truncation."""
    truncated = False
    observation, reward, terminated, info = self.env.step(action)
    self._elapsed_steps += 1
    if self._elapsed_steps >= self._max_episode_steps:
        truncated = True
    done = terminated or truncated
    return observation, reward, done, info


def set_env(task_id, max_episode_steps):
    """Set up the MetaWorld environment."""
    ml1 = metaworld.ML1('pick-place-v2', seed=task_id) # NOTE: seed=task_id for reproducibility!!
    env = ml1.train_classes['pick-place-v2']()
    task = ml1.train_tasks[task_id]
    env.set_task(task)
    env.max_path_length = max_episode_steps
    env = TimeLimit(env, max_episode_steps=max_episode_steps)
    env.step = overload_env_step.__get__(env, env.__class__)
    return env


def load_sac_model(task_id, p_good, p_bad):
    ref_model_path = f"./ml1-pick-place-ckpts/task_{task_id}/task{task_id}_best"
    ref_model = SAC.load(ref_model_path)
    if p_good == 100:
        good_model_path = f"./ml1-pick-place-ckpts/task_{task_id}/task{task_id}_best"
    else:
        good_model_path = f"./ml1-pick-place-ckpts/task_{task_id}/task{task_id}_{p_good}"
    good_model = SAC.load(good_model_path)
    bad_model_path = f"./ml1-pick-place-ckpts/task_{task_id}/task{task_id}_{p_bad}"
    bad_model = SAC.load(bad_model_path)
    return ref_model, good_model, bad_model


def calculate_advantage(model, states, actions=None, n_samples=256):
    """
    Calculate the advantage of actions in states (batch version):
    A(s,a) = Q(s,a) - V(s)
    
    Where V(s) = E_{a~π(·|s)}[Q(s,a) - α log π(a|s)]
    
    Parameters:
    -----------
    model: SAC model
    states: Batch of states to evaluate [batch_size, state_dim]
    actions: Optional batch of actions to evaluate. If None, samples from policy
    n_samples: Number of samples to use for expectation calculation
    batch_size: Batch size for efficient processing
    
    Returns:
    --------
    advantages: The advantage values A(s,a) [batch_size]
    """
    # Check input dimensions and convert to tensor
    if len(states.shape) == 1:
        states = states.reshape(1, -1)  # Single state to batch
        if actions is not None:
            actions = actions.reshape(1, -1)  # Single action to batch
    state_tensor = torch.FloatTensor(states).to(model.device)
    with torch.no_grad():
        # If actions not provided, sample from policy
        if actions is None:
            actions, _ = model.actor.action_log_prob(state_tensor)
        else:
            actions = torch.FloatTensor(actions).to(model.device)
        # Calculate Q(s,a) for given state-action pairs
        q1, q2 = model.critic(state_tensor, actions)
        q_values = torch.min(q1, q2).squeeze()
        # Calculate V(s) by sampling and averaging
        # Repeat each state n_samples times
        repeated_states = state_tensor.repeat_interleave(n_samples, dim=0)
        # Get actions and log probs for repeated states
        sample_actions, log_probs = model.actor.action_log_prob(repeated_states)
        # Get Q-values for these state-action pairs
        sample_q1, sample_q2 = model.critic(repeated_states, sample_actions)
        sample_q = torch.min(sample_q1, sample_q2).squeeze()
        # Calculate value term: Q(s,a) - α log π(a|s)
        alpha = torch.exp(model.log_ent_coef)
        values = sample_q - alpha * log_probs
        # Reshape to [batch_size, n_samples] and average across samples
        v_values = values.mean()
        # Calculate advantages
        advantages = q_values - v_values
    return advantages.item()


def yield_rollin_mdp(obs, anchor, policy):
    """
    Generate an action and calculate its advantage.
    
    Parameters:
    -----------
    obs: current observation/state
    anchor: SAC model for value estimation
    policy: SAC model for perform action
    
    Returns:
    --------
    action: action taken
    advantage: advantage of the action
    """
    action, _ = policy.predict(obs, deterministic=False) # NOTE: deterministic=False for exploration!!
    advantage = calculate_advantage(anchor, obs, action)
    return action, advantage


def rollin_mdp(env, ref_model, good_model, bad_model, task_id):
    """
    Generate a trajectory by taking actions based on advantages.
    
    Parameters:
    -----------
    env: environment instance
    model: SAC model for value estimation
    deterministic_policy: optional deterministic policy for optimal actions
    p_good: probability of taking optimal action for the "good" policy
    p_bad: probability of taking optimal action for the "bad" policy
    
    Returns:
    --------
    trajectory: dictionary containing trajectory data
    """
    # np.random.seed(np.random.randint(0, 1000000))
    # Lists to save context information
    context_states = []
    context_next_states = []
    context_rewards = []
    # preferred action
    prefer = []
    preferred_actions = []
    non_preferred_actions = []
    optimal_actions = []
    # good and bad actions and advantages
    good_actions = []
    good_advantages = []
    bad_actions = []
    bad_advantages = []
    # reset environment
    env.seed(task_id)
    env.reset()
    obs = env.reset()
    done = False
    # generate trajectory
    while not done:
        context_states.append(obs)
        # Generate actions and advantages
        action_good, advantage_good = yield_rollin_mdp(obs, ref_model, good_model)
        action_bad, advantage_bad = yield_rollin_mdp(obs, ref_model, bad_model)
        optimal_action, _ = ref_model.predict(obs, deterministic=True)
        optimal_actions.append(optimal_action)
        good_actions.append(action_good)
        bad_actions.append(action_bad)
        good_advantages.append(advantage_good)
        bad_advantages.append(advantage_bad)
        # sample from the softmax distribution of good and bad advantages
        softmax = np.exp(advantage_good) / np.sum(np.exp(advantage_good) + np.exp(advantage_bad))
        choice = np.random.choice(range(2), p=[1-softmax, softmax])
        prefer.append(choice)
        # Choose preferred action based on advantage
        action = action_good if choice == 1 else action_bad
        preferred_actions.append(action)
        non_preferred_actions.append(action_bad if choice == 1 else action_good)
        # Actually take the preferred action in the environment
        obs, reward, done, _ = env.step(action)
        context_rewards.append(reward)
        context_next_states.append(obs)
    # create trajectory dictionary
    traj_dict = {
        # context
        'context_states': np.array(context_states),
        'context_rewards': np.array(context_rewards),
        'context_next_states': np.array(context_next_states),
        # preferred action
        'prefer': np.array(prefer),
        'preferred_actions': np.array(preferred_actions),
        'non_preferred_actions': np.array(non_preferred_actions),
        'optimal_actions': np.array(optimal_actions),
        # NOTE: we store the actions for the good and bad policies separately
        'good_actions': np.array(good_actions),
        'good_advantages': np.array(good_advantages),
        'bad_actions': np.array(bad_actions),
        'bad_advantages': np.array(bad_advantages),
        # task id
        'task_id': task_id,
    }
    return traj_dict


def generate_metaworld_histories(task_ids, max_episode_steps, n_trajs, p_good=0.7, p_bad=0.3):
    trajs = []
    for task_id in tqdm(task_ids, desc='Generating trajectories'):
        # Load SAC model
        ref_model, good_model, bad_model = load_sac_model(task_id, p_good, p_bad)
        # Try loading deterministic policy if available
        # deterministic_policy = SawyerPickPlaceV2Policy()
        # Generate trajectories for this task
        env = set_env(task_id, max_episode_steps)
        for _ in range(n_trajs):
            trajs.append(rollin_mdp(env, ref_model, good_model, bad_model, task_id))
    return trajs


def build_metaworld_data_filename(n_tasks, n_trajs, config, mode=0):
    mode_str = 'train' if mode == 0 else 'test' if mode == 1 else 'eval'
    filename = f'datasets_new/metaworld_tasks{n_tasks}_trajs{n_trajs}_pg{config["p_good"]}_pb{config["p_bad"]}_{mode_str}.pkl'
    return filename


def main():
    """Main function to collect data."""
    np.random.seed(42)
    random.seed(42)
    
    args = parse_args()
    print("Args: ", args)
    
    if not os.path.exists('datasets_new'):
        os.makedirs('datasets_new', exist_ok=True)
    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)   
        
    # Configuration
    config = {
        'n_trajs': args.n_trajs,
        'max_episode_steps': args.max_episode_steps,
        'p_good': args.p_good,
        'p_bad': args.p_bad,
    }

    train_task_ids = args.train_tasks
    test_task_ids = args.train_tasks # NOTE:for testing dataset used for training, we use the same tasks for training and testing
    eval_task_ids = args.eval_tasks
    
    # Generate train trajectories
    print("Generating training trajectories...")
    train_trajs = generate_metaworld_histories(
        train_task_ids,
        max_episode_steps=config['max_episode_steps'],
        n_trajs=config['n_trajs'],
        p_good=config['p_good'],
        p_bad=config['p_bad']
    )
    
    train_filepath = build_metaworld_data_filename(
        len(train_task_ids), config['n_trajs'], config, mode=0)
    with open(train_filepath, 'wb') as file:
        pickle.dump(train_trajs, file)            
    print(f"Saved to {train_filepath}.")
    
    # Generate test trajectories
    print("Generating test trajectories...")
    test_trajs = generate_metaworld_histories(
        test_task_ids,
        max_episode_steps=config['max_episode_steps'],
        n_trajs=10,
        p_good=config['p_good'],
        p_bad=config['p_bad']
    )
    
    test_filepath = build_metaworld_data_filename(
        len(test_task_ids), 10, config, mode=1)
    with open(test_filepath, 'wb') as file:
        pickle.dump(test_trajs, file) 
    print(f"Saved to {test_filepath}.")
    
    # Generate eval trajectories
    for p_good in [20, 40, 50, 60, 80, 100]:
        print(f"Generating evaluation trajectories for p_good={p_good}...")
        eval_trajs = generate_metaworld_histories(
            eval_task_ids,
            max_episode_steps=config['max_episode_steps'],
            n_trajs=50,
            p_good=p_good,
            p_bad=config['p_bad']
        )
        config['p_good'] = p_good
        eval_filepath = build_metaworld_data_filename(
            len(eval_task_ids), 50, config, mode=2)
        with open(eval_filepath, 'wb') as file:
            pickle.dump(eval_trajs, file)        
        print(f"Saved to {eval_filepath}.")


if __name__ == '__main__':
    main()