import gym
import os
import json

import numpy as np
import torch as th
from tqdm import tqdm
from trainer import gen_rnn_fixed
from tester import gen_fixed
from copy import deepcopy
from diffusion_human_ai.ldm.vae import VAE, HyperDecoder, Conv1dEncoder
from diffusion_human_ai.ghn.core import hyperActor
from pantheonrl.common.observation import Observation
from pantheonrl.common.agents import StaticPolicyAgent, StaticRNNPolicyAgent
from sb3_contrib.common.recurrent.type_aliases import RNNStates
from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from stable_baselines3.common.utils import obs_as_tensor

from diffusion_human_ai.ldm.ddpm import DDPMSampler
from diffusion_human_ai.ldm.diffusion import Diffusion
from diffusion_human_ai.ldm.utils import *
from diffusion_human_ai.scripts.utils import *



def evaluate_agent_pair(env, ego, partner, max_steps=100, instruction=None, obs_with_action=False):
    device = ego.policy.device
    obs = env.reset()
    # env.render()
    total_rew = 0
    done = False
    for i in range(max_steps):
        # joint_action = []       
        with th.no_grad():
            partner_obs = Observation(np.array(obs[1]))
            partner_action = partner.get_action(partner_obs)

            if instruction is not None:
                ego_obs = Observation(np.concatenate([obs[0], instruction]))
            elif obs_with_action:
                onehot_action = np.zeros(env.action_space[0].n)
                onehot_action[partner_action] = 1
                ego_obs = Observation(np.concatenate([obs[0], onehot_action]))
            else:
                ego_obs = Observation(np.array(obs[0]))
            if isinstance(ego.policy, RecurrentActorCriticPolicy):
                ego_action = ego.get_action(ego_obs, th.tensor(done, device=device))
            else:
                ego_action = ego.get_action(ego_obs)
            joint_action = [ego_action, partner_action]
            

        obs, rew, dones, info = env.step(tuple(joint_action))
        total_rew += rew[0]
        # env.render()
        if any(dones):
            done = True
            break
    return total_rew

if __name__ == '__main__':
    device = "cuda" if th.cuda.is_available() else "cpu"
    model_dir = "diffusion_human_ai/models/lbf_spread/" 
    idx2env = []
    idx2ego = []
    idx2alt = []
    for idx in range(1, 9):
        layout_name = f"lbf_spread_{idx}-v0"
        env = gym.make(layout_name, test_mode=True)
        idx2env.append(env)

        ego_load = os.path.join(model_dir, f"ego_{idx}.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        idx2ego.append(ego)

        alt_load = os.path.join(model_dir, f"partner_{idx}.zip")
        alt = gen_fixed({}, "PPO", alt_load)
        idx2alt.append(alt)

    meta_size = 8
    latent_dim = 64
    encoder = Conv1dEncoder(bias_shapes=[64, 64, 6], latent_dim=latent_dim, \
                        obs_dim=4*(idx2env[0].observation_space[0].shape[0] // 4 + 1), \
                        action_dim=idx2env[0].action_space[0].n).to(device)
    hyper_actor = hyperActor(
        act_dim=idx2env[0].action_space[0].n,
        obs_dim=idx2env[0].observation_space[0].shape[0],
        latent_dim=latent_dim,
        allowable_layers=np.array([64]),
        meta_batch_size=8,
        device=device
    )
    decoder = HyperDecoder(hyper_actor)
    vae = VAE(encoder=encoder, decoder=decoder).to(device)
    vae_path = os.path.join(model_dir, "vae.pth")
    vae.load_state_dict(th.load(vae_path))

    diffusion_path = os.path.join(model_dir, "diffusion_robust.pth")
    diffusion = Diffusion(d_context=320, context_embed=True).to(device)
    diffusion.load_state_dict(th.load(diffusion_path))

    generator = th.Generator(device=device).manual_seed(42)
    prompt_list = get_prompt_embeddings(th.tensor(range(len(idx2ego)))).cpu().numpy() 
    params_list = generate_params(prompt_list, vae, generator, diffusion, device)

    rewards_mean = np.zeros((2, 8))
    rewards_std = np.zeros((2, 8))

    haland_ego = deepcopy(idx2ego[0])


    for idx in range(8):
        print('Partner ', idx)
        rew = []
        env = idx2env[idx]
        ego = idx2ego[idx]
        alt = idx2alt[idx]
        for _ in range(100):
            rew.append(evaluate_agent_pair(env, ego, alt))
        rewards_mean[0, idx] = np.mean(rew)
        rewards_std[0, idx] = np.std(rew)

        rew = []
        haland_ego.policy = set_params(haland_ego.policy, params_list[idx])
        for _ in range(100):
            rew.append(evaluate_agent_pair(env, haland_ego, alt))
        rewards_mean[1, idx] = np.mean(rew)
        rewards_std[1, idx] = np.std(rew)

    with open(os.path.join(model_dir, 'results.json'), 'w') as f:
        json.dump({'mean': rewards_mean.tolist(), 'std': rewards_std.tolist()}, f)