import argparse
import os
import pickle
import random
import numpy as np
from tqdm import tqdm
import common_args 
from easydict import EasyDict
from stable_baselines3 import SAC

def rollin_sac(env, model, args):
    states = []
    actions = []
    next_states = []
    rewards = []
    if args.env == 'ml1_pick_place':
        state = env.reset()
    else:   
        state, _ = env.reset()
    for _ in range(args.horizon):
        action, _ = model.predict(state, deterministic=True)
        next_state, reward, done, info = env.step(action)

        states.append(state)
        actions.append(action)
        next_states.append(next_state)
        rewards.append(reward)
        state = next_state

    states = np.array(states)
    actions = np.array(actions)
    next_states = np.array(next_states)
    rewards = np.array(rewards)

    return states, actions, next_states, rewards


def get_exp_rewards(env, model, state, step, args):
    gamma = args.gamma
    n_samples = args.n_samples
    n_cumu_rewards = []
    for _ in range(n_samples):
        rewards = []
        env.set_env_state(state) # FIXME: ml1-pick-place env does not have set_state method
        for _ in range(env.horizon-step):
            action, _ = model.predict(state, deterministic=True)
            next_state, reward, done, info = env.step(action)
            rewards.append(reward)
            state = next_state
        rewards = np.array(rewards)
        cumulative_reward = np.sum(rewards * gamma ** np.arange(len(rewards)))
        n_cumu_rewards.append(cumulative_reward)
    return np.mean(n_cumu_rewards)


def generate_traj(task_list, args, eval=False):
    trajs = []
    for _ in range(args.n_trails):
        for task_id in tqdm(task_list, desc='Generating data'):

            model_ckpt_path = f'{args.model_ckpt_path}/task_{task_id}'      
            model = SAC.load(f'{model_ckpt_path}/sac_checkpoint_task_{task_id}_{args.policy_quality}')

            if args.env == 'ml1_pick_place':
                ml1 = metaworld.ML1('pick-place-v2', seed=task_id)
                env = ml1.train_classes['pick-place-v2']()
                task = ml1.train_tasks[task_id]
                env.set_task(task)
                env.max_path_length = args.horizon
                opt_policy = SawyerPickPlaceV2Policy()

            elif args.env == 'cheetah_vel':
                task_path = f'{model_ckpt_path}/config_cheetah_vel_task{task_id}.pkl'
                tasks = []
                with open(task_path, 'rb') as f:
                    task_info = pickle.load(f)
                    assert len(task_info) == 1, f'Unexpected task info: {task_info}'
                    tasks.append(task_info[0])
                env = HalfCheetahVelEnv(tasks, include_goal = False)                    
                opt_policy = SAC.load(f'{model_ckpt_path}/sac_checkpoint_task_{task_id}_best')
            

            (
                context_states,
                context_actions,
                context_next_states,
                context_rewards,
            ) = rollin_sac(env, model, args)
            query_idx = np.random.choice(len(context_states))
            query_state = context_states[query_idx]
            if args.env == 'ml1_pick_place':
                optimal_action = opt_policy.get_action(query_state)
            elif args.env == 'cheetah_vel':
                optimal_action, _ = opt_policy.predict(query_state, deterministic=True)

            # if eval:
            #     cumulative_rewards = np.zeros(len(context_states))
            #     for i in range(len(context_states)):
            #         cumulative_rewards[i] = get_exp_rewards(env, model, context_states[i], i+1, args)
            # else:
                # cumulative_rewards = np.array([np.sum(context_rewards[i:] * args.gamma ** np.arange(len(context_rewards[i:]))) 
                #                 for i in range(len(context_rewards))])
            traj = {
                'query_idx': query_idx,
                'query_state': query_state,
                'optimal_action': optimal_action,
                'context_states': context_states,
                'context_actions': context_actions,
                'context_next_states': context_next_states,
                'context_rewards': context_rewards,
                # 'cumulative_rewards': cumulative_rewards,
                'task_id': task_id
            }
            trajs.append(traj)
    return trajs


if __name__ == '__main__':
    np.random.seed(0)
    random.seed(0)

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    args = EasyDict(vars(parser.parse_args()))
    
    if args.env == 'ml1_pick_place':
        import metaworld
        from metaworld.policies.sawyer_pick_place_v2_policy import SawyerPickPlaceV2Policy
        common_args.add_ml1_pick_place_dataset_args(parser)
    elif args.env == 'cheetah_vel':
        from envs.mujoco_control_envs.mujoco_control_envs import HalfCheetahVelEnv
        common_args.add_cheetah_vel_dataset_args(parser)
    
    args = EasyDict(vars(parser.parse_args()))
    print("Args: ", args)

    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)

    train_tasks = args.train_tasks
    test_tasks = args.test_tasks
    eval_tasks = train_tasks + test_tasks

    # args.n_trails = 1
    # test_traj = generate_traj(test_tasks, args)
    # test_filepath = common_args.build_data_filename('test', args)
    # with open(test_filepath, 'wb') as file:
    #     pickle.dump(test_traj, file)

    # args.n_trails = 4
    # eval_traj = generate_traj(eval_tasks, args, eval=True)
    # eval_filepath = common_args.build_data_filename('eval', args)
    # with open(eval_filepath, 'wb') as file:
    #     pickle.dump(eval_traj, file)

    args.n_trails = 2000
    train_traj = generate_traj(eval_tasks, args)
    train_filepath = common_args.build_data_filename('train', args)
    with open(train_filepath, 'wb') as file:
        pickle.dump(train_traj, file)

