import argparse
import os
import pickle
import random
import numpy as np
from tqdm import tqdm
import common_args 
from easydict import EasyDict
from stable_baselines3 import SAC

def perturb_state(
    state,
    eps_ee=0.01,       # end-effector position
    eps_obj=0.005,     # object position
    eps_quat=0.01,     # object orientation (quaternion)
    eps_goal=0.005     # goal position
):
    """
    Apply bounded perturbation to selected dimensions of a Meta-World state vector.

    Args:
        state: np.ndarray, shape (39,)
        eps_*: float, maximum perturbation magnitude for each group

    Returns:
        perturbed_state: np.ndarray, shape (39,)
    """
    state = state.copy()
    rng = np.random.default_rng()  # modern numpy RNG

    # End-effector position [0:3]
    state[0:3] += rng.uniform(-eps_ee, eps_ee, size=3)
    state[0:3] = np.clip(state[0:3], -0.2, 0.2)

    # Object position [7:10]
    state[7:10] += rng.uniform(-eps_obj, eps_obj, size=3)
    state[7:10] = np.clip(state[7:10], 0.0, 0.2)

    # Object orientation (quaternion) [10:14]
    state[10:14] += rng.uniform(-eps_quat, eps_quat, size=4)
    state[10:14] /= (np.linalg.norm(state[10:14]) + 1e-8)

    # Goal position [35:38]
    state[35:38] += rng.uniform(-eps_goal, eps_goal, size=3)
    state[35:38] = np.clip(state[35:38], 0.0, 0.2)

    return state

def rollin_sac(env, model, args):
    states = []
    actions = []
    next_states = []
    rewards = []

    if args.env == 'ml1_pick_place':
        state = env.reset()  # Meta-World is Gym v0.21 style, returns just obs
    else:
        state, _ = env.reset()  # In newer Gym, use: state, info = env.reset()
    for _ in range(args.horizon):
        if isinstance(state, tuple):  # <-- Fix here
            state = state[0]  # extract observation only

        action, _ = model.predict(state, deterministic=True)
        next_state, reward, done, _, info = env.step(action)




        if isinstance(next_state, tuple):
            next_state = next_state[0]

        next_state = perturb_state(next_state)


        states.append(state)
        actions.append(action)
        next_states.append(next_state)
        rewards.append(reward)

        state = next_state

    return np.array(states), np.array(actions), np.array(next_states), np.array(rewards)
   

def get_exp_rewards(env, model, state, step, args):
    gamma = args.gamma
    n_samples = args.n_samples
    n_cumu_rewards = []
    for _ in range(n_samples):
        rewards = []
        env.set_env_state(state) # FIXME: ml1-pick-place env does not have set_state method
        for _ in range(env.horizon-step):
            action, _ = model.predict(state, deterministic=True)
            next_state, reward, done, info = env.step(action)
            rewards.append(reward)
            state = next_state
        rewards = np.array(rewards)
        cumulative_reward = np.sum(rewards * gamma ** np.arange(len(rewards)))
        n_cumu_rewards.append(cumulative_reward)
    return np.mean(n_cumu_rewards)


def generate_traj(task_list, args, eval=False):
    trajs = []
    for _ in range(args.n_trails):
        for task_id in tqdm(task_list, desc='Generating data'):

            
            if args.env == 'ml1_pick_place':
                ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
                env = ml1.train_classes['pick-place-v2']()
                task = ml1.train_tasks[task_id]
                env.set_task(task)
                env.max_path_length = args.horizon
                opt_policy = SawyerPickPlaceV2Policy()

                model_ckpt_path = f'{args.model_ckpt_path}/task_{task_id}'
                # model = SAC.load(f'{model_ckpt_path}/task{task_id}_best')

                model = SAC.load(
                f'ml1-pick-place-ckpts/task_{task_id}/task{task_id}_best',
                custom_objects={
                    "observation_space": env.observation_space,
                    "action_space": env.action_space
                }
            )
            

            (
                context_states,
                context_actions,
                context_next_states,
                context_rewards,
            ) = rollin_sac(env, model, args)
            query_idx = np.random.choice(len(context_states))
            query_state = context_states[query_idx]
       
            optimal_action = opt_policy.get_action(query_state)

            print('task_id: ', task_id)
   
            traj = {
                'query_idx': query_idx,
                'query_state': query_state,
                'optimal_action': optimal_action,
                'context_states': context_states,
                'context_actions': context_actions,
                'context_next_states': context_next_states,
                'context_rewards': context_rewards,
                # 'cumulative_rewards': cumulative_rewards,
                'task_id': task_id
            }
            trajs.append(traj)
    return trajs


if __name__ == '__main__':
    np.random.seed(0)
    random.seed(0)

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    
    args = EasyDict(vars(parser.parse_args()))
    
    if args.env == 'ml1_pick_place':
        import metaworld
        from metaworld.policies.sawyer_pick_place_v2_policy import SawyerPickPlaceV2Policy
        common_args.add_ml1_pick_place_dataset_args(parser)

    else:
        raise NotImplementedError

    
    args = EasyDict(vars(parser.parse_args()))
    print("Args: ", args)

    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)

    train_tasks = args.train_tasks
    test_tasks = args.test_tasks
    eval_tasks = train_tasks + test_tasks

   

    args.n_trails = 2

 
    train_traj = generate_traj(train_tasks, args)
    train_filepath = common_args.build_data_filename('train', args)
    os.makedirs(os.path.dirname(train_filepath), exist_ok=True)    
    with open(train_filepath, 'wb') as file:
        pickle.dump(train_traj, file)

    # test_traj = generate_traj(test_tasks, args)
    # test_filepath = common_args.build_data_filename('test', args)
    # os.makedirs(os.path.dirname(test_filepath), exist_ok=True)
    # with open(test_filepath, 'wb') as file:
    #     pickle.dump(test_traj, file)

