import os
import numpy as np
import json
from PIL import Image
from pathlib import Path
from collections import Counter, defaultdict

import hydra
from omegaconf import OmegaConf
from termcolor import colored

from calvin.calvin_env.calvin_env.envs.play_table_env import get_env


def make_calvin_env(dataset_path):
    val_folder = Path(dataset_path) / "validation"
    env = get_env(val_folder, show_gui=False)
    return env

def evaluate_policy(env, playbook_agent, available_tasks, config):
    conf_dir = Path(__file__).absolute().parents[1] / "calvin/conf"
    task_cfg = OmegaConf.load(conf_dir / "callbacks/rollout/tasks/new_playtable_tasks.yaml")
    task_oracle = hydra.utils.instantiate(task_cfg)
    # get evaluation scenarios from the validation dataset
    # len_task_chain = config.len_task_chain
    # if len_task_chain == 1: episode_length = 180
    # else: episode_length = 360

    eval_sequences = get_tasks_from_tacorl(available_tasks, config.data_dir, len_tasks=config.len_task_chain, max_episodes=config.eval_episodes)
    plans = defaultdict(list)

    results = []
    for epi_idx, eval_sequence in enumerate(eval_sequences):
        task = eval_sequence["completed_tasks"]
        result = evaluate_sequence(
            env, playbook_agent, task, task_oracle, eval_sequence, epi_idx, config
        )
        results.append(result)
        results_print = "    [Success-Rate]"
        for i in range(config.len_task_chain):
            results_print += "  {:.3f}".format(np.mean(results, axis=0)[i])
            if i+1 != config.len_task_chain: results_print += "  ->"
        print(results_print)
    print()
    return results, plans


def evaluate_sequence(
    env, playbook_agent, goal_task, task_checker, eval_sequence, episode_index, config
):
    robot_obs, scene_obs = eval_sequence["robot_obs"], eval_sequence["scene_obs"]
    env.reset(robot_obs=robot_obs, scene_obs=scene_obs)

    goal_obs = eval_sequence["goal_obs"]
    success = rollout(env, playbook_agent, goal_obs, goal_task, task_checker, episode_index, config)
    return success


def rollout(env, playbook_agent, goal_obs, subtask, task_oracle, episode_index, config):
    obs = env.get_obs()
    start_info = env.get_info()

    print()
    task_name = ""
    for i, t in enumerate(subtask):
        task_name += t
        if i+1 != len(subtask): task_name += " -> "
    print("[TASK-ORDER: {}]".format(task_name))
    
    results = np.zeros((len(subtask),))

    H, play, save_trajectory, debug = config.window_size, None, True, True
    if save_trajectory:
        save_trajectory_path = "./results/calvin/{}/record/{}task/{}/episode{:03d}".format(config.loadname,len(subtask),task_name,episode_index+1)
        if not os.path.exists(save_trajectory_path): os.makedirs(save_trajectory_path)
        
        obs_ = Image.fromarray(goal_obs)
        obs_.save(save_trajectory_path+"/goal_state.jpg")

    if config.len_task_chain == 1: ep_len = 180
    else: ep_len = 360

    for step in range(ep_len):
        obs = obs['rgb_obs']['rgb_static']
        if step % H == 0:
            init_obs, play, emb_idx = np.copy(obs), None, None
        if save_trajectory:
            obs_ = Image.fromarray(obs)
            obs_.save(save_trajectory_path+"/step{:03d}.jpg".format(step))

        action, play, emb_idx = playbook_agent.get_action(obs, goal_obs, play, emb_idx)
        obs, _, _, current_info = env.step(action)

        num_bars = 50
        progress_ = int((step+1)/ep_len*num_bars)
        percent_ = (step+1)/ep_len*100
        
        if debug:
            img = env.render(mode="rgb_array")

        current_task_info = task_oracle.get_task_info_for_set(start_info, current_info, subtask)
        remain_tasks = len(subtask)-len(current_task_info)
        results[:len(current_task_info)] = 1.0

        print('    [EPISODE#{:03d}][Progress {}{}:{:.1f}%] Remains: {}       '\
            .format(episode_index+1, '█'*progress_, ' '*(num_bars-progress_), percent_, remain_tasks), end='\r')
        
        if len(current_task_info) == len(subtask):
            print('    [EPISODE#{:03d}][Progress {}{}:{:.1f}%] Remains: {}  '\
                .format(episode_index+1, '█'*progress_, ' '*(num_bars-progress_), percent_, remain_tasks), end="")        
            if debug:
                for r_i, r in enumerate(results):
                    if r: print(colored("success", "green"), end="")
                    else: print(colored("fail", "red"), end="")
                    if r_i+1 != len(results): print("  ", end="")
                print()
            return results
            
    print('    [EPISODE#{:03d}][Progress {}{}:{:.1f}%] Remains: {}  '\
        .format(episode_index+1, '█'*progress_, ' '*(num_bars-progress_), percent_, remain_tasks), end="")        

    if debug:
        for r_i, r in enumerate(results):
            if r: print(colored("success", "green"), end="")
            else: print(colored("fail", "red"), end="")
            if r_i+1 != len(results): print("  ", end="")
        print()
    return results


def get_subtasks():
    subtask_list = [
        # drawer
        'close_drawer',
        'open_drawer',
        # slider
        'move_slider_left',
        'move_slider_right',
        # led
        'turn_on_led',
        'turn_off_led',
        # lightbulb
        'turn_on_lightbulb',
        'turn_off_lightbulb',
        # blocks
        'stack_block', 'unstack_block',
        'place_in_drawer', 'place_in_slider', 'push_into_drawer',
        # using the blue block
        'lift_blue_block_drawer', 'lift_blue_block_slider', 'lift_blue_block_table',
        'push_blue_block_left', 'push_blue_block_right',
        'rotate_blue_block_left', 'rotate_blue_block_right',
        # # using the pink block
        # 'lift_pink_block_drawer', 'lift_pink_block_slider', 'lift_pink_block_table',
        # 'push_pink_block_left', 'push_pink_block_right',
        # 'rotate_pink_block_left', 'rotate_pink_block_right',
        # # using the red block
        # 'lift_red_block_drawer', 'lift_red_block_slider', 'lift_red_block_table',
        # 'push_red_block_left', 'push_red_block_right',
        # 'rotate_red_block_left', 'rotate_red_block_right',
    ]
    return subtask_list


def get_tasks_from_tacorl(available_tasks, data_dir, len_tasks=1, max_episodes=100):
    indx_dir = data_dir+"/start_end_tasks.json"
    with open(indx_dir) as f:
        start_end_tasks = json.load(f)

    max_seq_len = 300
    if len_tasks == 3: min_seq_len = 120
    elif len_tasks == 2: min_seq_len = 100
    else: min_seq_len = 16

    eval_sequences = []
    for start_idx, end_tasks in start_end_tasks.items():
        for end_idx, completed_tasks in end_tasks.items():
            if len(completed_tasks) == len_tasks:
                valid_case = True
                for task in completed_tasks:
                    if not task in available_tasks: valid_case = False
                if not valid_case: continue

                start_step = int(start_idx)
                end_step = int(end_idx)
                start_dir = data_dir+"/episode_{:07d}.npz".format(start_step)
                end_dir = data_dir+"/episode_{:07d}.npz".format(end_step)

                seq_len = end_step - start_step
                if max_seq_len > seq_len > min_seq_len:
                    init_state = np.load(start_dir, mmap_mode="r")
                    goal_state = np.load(end_dir, mmap_mode="r")

                    episode_ = {}
                    episode_["completed_tasks"] = completed_tasks
                    episode_["scene_obs"] = init_state["scene_obs"]
                    episode_["robot_obs"] = init_state["robot_obs"]
                    episode_["goal_obs"] = goal_state["rgb_static"]
                    episode_["start_step"] = start_step
                    episode_["end_step"] = end_step
                    episode_["seq_len"] = seq_len
                    eval_sequences.append(episode_)

    assert len(eval_sequences) > 0
    eval_sequences = sorted(eval_sequences, key=lambda d: d["seq_len"])

    # max_len = 100 if len_tasks == 1 else 1000 
    # if len(eval_sequences) > max_len:
    #     eval_sequences = eval_sequences[:max_len]
    # eval_sequences = eval_sequences[:max_episodes]
    eval_sequences = eval_sequences[:max_episodes*5:5]
    return eval_sequences


def evaluate_calvin(env, playbook_agent, config):
    available_tasks = get_subtasks()

    if config.eval_type == "individually":
        goal_results, total_results = [], []
        for goal in available_tasks:
            result, _ = evaluate_policy(env, playbook_agent, [goal], config)
            goal_results.append(result)
            total_results += list(result)

        for goal, result in zip(available_tasks, goal_results):
            print("[{}] Success Rate: {:.3f} ({}/{}) :D ".format(goal.upper(), np.mean(result), int(np.sum(result)), int(np.sum(np.ones(np.shape(result))))))
        print()

        print("[TOTAL] Success Rate: {:.3f} ({}/{}) :D ".format(np.mean(total_results), int(np.sum(total_results)), len(total_results)))
        print()

    elif config.eval_type == "in_a_row":
        goal_results, _ = evaluate_policy(env, playbook_agent, available_tasks, config)
        print()

        goal_results = np.mean(goal_results, axis=0)
        print("[CHAIN] Success Rate:", end="")
        for task_i in range(len(goal_results)):
            print(" {:.3f}".format(goal_results[task_i]), end="")
            if task_i+1 != len(goal_results): print(" ->", end="")
        print("  :D")
        print()

    return goal_results


