import argparse
import numpy as np
import os
import tensorflow as tf
import json
import pickle
import cv2
import matplotlib.pyplot as plt

import metaworld
import gymnasium as gym
from gymnasium.vector import AsyncVectorEnv

from data.metaworld.config import *

from data.utils.multi_env_interface import InferenceWrapper
from data.utils.language_tokenizer import *

import mediapy

os.environ["MUJOCO_GL"] = "egl"

IMAGE_WIDTH = 256
IMAGE_HEIGHT = 256
CAMERA_VIEW = 'corner2'
CROP_RATIO = 0.2

# prevent a single jax process from taking up all the GPU memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
gpus = tf.config.list_physical_devices("GPU")
if len(gpus) > 0:
    # prevent a single tf process from taking up all the GPU memory
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=3072)],
    )


def make_env(env_cls, initial_state):
    def _init():
        env = env_cls(render_mode='rgb_array', camera_name='corner2')
        env.set_task(initial_state)
        return env
    return _init


def process_image(image):
    height, width = image.shape[0], image.shape[1]
    image = image[int(height * CROP_RATIO):int(height * (1 - CROP_RATIO)), int(width * CROP_RATIO):int(width * (1 - CROP_RATIO))]
    image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR)
    return image[::-1].astype(np.uint8)


def load_model(model_path, input_rng=0, step=None, action_ensemble=False, crop=False, image_horizon=1, split='train'):
    if split == 'single_task':
        from hypervla.base_model import BaseModel
        tempmodel = BaseModel.load_pretrained(model_path, step=step)
    else:
        from hypervla.model import HyperVLA
        tempmodel = HyperVLA.load_pretrained(model_path, step=step)
    model = InferenceWrapper(
        model=tempmodel, 
        policy_setup='metaworld', 
        init_rng=input_rng, 
        action_ensemble=action_ensemble, 
        horizon=image_horizon, 
        crop=crop)
    return model


def evaluate(model_path, seed=0, checkpoint_step=None, split='train', save_video=False, env_num=5, action_ensemble=False, image_horizon=1, recompute=False):

    save_dir = model_path.replace('finetune_saves', 'eval_results')
    eval_path = f'{save_dir}/eval_step_{checkpoint_step}/{seed}'
    os.makedirs(eval_path, exist_ok=True)

    save_file_name = f'success_rate_{split}'
    if action_ensemble:
        save_file_name += '_action_ensemble'
    save_file_name += f'_horizon_{image_horizon}'
    if os.path.exists(f'{eval_path}/{save_file_name}.json'):
        with open(f'{eval_path}/{save_file_name}.json', 'r') as f:
            all_tasks_success_rate = json.load(f)
    else:
        all_tasks_success_rate = dict()

    model = load_model(model_path, seed, step=checkpoint_step, action_ensemble=action_ensemble, image_horizon=image_horizon, split=split)

    # load instruction embeddings
    with open('dataset/token_embedding_metaworld.pkl', 'rb') as f:
        token_embeddings = pickle.load(f)
    # load language tokenizers
    tokenizer, token_embedding_model, params = get_language_tokenizer()

    if split == 'train':
        benchmark = metaworld.ML45(seed=seed)
        tasks = benchmark.train_classes
    elif split == 'test':
        benchmark = metaworld.ML45(seed=seed)
        tasks = benchmark.test_classes
    elif split =='single_task':
        task_name = model_path.split('/')[-1].split('_')[0]
        benchmark = metaworld.ML1(task_name, seed=seed)
        tasks = benchmark.train_classes

    for name, env_cls in tasks.items():

        if name in all_tasks_success_rate and not recompute:
            continue

        video_path = f"{eval_path}/video/{split}/{name}"
        os.makedirs(video_path, exist_ok=True)

        benchmark = metaworld.ML1(name, seed=seed)
        init_states = benchmark.test_tasks
        envs = AsyncVectorEnv([make_env(env_cls, init_states[i]) for i in range(env_num)])

        language_instruction = policies[name][1]
        print (f'===== {name}: {language_instruction} =====')

        tokens = tokenizer.encode(language_instruction)
        if tuple(tokens['input_ids'][0]) not in token_embeddings:
            instruction_embedding = token_to_embedding(token_embedding_model, params, tokens)
            token_embeddings[tuple(tokens['input_ids'][0])] = instruction_embedding.squeeze()

        model.reset(language_instruction, token_embeddings=token_embeddings)

        obs, infos = envs.reset()  # Reset environment
        # simulate for a while to get a steady state
        dummy_action = np.zeros((env_num, 4))
        for _ in range(20):
            obs, reward, done, truncated, info = envs.step(dummy_action)

        images = envs.render()
        images = np.stack(process_image(image) for image in images)
        images_history = [images]

        finished_tasks = [False] * env_num
        # max_step = task_demo_length[task_name] + 30
        # TODO: how to set max_step?
        max_step = 300
        episode_length = [max_step] * env_num
        for step in range(max_step):
            raw_actions, actions, _, (language_instruction, tokens) = model.step(images)
            # if step == 0 and 'hypernet' in attention_weights:
            #     task_attention_weights = []
            #     for i in range(3):
            #         # shape: head_num * token_num * token_num
            #         w = attention_weights['hypernet']['Transformer_0'][f'encoderblock_{i}']['MultiHeadDotProductAttention_0']['attention_weights'][0][0]
            #         task_attention_weights.append(w)
            #     os.makedirs(f'{eval_path}/task_attention_weights/{split}', exist_ok=True)
            #     with open(f'{eval_path}/task_attention_weights/{split}/{name}.pkl', 'wb') as f:
            #         pickle.dump([task_attention_weights, language_instruction, tokens], f)

            # analyze action attention mask
            # action_attention_weights = {'mean': [], 'max': []}
            # for i in range(12):
            #     # attention map shape: batch_size * head_num * token_num * token_num
            #     # shape of w: batch_size * head_num * image_token_num (256)
            #     w = attention_weights['Transformer_0'][f'encoderblock_{i}']['MultiHeadDotProductAttention_0']['attention_weights'][0][:, :, -1, -(1 + 16 + 256):-(1 + 16)]
            #     w_mean = w.mean(axis=1)
            #     w_max = w.max(axis=1)
            #     action_attention_weights['mean'].append(w_mean)
            #     action_attention_weights['max'].append(w_max)
            # heatmaps = generate_attention_map(action_attention_weights['mean'][-1])
            # masked_images = combine_image_and_heatmap(images, heatmaps)
            # images_with_attention_weights.append(masked_images)
            # attention_history.append(action_attention_weights)

            obs, rewards, dones, truncated, infos = envs.step(actions)
            # check whether succeed
            for k in range(env_num):
                if int(infos["success"][k]) == 1:
                    finished_tasks[k] = True
                    episode_length[k] = min(step + 1, episode_length[k])
            if all(finished_tasks):
                break
            images = envs.render()
            images = np.stack(process_image(image) for image in images)
            images_history.append(images)

        success_rate = sum(finished_tasks) / env_num
        envs.close()

        if save_video:
            os.system(f'rm {video_path}/*.mp4')
            for i in range(env_num):
                result = 'success' if finished_tasks[i] else 'fail'
                # images = [x[i] for x in images_with_attention_weights[:episode_length[i]]]
                images = [x[i] for x in images_history[:episode_length[i]]]
                mediapy.write_video(f'{video_path}/{i + 1}_{result}.mp4', images, fps=10)
            # with open(f'{video_path}/record.pkl', 'wb') as f:
            #     pickle.dump([images_history, attention_history, episode_length], f)

        all_tasks_success_rate[name] = success_rate
        sorted_scores = sorted(all_tasks_success_rate.items(), key=lambda x: x[1])
        for x in sorted_scores:
            print (x)
        with open(f'{eval_path}/{save_file_name}.json', 'w') as f:
            json.dump(all_tasks_success_rate, f)

    full_results = sorted(all_tasks_success_rate.items(), key=lambda x: x[0])
    for x in full_results:
        print (x)
    print (np.mean([x[1] for x in full_results]))



if __name__ == '__main__':

    # Add arguments
    parser = argparse.ArgumentParser(description="A simple example of argparse")
    parser.add_argument("--model_path", type=str, default='', help="The path of the custom model (only useful for octo-custom?)")
    parser.add_argument("--seeds", type=str, default='0+1+2+3', help="seeds for policy and env")
    parser.add_argument("--step", type=int, default=None, help="checkpoint step to evaluate")
    parser.add_argument("--split", type=str, default='train', help="evaluate on the train or test split")
    parser.add_argument("--save_video", action='store_true', help="save evaluation video or not")
    parser.add_argument("--action_ensemble", action='store_true', help="Use action ensemble or not")
    parser.add_argument("--image_horizon", type=int, default=2, help="The horizon of image history")
    parser.add_argument("--recompute", action='store_true', help="Whether to recompute for existing results")
    parser.add_argument("--env_num", type=int, default=20, help="The number of evaluation episodes")
    # Parse the arguments
    args = parser.parse_args()

    seeds = [eval(seed) for seed in args.seeds.split('+')]
    for seed in seeds:
        evaluate(
            args.model_path, 
            checkpoint_step=args.step, 
            seed=seed, 
            split=args.split, 
            save_video=args.save_video, 
            action_ensemble=args.action_ensemble, 
            image_horizon=args.image_horizon, 
            recompute=args.recompute,
            env_num=args.env_num,
        )
