import os
import gym
import json
import assistive_gym
import torch
import numpy as np
from numpngw import write_apng
from IPython.display import Image
import warnings
from tqdm import tqdm
from copy import deepcopy

import importlib
from trainer import gen_fixed
from pantheonrl.common.observation import Observation
from stable_baselines3.common.utils import obs_as_tensor

from diffusion_human_ai.ldm.vae import VAE, HyperDecoder, Conv1dEncoder
from diffusion_human_ai.ghn.core import hyperActor
from diffusion_human_ai.ldm.utils import *
from diffusion_human_ai.scripts.utils import *

from diffusion_human_ai.ldm.ddpm import DDPMSampler
from diffusion_human_ai.ldm.diffusion import Diffusion


def evaluate_agent_pair(ego, partner, env, n_games=10, n_steps=200, instruction=None): 
    episode_rewards = []
    for i in range(n_games):
        step = 0
        env.set_seed(123 * i)
        obs = env.reset()[0]
        if obs.shape[-1] < ego.policy.observation_space.shape[0]:
            ego_obs = np.concatenate([obs, np.zeros(ego.policy.observation_space.shape[0] - obs.shape[-1])])
            ego_obs = Observation(ego_obs)
        else:
            ego_obs = Observation(obs)
        obs = Observation(obs)
        done = False
        episode_reward = 0
        while not done:
            joint_action = np.concatenate([ego.get_action(ego_obs, instruction=instruction), partner.get_action(obs)])
            raw_obs, reward, done, info = env.step(joint_action)
            if raw_obs.shape[-1] < ego.policy.observation_space.shape[0]:
                ego_obs = np.concatenate([raw_obs, np.zeros(ego.policy.observation_space.shape[0] - raw_obs.shape[-1])])
                ego_obs = Observation(ego_obs)
            else:
                ego_obs = Observation(raw_obs)
            obs = Observation(raw_obs)
            step += 1
            episode_reward += reward
            if step >= n_steps:
                break
        episode_rewards.append(episode_reward)
    return episode_rewards


if __name__ == '__main__':
    module = importlib.import_module('assistive_gym.envs')
    warnings.filterwarnings('ignore')
    os.environ["USE_TF"] = 'None'

    model_dir = "diffusion_human_ai/models/assistive"

    task_list = ["scratchitch", "feeding", "drinking", "bedbathing"]
    task2env = {"drinking": "DrinkingJacoHuman-v1", "feeding": "FeedingJacoHuman-v1",
                "scratchitch": "ScratchItchJacoHuman-v1", "bedbathing": "BedBathingJacoHuman-v1"} 

    env_list = [task2env[task] for task in task_list]
    idx2env = []
    idx2ego = []
    idx2partner = []
    for env_name in env_list:
        env_class = getattr(module, env_name.split('-')[0] + 'Env')
        env = env_class()

        ego_load = os.path.join(model_dir, f"ego_{env_name}.zip")
        partner_load = os.path.join(model_dir, f"partner_{env_name}.zip")
        ego = gen_fixed({}, "PPO", ego_load)
        partner = gen_fixed({}, "PPO", partner_load)

        idx2env.append(env)
        idx2ego.append(ego)
        idx2partner.append(partner)

    meta_size = 8
    latent_dim = 64
    device = "cuda" if torch.cuda.is_available() else "cpu"
    hyper_actor = hyperActor(
        act_dim=7,
        obs_dim=64,
        latent_dim=latent_dim,
        allowable_layers=np.array([64]),
        meta_batch_size=8,
        device=device
    )
    decoder = HyperDecoder(hyper_actor).to(device)
    encoder = Conv1dEncoder(bias_shapes=[64, 64, 7], latent_dim=latent_dim, obs_dim=64, action_dim=7).to(device)
    vae = VAE(encoder, decoder).to(device)
    vae_path = os.path.join(model_dir, "vae.pth")
    vae.load_state_dict(torch.load(vae_path))

    diffusion_path = os.path.join(model_dir, "diffusion_robust.pth")
    diffusion = Diffusion(d_context=32).to(device)
    diffusion.load_state_dict(torch.load(diffusion_path))

    generator = torch.Generator(device=device).manual_seed(42)
    prompt_list = get_prompt_embeddings(torch.tensor(np.arange(4)), dim=32).numpy()
    params_list = generate_params(prompt_list, vae, generator, diffusion, device)

    rewards = np.zeros((2, 4))
    rewards_std = np.zeros((2, 4))

    for i in range(4):
        env_name = env_list[i]
        env = idx2env[i]
        ego = idx2ego[i]
        partner = idx2partner[i]

        episode_rewards, task_successes = evaluate_agent_pair(ego, partner, env)
        rewards[0, i] = np.mean(episode_rewards)
        rewards_std[0, i] = np.std(episode_rewards)
        
        haland_ego = deepcopy(idx2ego[i])
        haland_ego.policy = set_params(haland_ego.policy, params_list[i])
        episode_rewards, task_successes = evaluate_agent_pair(haland_ego, partner, env)
        rewards[1, i] = np.mean(episode_rewards)
        rewards_std[1, i] = np.std(episode_rewards)

    with open(os.path.join(model_dir, "results.json"), "w") as f:
        json.dump({"rewards": rewards.tolist(), "rewards_std": rewards_std.tolist()}, f)