import os
import gym
import json
import re
import torch
import numpy as np
import argparse
from tqdm import tqdm
from copy import deepcopy

from diffusion_human_ai.ldm.vae import VAE, HyperDecoder, Conv1dEncoder
from diffusion_human_ai.ghn.core import hyperActor

from overcooked_ai_py.agents.benchmarking import AgentEvaluator
from overcooked_ai_py.agents.agent import AgentPair, PantheonRLAgent, RNNAgent

from .utils import *
from diffusion_human_ai.ldm.utils import *
from diffusion_human_ai.ldm.ddpm import DDPMSampler
from diffusion_human_ai.ldm.diffusion import Diffusion

from diffusion_human_ai.translator.vae_fc import Translator, VAE


from trainer import DIVERSE_COORPERATION_STYLE_LIST, DIVERSE_ORDERS_STYLE_LIST, CENTER_POTS_STYLE_LIST, CROSSWAY_STYLE_LIST

EVENT_INFOS_DIM = {
    "crossway": 23,
    "diverse_coordination": 19,
    "diverse_orders": 24,
    "center_pots": 18
}

def evaluate_agent_pair(env, ego, alt, num_games=5, horizon=400, instruction=None, recurrent_ego=False, obs_with_action=False):
    evaluator = AgentEvaluator.from_mdp(env.base_env.mdp, {'horizon': horizon})
    if instruction is not None:
        ego_agent = PantheonRLAgent(ego, env, instruction)
    elif recurrent_ego:
        ego_agent = RNNAgent(ego, env)
    else:
        ego_agent = PantheonRLAgent(ego, env)
    alt_agent = PantheonRLAgent(alt, env)
    agent_pair = AgentPair(ego_agent, alt_agent, obs_with_action=obs_with_action)
    
    trajs = evaluator.evaluate_agent_pair(agent_pair, num_games=num_games)
    sparse_rews, dense_rews, lens = trajs['ep_returns'], trajs['ep_dense_returns'], trajs['ep_lengths']
    
    return sparse_rews, dense_rews



def get_event_infos_from_translator(desc_path, translator):
    """Get event-based description from human descriptions with translator"""
    with open(desc_path, 'r') as f:
        desc_dict = json.load(f)['eval']
        
    desc_list = [desc[0] for desc in desc_dict.values()]
    translator.eval()
    with torch.no_grad():
        event_infos = translator.convert(desc_list)
        processed_infos = process_infos_with_ranking(event_infos.cpu().numpy())
    return processed_infos.numpy() if not isinstance(processed_infos, np.ndarray) else processed_infos


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--layout', type=str, default='crossway')
    parser.add_argument('--eval_games', type=int, default=50)
    parser.add_argument('--horizon', type=int, default=400)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    model_dir = 'diffusion_human_ai/models/%s' % (args.layout)
    desc_load = os.path.join(model_dir, 'diverse_descriptions.json')
    results_dir = os.path.join(model_dir, 'results')
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    with open(desc_load, 'r') as f:
        desc_dict = json.load(f)
        
    idx2env = []
    idx2ego = []
    idx2partner = []
    for i, name in enumerate(desc_dict):
        env_config = {"layout_name": args.layout}
        if args.layout == "diverse_coordination":
            masked_events = DIVERSE_COORPERATION_STYLE_LIST[i]
        elif args.layout == "diverse_orders":
            masked_events = DIVERSE_ORDERS_STYLE_LIST[i]
        elif args.layout == "center_pots":
            masked_events = CENTER_POTS_STYLE_LIST[i]
        elif args.layout == "crossway":
            masked_events = CROSSWAY_STYLE_LIST[i]
        else:
            raise ValueError("Invalid layout name")
        
        env_config["masked_events"] = masked_events
        env = gym.make('OvercookedMultiEnv-v0', **env_config)
        
        partner_load = os.path.join(model_dir, name + "_eval.zip")
        partner = gen_fixed({}, "PPO", partner_load)
        
        idx = re.findall('[0-9]', name)[0]
        ego_load = os.path.join(model_dir, f"ego_{idx}.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        
        env.add_partner_agent(partner)
        
        idx2env.append(env)
        idx2ego.append(ego)
        idx2partner.append(partner)

    ### Load VAE ###
    meta_size = 8
    latent_dim = 64
    device = "cuda" if torch.cuda.is_available() else "cpu"

    hyper_actor = hyperActor(
        act_dim=idx2env[0].action_space.n,
        obs_dim=idx2env[0].observation_space.shape[0],
        latent_dim=latent_dim,
        allowable_layers=np.array([64]),
        meta_batch_size=8,
        device=device
    )

    decoder = HyperDecoder(hyper_actor).to(device)
    encoder = Conv1dEncoder(bias_shapes=[64, 64, 6], latent_dim=latent_dim, action_dim=6, obs_dim=96).to(device)
    vae = VAE(encoder, decoder).to(device)
    vae_path = os.path.join(model_dir, "vae.pth")
    vae.load_state_dict(torch.load(vae_path))

    ### Load Diffusion ### 
    diffusion_path = os.path.join(model_dir, "diffusion_robust.pth")
    diffusion = Diffusion(d_context=320, context_embed=True).to(device)
    diffusion.load_state_dict(torch.load(diffusion_path))
    generator = torch.Generator(device=device).manual_seed(42)

    ### Load Translator ###
    desc_path = os.path.join(model_dir, "diverse_descriptions.json")
    translator_path = os.path.join(model_dir, "translator.pth")
    bert_path = "diffusion_human_ai/models/bert-base-uncased"
    translator = Translator(event_info_dim=EVENT_INFOS_DIM[args.layout], max_seq_len=32, pooler_output=False).to(device)
    translator.load_state_dict(torch.load(translator_path))

    prompt_list = get_event_infos_from_translator(desc_path, translator)
    params_list = generate_params(prompt_list, vae, generator, diffusion, device)

    
    pair_sparse_rewards = np.zeros((2, len(idx2partner)))
    pair_dense_rewards = np.zeros((2, len(idx2partner)))
    sparse_rewards_std = np.zeros((2, len(idx2partner)))
    dense_rewards_std = np.zeros((2, len(idx2partner)))

    recon_ego = deepcopy(idx2ego[0])
    for i, partner in enumerate(idx2partner):
        print("Partner:", i)
        env = idx2env[i]
        
        # play with best response
        ego = idx2ego[i]
        sparse_rews, dense_rews = evaluate_agent_pair(env, ego, partner, num_games=args.eval_games)
        pair_sparse_rewards[0, i] = np.mean(sparse_rews)
        pair_dense_rewards[0, i] = np.mean(dense_rews)
        sparse_rewards_std[0, i] = np.std(sparse_rews)
        dense_rewards_std[0, i] = np.std(dense_rews)

        # play with ego reconstructed via haland
        recon_ego.policy = set_params(recon_ego.policy, params_list[i])
        sparse_rews, dense_rews = evaluate_agent_pair(env, recon_ego, partner, num_games=args.eval_games)
        pair_sparse_rewards[1, i] = np.mean(sparse_rews)
        pair_dense_rewards[1, i] = np.mean(dense_rews)
        sparse_rewards_std[1, i] = np.std(sparse_rews)
        dense_rewards_std[1, i] = np.std(dense_rews)

    with open(os.path.join(model_dir, 'pair_sparse_rewards.npy'), 'wb') as f:
        np.save(f, pair_sparse_rewards)
    with open(os.path.join(model_dir, 'pair_dense_rewards.npy'), 'wb') as f:
        np.save(f, pair_dense_rewards)
    with open(os.path.join(model_dir, 'sparse_rewards_std.npy'), 'wb') as f:
        np.save(f, sparse_rewards_std)
    with open(os.path.join(model_dir, 'dense_rewards_std.npy'), 'wb') as f:
        np.save(f, dense_rewards_std)