import argparse
import os
import pickle
import random
import glob


# import gym
import gymnasium as gym
import numpy as np
from skimage.transform import resize
from IPython import embed

import common_args
from envs import darkroom_env, bandit_env, miniworld_env
from ctrls.ctrl_bandit import ThompsonSamplingPolicy
from evals import eval_bandit
from utils import (
    build_bandit_data_filename,
    build_linear_bandit_data_filename,
    build_darkroom_data_filename,
    build_miniworld_data_filename,
    build_overcooked_data_filename,
)

from configs.config import get_config
from configs.overcooked_config import get_overcooked_args

from envs.env_wrappers import ShareDummyVecEnv, ShareSubprocDummyBatchVecEnv
from envs.overcooked.Overcooked_Env import Overcooked
from envs.overcooked_new.Overcooked_Env import Overcooked as Overcooked_new

from pyvirtualdisplay import Display
display = Display(visible=False, size=(1024,768))
display.start()

CTX_ROLLOUTS = 5
MASK_ROLLOUT = True
NUM_QUERY = 6
STORAGE_PREFIX = '' 
NUM_TRAIN_ROLL = 150
NUM_EVAL_ROLL = 50

def parse_args(args, parser):
    parser = get_overcooked_args(parser)
    all_args = parser.parse_args(args)
    from zsceval.overcooked_config import OLD_LAYOUTS

    if all_args.layout_name in OLD_LAYOUTS:
        all_args.old_dynamics = True
    else:
        all_args.old_dynamics = False
    return all_args

def action2num(action):
    if action == "interact":
        output = [0,0,0,0,0,1]
    elif action == (0,-1): # North
        output = [1,0,0,0,0,0]
    elif action == (0, 1): # Sorth
        output = [0,1,0,0,0,0]
    elif action == (1, 0): # East
        output = [0,0,1,0,0,0]
    elif action == (-1,0): # West
        output = [0,0,0,1,0,0]
    elif action == (0,0):
        output = [0,0,0,0,1,0]
    else:
        raise ValueError("unidentified action!")
    return output

def get_obs(context_trajs):
    episode_length = (len(context_trajs)-1)//4
    context_states = []
    context_actions = []
    context_rewards = []
    context_next_states = []    
    for i in range(episode_length):
        # get state
        context_state = context_trajs[i*4][1]
        context_states.append(context_state)

        # get action 
        agent1_action = context_trajs[i*4+1][1]
        agent1_action = action2num(agent1_action)
        # Can't not read partner's action
        # agent0_action = context_trajs[i*4+1][0]
        # agent0_action = action2num(agent1_action)
        context_actions.append(np.array(agent1_action))

        # get reward
        agent1_reward = context_trajs[i*4+2]
        # Can't read partner's reward
        # agent1_reward = context_trajs[i*4+2][1]
        context_reward = [agent1_reward]
        context_rewards.append(np.array(context_reward))

        # get next state
        context_next_state = context_trajs[i*4+4][1]
        context_next_states.append(context_next_state)
        
    return context_states, context_actions, context_rewards, context_next_states


def generate_overcooked_histories(agent_ids, n_hists, n_samples, layout, exp_name, oc_env, agent_group, agent_id_start=0, bias_agent_len=0, rollin_type='expert'):
    import pickle
    import re

    n_agents = len(agent_ids)
    ctx_rollouts = CTX_ROLLOUTS
    trajs = []
    trajs_query_states = []
    trajs_optimal_actions = []
    trajs_context_actions = []
    trajs_context_states = []
    trajs_context_rewards = []
    horizon = oc_env.episode_length
    agent_type = "hsp"

    for i, agent_id in enumerate(agent_ids):
        # Construct the path to the trajectory files
        if agent_group == "train":
            if i+agent_id_start-1 < bias_agent_len:
                traj_hsp_none_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout}/mappo/{exp_name}/hsp_final_none_{NUM_TRAIN_ROLL}/run_{agent_id:02d}/trajs/"
                traj_hsp_mid_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout}/mappo/{exp_name}/hsp_mid_none_{NUM_TRAIN_ROLL}/run_{agent_id:02d}/trajs/"
                
                traj_hsp_none_files = glob.glob(traj_hsp_none_path + "traj_*.pkl")
                # traj_hsp_sn_files = glob.glob(traj_hsp_sn_path + "traj_*.pkl")
                # traj_hsp_mn_files = glob.glob(traj_hsp_mn_path + "traj_*.pkl")
                traj_hsp_mid_files = glob.glob(traj_hsp_mid_path + "traj_*.pkl")

            elif i+agent_id_start-1 >= bias_agent_len:
                agent_type = "mep"
                # MEP
                traj_mep_none_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout}/mappo/{exp_name}/mep_final_none_{NUM_TRAIN_ROLL}/run_{agent_id:02d}/trajs/"
                traj_mep_mid_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout}/mappo/{exp_name}/mep_mid_none_{NUM_TRAIN_ROLL}/run_{agent_id:02d}/trajs/"
                

                # traj_files = glob.glob(traj_path + "traj_*.pkl")
                traj_mep_none_files = glob.glob(traj_mep_none_path + "traj_*.pkl")
                traj_mep_mid_files = glob.glob(traj_mep_mid_path + "traj_*.pkl")
        elif agent_group == "test":
            traj_eval_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout}/mappo/{exp_name}/hsp_eval_none_{NUM_EVAL_ROLL}/run_{agent_id:02d}/trajs/"
            traj_eval_files = glob.glob(traj_eval_path + "traj_*.pkl")

        
        print(f"Generating {agent_group} histories for agent {i+1}/{n_agents}")

        for j in range(n_hists):
            ctx_states = []
            ctx_actions = []
            ctx_rewards = []
            ctx_next_states = []
            context_traj_files = []
            # generate 3 different type of context window
            ind_list = [random.randint(0, 149) for i in range(ctx_rollouts)]
            
            if agent_type == "hsp" and agent_group == "train":
                hsp_traj_type = [traj_hsp_mid_files, traj_hsp_none_files]
                hsp_traj_type_dist = [0.25, 0.75]
            elif agent_type == "mep" and agent_group=="train":
                mep_traj_type = [traj_mep_mid_files, traj_mep_none_files]
                mep_traj_type_dist = [0.25, 0.75]
            
            if agent_group == "train":
                for ctx in range(CTX_ROLLOUTS):
                    if agent_type == "hsp":
                        context_traj_file_type = random.choices(hsp_traj_type, weights=hsp_traj_type_dist, k=1)[0]
                        context_traj_files.append(context_traj_file_type[ind_list[ctx]])
                    elif agent_type == "mep":
                        context_traj_file_type = random.choices(mep_traj_type, weights=mep_traj_type_dist, k=1)[0]
                        context_traj_files.append(context_traj_file_type[ind_list[ctx]])
            elif agent_group == "test":
                ind_list = [random.randint(0, 49) for _ in range(5)]
                for c in range(5):
                    context_traj_files.append(traj_eval_files[ind_list[c]])
                    
            # Adaptive context rollouts
            if MASK_ROLLOUT:
                exponenet = 1.5
                p = [1/(i+1)**exponenet for i in range(ctx_rollouts+1)]
                prob = np.array(p) / sum(p)
                num_mask_rollout = np.random.choice(ctx_rollouts+1, p=prob)

            else:
                num_mask_rollout = 0

            for k in range(ctx_rollouts):
                if k < num_mask_rollout:
                    # TODO: change to proper shape for layout
                    if layout == "random1_m" or layout == "random0_m":
                        context_states = [np.zeros((5,5,25)) for i in range(horizon)]
                        context_next_states = [np.zeros((5,5,25)) for i in range(horizon)]
                        context_actions = [np.zeros(6) for i in range(horizon)]
                        context_rewards = [np.zeros(1) for i in range(horizon)]
                    elif layout == "random1":
                        context_states = [np.zeros((5,5,20)) for i in range(horizon)]
                        context_next_states = [np.zeros((5,5,20)) for i in range(horizon)]
                        context_actions = [np.zeros(6) for i in range(horizon)]
                        context_rewards = [np.zeros(1) for i in range(horizon)]
                    elif layout == "random0_medium":
                        context_states = [np.zeros((8,5,20)) for i in range(horizon)]
                        context_next_states = [np.zeros((8,5,20)) for i in range(horizon)]
                        context_actions = [np.zeros(6) for i in range(horizon)]
                        context_rewards = [np.zeros(1) for i in range(horizon)]
                    elif layout == "random0":
                        context_states = [np.zeros((5,5,20)) for i in range(horizon)]
                        context_next_states = [np.zeros((5,5,20)) for i in range(horizon)]
                        context_actions = [np.zeros(6) for i in range(horizon)]
                        context_rewards = [np.zeros(1) for i in range(horizon)]
                    elif layout == "random3":
                        context_states = [np.zeros((8,5,20)) for i in range(horizon)]
                        context_next_states = [np.zeros((8,5,20)) for i in range(horizon)]
                        context_actions = [np.zeros(6) for i in range(horizon)]
                        context_rewards = [np.zeros(1) for i in range(horizon)]
                    elif layout == "unident_s":
                        context_states = [np.zeros((9,5,20)) for i in range(horizon)]
                        context_next_states = [np.zeros((9,5,20)) for i in range(horizon)]
                        context_actions = [np.zeros(6) for i in range(horizon)]
                        context_rewards = [np.zeros(1) for i in range(horizon)]

                    ctx_states = ctx_states + context_states 
                    ctx_actions = ctx_actions + context_actions
                    ctx_rewards = ctx_rewards + context_rewards
                    # ctx_next_states = ctx_next_states + context_next_states
                else:
                    with open(context_traj_files[k-num_mask_rollout], 'rb') as f:
                        context_trajectory = pickle.load(f)

                    # Extract context states, actions, and rewards from the trajectory
                    context_states, context_actions, context_rewards, context_next_states = get_obs(context_trajectory)
                    context_states = context_states[:horizon]
                    context_actions = context_actions[:horizon]
                    context_rewards = context_rewards[:horizon]
                    # context_next_states = context_next_states[:horizon]
                    ctx_states = ctx_states + context_states 
                    ctx_actions = ctx_actions + context_actions
                    ctx_rewards = ctx_rewards + context_rewards
                    # ctx_next_states = ctx_next_states + context_next_states

            trajs_context_actions.append(ctx_actions)
            trajs_context_states.append(ctx_states)
            trajs_context_rewards.append(ctx_rewards)


            all_indices = set(range(NUM_TRAIN_ROLL))
            query_indices = list(all_indices - set(ind_list))
            for _ in range(n_samples):
                query_ind = random.choice(query_indices)
                if agent_group == "train":
                    if agent_type == "hsp":
                        query_traj_type = random.choices([ "mid", "final"], weights=[0.25,0.75], k=1)[0]
                        if query_traj_type == "final":
                            query_traj_file = traj_hsp_none_files[query_ind]
                        # elif query_traj_type == "small_n":
                        #     query_traj_file = traj_hsp_sn_files[query_ind]
                        # elif query_traj_type == "mid_n":
                        #     query_traj_file = traj_hsp_mn_files[query_ind]
                        elif query_traj_type == "mid":
                            query_traj_file = traj_hsp_mid_files[query_ind]
                    elif agent_type == "mep":
                        query_traj_type = random.choices(["mid", "final"], weights=[0.25,0.75], k=1)[0]
                        if query_traj_type == "final":
                            query_traj_file = traj_mep_none_files[query_ind]
                        elif query_traj_type == "mid":
                            query_traj_file = traj_mep_mid_files[query_ind]
                elif agent_group == "test":
                    query_indices = list(range(NUM_EVAL_ROLL))
                    query_ind = random.choice(query_indices)
                    query_traj_file = traj_eval_files[query_ind]

                # query_traj_file = random.choice(traj_files)
                with open(query_traj_file, 'rb') as f:
                    query_trajectory = pickle.load(f)

                q_states, q_actions, query_rewards, _ = get_obs(query_trajectory)
                empty_states = [np.zeros([5,5,20]) for i in range(NUM_QUERY-1)]
                empty_actions = [np.zeros([6,]) for i in range(NUM_QUERY-1)]
                query_states = empty_states + q_states 
                query_actions = empty_actions + q_actions
                
                # Sample a query state and optimal action from the query trajectory
                sample_index = random.randint(0, len(query_states) - NUM_QUERY)
                query_state = query_states[sample_index:sample_index + NUM_QUERY]
                optimal_action = query_actions[sample_index:sample_index + NUM_QUERY]
                
                assert len(query_state) == NUM_QUERY

                query_state = np.array(query_state, dtype=np.uint8)
                optimal_action = np.array(optimal_action, dtype=np.uint8)
                trajs_query_states.append(query_state)
                trajs_optimal_actions.append(optimal_action)
                
            # for context_traj_file in context_traj_files:
            #     traj_files.append(context_traj_file)
        # print(f"Done with agent {i+1}/{n_agents}")

    return trajs_query_states, trajs_optimal_actions, trajs_context_actions, trajs_context_states , trajs_context_rewards

if __name__ == '__main__':
    np.random.seed(0)
    random.seed(0)

    parser = argparse.ArgumentParser()
    common_args.add_dataset_args(parser)
    args = vars(parser.parse_args())
    print("Args: ", args)

    env = args['env']
    n_envs = args['envs']
    n_eval_envs = args['envs_eval']
    n_hists = args['hists']
    n_samples = args['samples']
    horizon = args['H']
    dim = args['dim']
    var = args['var']
    cov = args['cov']
    env_id_start = args['env_id_start']
    env_id_end = args['env_id_end']
    lin_d = args['lin_d']
    rollin_type = args['rollin_type']


    n_train_envs = int(.8 * n_envs)
    n_test_envs = n_envs - n_train_envs

    config = {
        'n_hists': n_hists,
        'n_samples': n_samples,
        'horizon': horizon,
    }


    if env == 'overcooked':
        layout_name = args['layout_name']
        exp_name = args['exp_name']
        agent_id_start = args['agent_id_start']
        agent_id_end = args['agent_id_end']
        dataset_prefix = args['dataset_prefix']
        shuffle = args['shuffle_dataset']
        episode_length = args['episode_length']
        if layout_name == "random1":
            train_mep_agent_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
            train_bias_agent_list = [1, 4, 5, 6, 7, 11, 13, 14, 19, 22, 26, 28, 30, 38, 41, 42, 43, 44, 49, 51, 52]
            test_agent_list = [2, 8, 12, 15, 20, 50]
        elif layout_name == "random0":
            train_bias_agent_list = [2, 6, 8, 12, 14, 15, 18, 20, 21, 22, 23, 24, 25, 27]
            train_mep_agent_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
            test_agent_list = [1, 3, 4, 5, 7, 9, 17, 26]
        elif layout_name == "random0_m":
            train_bias_agent_list = [1, 3, 4, 6, 34, 35, 36, 37, 38, 40, 41, 43, 44, 61, 63, 67, 71]
            train_mep_agent_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
            test_agent_list = [2, 5, 7, 26, 28, 32, 39, 46, 49]

        elif layout_name == "random0_medium":
            train_bias_agent_list = [1, 3, 5, 9, 11, 13, 15, 16, 17, 19, 24, 30, 33, 35, 40, 41, 43, 51, 52, 53, 54]
            train_mep_agent_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
            test_agent_list = [12, 21, 29, 36, 44, 46]
        elif layout_name == "random1_m":
            train_bias_agent_list = [2, 6, 8, 10, 11, 12, 21, 23, 27, 28, 36, 37, 38, 39, 48, 61, 63, 65, 66, 68, 70]
            train_mep_agent_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
            test_agent_list = [4, 5, 9, 13, 18, 24, 40, 47, 51, 69]

        elif layout_name == "random3":
            train_bias_agent_list =  [5, 16, 34, 52, 76, 78, 96, 104, 112, 116, 118, 125, 131, 134, 135, 138, 149, 154, 157, 158, 159]
            train_mep_agent_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
            test_agent_list =  [32, 42, 51, 98, 127, 152, 162]
            
        elif layout_name == "unident_s":
            train_bias_agent_list = [10, 12, 18, 24, 40, 41, 42, 45, 52, 54, 63, 78, 80, 84, 85, 92, 127, 141, 155, 157, 163]
            train_mep_agent_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
            test_agent_list = [13, 22, 25, 46, 74, 105, 126, 147]
                
        train_agent_ids = np.array(train_bias_agent_list + train_mep_agent_list)
        train_agent_ids = train_agent_ids[agent_id_start-1:agent_id_end]
        test_agent_ids = np.array(test_agent_list)

        if shuffle:
            random.shuffle(train_agent_ids)
        
        config.update({
            'layout': layout_name,
            'evaluation': False,
            'use_random_player_pos': True,
            'use_random_terrain_state': True,
            'shuffle': shuffle,
            'ctx_rollouts': CTX_ROLLOUTS,
        })

        print("Checking trajectory files for training...")
        min_traj_files = 200
        for i, agent_id in enumerate(train_agent_ids):
            if i+agent_id_start-1 < len(train_bias_agent_list):
                traj_hsp_none_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout_name}/mappo/{exp_name}/hsp_final_none_{NUM_TRAIN_ROLL}/run_{agent_id:02d}/trajs/"
                traj_hsp_mid_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout_name}/mappo/{exp_name}/hsp_mid_none_{NUM_TRAIN_ROLL}/run_{agent_id:02d}/trajs/"
                

                traj_hsp_none_files = glob.glob(traj_hsp_none_path + "traj_*.pkl")
                traj_hsp_mid_files = glob.glob(traj_hsp_mid_path + "traj_*.pkl")
                assert len(traj_hsp_none_files) > 0, f"No trajectory files found for agent {agent_id}, \ntraj_path: {traj_hsp_none_path}"
                assert len(traj_hsp_mid_files) > 0, f"No trajectory files found for agent {agent_id}, \ntraj_path: {traj_hsp_mid_path}"
                
            elif i+agent_id_start-1 >= len(train_bias_agent_list):
                traj_mep_none_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout_name}/mappo/{exp_name}/mep_final_none_{NUM_TRAIN_ROLL}/run_{agent_id:02d}/trajs/"
                traj_mep_mid_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout_name}/mappo/{exp_name}/mep_final_none_{NUM_TRAIN_ROLL}/run_{agent_id:02d}/trajs/"

                traj_mep_none_files = glob.glob(traj_mep_none_path + "traj_*.pkl")
                traj_mep_mid_files = glob.glob(traj_mep_mid_path + "traj_*.pkl")
                assert len(traj_mep_none_files) > 0, f"No trajectory files found for agent {agent_id}, \ntraj_path: {traj_mep_none_path}"
                assert len(traj_mep_mid_files) > 0, f"No trajectory files found for agent {agent_id}, \ntraj_path: {traj_mep_mid_path}"
            
        print("Checking trajectory files for testing...")   
        for i, agent_id in enumerate(test_agent_ids):
            traj_eval_path = f"{STORAGE_PREFIX}/ZSC/results/Overcooked/{layout_name}/mappo/{exp_name}/hsp_eval_none_{NUM_EVAL_ROLL}/run_{agent_id:02d}/trajs/"
            traj_eval_files = glob.glob(traj_eval_path + "traj_*.pkl")

            assert len(traj_eval_files) > 0, f"No trajectory files found for agent {agent_id}"
        print(f"Done!!!")
        train_filepath = build_overcooked_data_filename(
            env, agent_id_start, agent_id_end, config, mode=0, layout=layout_name, prefix=dataset_prefix)
        test_filepath = build_overcooked_data_filename(
            env, agent_id_start, agent_id_end, config, mode=1, layout=layout_name, prefix=dataset_prefix)
        
        env_parser = get_config()
        env_args = ['--env', 'Overcooked','--algorithm_name', 'mappo', '--experiment_name', 'test', '--layout_name', layout_name, '--cnn_layers_params', '32,3,1,1 64,3,1,1 32,3,1,1', '--use_recurrent_policy','--episode_length', str(episode_length)]
        env_all_args = parse_args(env_args, env_parser)
        run_dir = "run_dir/test/"
        
        if layout_name == "random0_m" or layout_name == "random1_m":
            VERSION = "new"
        else:
            VERSION = "old"

        # import h5py
        if VERSION == "new":
            overcooked_env = Overcooked_new(env_all_args, run_dir, evaluation=True)
        elif VERSION == "old":
            overcooked_env = Overcooked(env_all_args, run_dir, evaluation=True)
        # overcooked_env = Overcooked(env_all_args, run_dir, evaluation=True)
        if not os.path.exists('datasets/{}/{}'.format(layout_name, dataset_prefix)):
            os.makedirs('datasets/{}/{}'.format(layout_name, dataset_prefix), exist_ok=True)
        train_filepath = train_filepath.split(".")[0]
        test_filepath = test_filepath.split(".")[0]
            
        query_states, optimal_actions, context_actions, context_states, context_rewards = generate_overcooked_histories(train_agent_ids, n_hists, n_samples, layout_name, exp_name, overcooked_env, bias_agent_len=len(train_bias_agent_list), agent_group="train", agent_id_start=agent_id_start)
        with open(train_filepath + "_query_s.pkl", 'wb') as file:
            pickle.dump(query_states, file)
        with open(train_filepath + "_optimal_a.pkl", 'wb') as file:
            pickle.dump(optimal_actions, file)
        with open(train_filepath + "_context_s.pkl", 'wb') as file:
            pickle.dump(context_states, file)
        with open(train_filepath + "_context_a.pkl", 'wb') as file:
            pickle.dump(context_actions, file)
        with open(train_filepath + "_context_r.pkl", 'wb') as file:
            pickle.dump(context_rewards, file)
        
        query_states, optimal_actions, context_actions, context_states, context_rewards = generate_overcooked_histories(test_agent_ids, n_hists, n_samples, layout_name, exp_name, overcooked_env, bias_agent_len=len(train_bias_agent_list), agent_group="test", agent_id_start=agent_id_start)
        with open(test_filepath + "_query_s.pkl", 'wb') as file:
            pickle.dump(query_states, file)
        with open(test_filepath + "_optimal_a.pkl", 'wb') as file:
            pickle.dump(optimal_actions, file)
        with open(test_filepath + "_context_s.pkl", 'wb') as file:
            pickle.dump(context_states, file)
        with open(test_filepath + "_context_a.pkl", 'wb') as file:
            pickle.dump(context_actions, file)
        with open(test_filepath + "_context_r.pkl", 'wb') as file:
            pickle.dump(context_rewards, file)

    else:
        raise NotImplementedError
    
    # with open(eval_filepath, 'wb') as file:
    #     pickle.dump(eval_trajs, file)
    # print(f"Saved to {eval_filepath}.")

    display.stop()