import os
from typing import Any, Union, Dict, OrderedDict
import pprint
import tqdm

import cv2
import gym
import numpy as np
import pickle as pkl
import metaworld
from metaworld.policies import *
from metaworld.policies.policy import Policy

from stable_baselines3.common import env_util
from stable_baselines3.common import vec_env
from stable_baselines3.common import base_class, policies

from wrapper import MetaWorldWrapper


def check_env(
    config: Dict[str, Any],
    benchmark,
):
    valid_envs = benchmark.train_classes.keys()
    print(valid_envs)
    if config.env in valid_envs:
        return True
    return False


def make_env(
    config: Dict[str, Any],
    benchmark,
    seed,
    expert=None,
    dict_obs=False, 
):

    env_kwargs = {
        "mt": benchmark,
        "task": config.env,
        "max_step": config.max_step,
        "terminate": config.terminate,
        "expert": expert,
        "dict_obs": dict_obs,
        "xwind_id": config.xwind_id, 
        "gravity_id": config.gravity_id,
        "camera_id": config.fov,
        "bright_id": config.bright_id,
        "contrast_id": config.contrast_id,
        "saturation_id": config.saturation_id,
        "hue_id": config.hue_id,
        "set_source": config.source,
        "seed": seed,
    }
    return MetaWorldWrapper(**env_kwargs)
            
            
def make_vec_env(
    config: Dict[str, Any],
    benchmark,
    seed,
    expert=None,
    dict_obs=False, 
    training=True,
    norm_obs=True,
    norm_reward=True,
):
    env_kwargs = {
        "mt": benchmark,
        "task": config.env,
        "max_step": config.max_step,
        "terminate": config.terminate,
        'expert': expert,
        'dict_obs': dict_obs,
        "xwind_id": config.xwind_id, 
        "gravity_id": config.gravity_id,
        "camera_id": config.fov,
        "bright_id": config.bright_id,
        "contrast_id": config.contrast_id,
        "saturation_id": config.saturation_id,
        "hue_id": config.hue_id,
        "set_source": config.source,
        "seed": seed,
    }
    pprint.pprint(env_kwargs)
    venv = env_util.make_vec_env(MetaWorldWrapper, env_kwargs=env_kwargs)
    return vec_env.VecNormalize(venv, training=training, norm_obs=norm_obs, norm_reward=norm_reward)


def evaluate_env(
    env: gym.Env,
    policy: Union[Policy, policies.BasePolicy] = None,
    n_eval_episodes: int = 100, 
    verbose: bool = False,
):
    epi_rews, epi_sucs, epi_lens, epi_infos = [], [], [], []
    for _ in tqdm.tqdm(range(n_eval_episodes), total=n_eval_episodes, desc = 'Evaluation', ncols = 70, ascii = ' =', leave = True,):
        obs = env.reset()
        epi_rew, epi_len = 0, 0 
        while True:
            if policy is None:
                act = [env.action_space.sample()]
            elif isinstance(policy, Policy):
                act = [policy.get_action(env.envs[0].expert_obs)]
            elif isinstance(policy, policies.BasePolicy):
                act, _ = policy.predict(obs)
            else:
                raise NotImplementedError
            obs, rew, done, info = env.step(act)
            if isinstance(info, list): info = info[0]
            epi_rew += rew; epi_len += 1
            if info['is_success']: epi_suc = True
            else: epi_suc = False
            if done: break
        if not epi_suc:
            epi_infos.append(info['metadata']) # save meatadata for unsuccessful metadatas
        epi_rews.append(epi_rew)
        epi_sucs.append(epi_suc)
        epi_lens.append(epi_len)

    avg_rew, std_rew = np.mean(epi_rews), np.std(epi_rews)
    avg_len, std_len = np.mean(epi_lens), np.std(epi_lens)
    avg_suc = np.sum(epi_sucs) / n_eval_episodes

    if verbose:
        print(f"Average Return: {avg_rew:.2f} +/- {std_rew:.2f}")
        print(f"Average Length: {avg_len:.2f} +/- {std_len:.2f}")
        print(f"Average Succes: {avg_suc*100:.2f}%")
    
    xwind, gravity, bright, contrast, saturation, hue = [0 for i in range(10)], [0 for i in range(10)], \
            [0 for i in range(10)], [0 for i in range(10)], [0 for i in range(10)], [0 for i in range(10)]

    for info in epi_infos: 
        xwind[info['xwind']] += 1
        gravity[info['gravity']] += 1
        bright[info['bright']] += 1
        contrast[info['contrast']] += 1
        saturation[info['saturation']] += 1
        hue[info['hue']] += 1

    if verbose:
        print(f"<<< ! ! ! Failed Cases ! ! ! >>>")
        print(f"WIND {xwind}")
        print(f"GRAV {gravity}")
        print(f"BRIG {bright}")
        print(f"CONT {contrast}")
        print(f"SATU {saturation}")
        print(f"HUEE {hue}")
        
    return avg_rew, avg_len, avg_suc

FOVS = ['cam0-0', 'cam0-1', 'cam0-2', 'cam0-3', 'cam0-4',
        'cam1-0', 'cam1-1', 'cam1-2', 'cam1-3', 'cam1-4']
CAMS = ['cam00', 'cam04', 'cam08', 'cam12', 'cam16',
        'cam20', 'cam24', 'cam28', 'cam32', 'cam36']

def evaluate_dfs_env(
    env: gym.Env,
    policy: policies.BasePolicy,
    path: str,
    n_eval_episodes: int=10,
    verbose: bool=True,
):
    avg_rews, avg_lens, avg_sucs = [], [], []
    for fidx in range(10): # fov
        for widx in range(10): # xwind
            for gidx in range(10): # gravity
                for bidx in range(10): # brightness
                    for cidx in range(10): # contrast
                        env.envs[0].set_dfs(CAMS[fidx], widx, gidx, bidx, cidx) 
                        avg_rew, avg_len, avg_suc = evaluate_env(env, policy, n_eval_episodes, verbose)
                        avg_rews.append(avg_rew); avg_lens.append(avg_len); avg_sucs.append(avg_suc)

                        with open(path + '/dfs.txt', "a") as f:
                            text = f"{CAMS[idx]}-WIND{widx}-GRAV{gidx}-BRIGHT{bidx}-CONTRAST{cidx}: {avg_rew:.2f} {avg_len:.2f} {avg_suc:.2f}\n"
                            f.write(text)
                            if verbose: print(text)
    
    avg_rews, avg_lens, avg_sucs = np.array(avg_rews), np.array(avg_lens), np.array(avg_sucs)  
    print(f"Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lengs):.2f}; {np.mean(avg_sucs):.2f}")
    with open(path + '/dfs.txt', "a") as f:
        text = f"Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lengs):.2f}; {np.mean(avg_sucs):.2f}"
        f.write(text)


def evaluate_short_dfs_env(
    env: gym.Env,
    policy: policies.BasePolicy,
    path: str,
    n_eval_episodes: int=10,
    verbose: bool=True,
):
    path = path + '/dfs_short.txt'

    # ============= CAMS ============== #
    """
    avg_rews, avg_lens, avg_sucs = [], [], []
    for fidx in range(10): # fov
        env.envs[0].set_dfs(CAMS[fidx], 0, 0, 0, 0) 
        avg_rew, avg_len, avg_suc = evaluate_env(env, policy, n_eval_episodes, verbose)
        avg_rews.append(avg_rew); avg_lens.append(avg_len); avg_sucs.append(avg_suc)
        with open(path, "a") as f:
            text = f"{CAMS[fidx]}: {avg_rew:.2f} {avg_len:.2f} {avg_suc:.2f}\n"
            f.write(text)
            if verbose: print(text)
    avg_rews, avg_lens, avg_sucs = np.array(avg_rews), np.array(avg_lens), np.array(avg_sucs)  
    with open(path, "a") as f:
        text = f"CAMS-Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lens):.2f}; {np.mean(avg_sucs):.2f}\n\n"
        f.write(text)
        if verbose: print(text)
    """

    # ============= WIND ============== #
    avg_rews, avg_lens, avg_sucs = [], [], []
    for widx in range(10): # xwind 
        env.envs[0].set_dfs('cam00', widx, 0, 0, 0, 0, 0) 
        avg_rew, avg_len, avg_suc = evaluate_env(env, policy, n_eval_episodes, verbose)
        avg_rews.append(avg_rew); avg_lens.append(avg_len); avg_sucs.append(avg_suc)
        with open(path, "a") as f:
            text = f"WIND{widx}: {avg_rew:.2f} {avg_len:.2f} {avg_suc:.2f}\n"
            f.write(text)
            if verbose: print(text)
    avg_rews, avg_lens, avg_sucs = np.array(avg_rews), np.array(avg_lens), np.array(avg_sucs)  
    with open(path, "a") as f:
        text = f"WIND-Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lens):.2f}; {np.mean(avg_sucs):.2f}\n\n"
        f.write(text)
        if verbose: print(text)

    # ============= GRAVITY ============== #
    avg_rews, avg_lens, avg_sucs = [], [], []
    for gidx in range(10): # gravity 
        env.envs[0].set_dfs('cam00', 0, gidx, 0, 0, 0, 0) 
        avg_rew, avg_len, avg_suc = evaluate_env(env, policy, n_eval_episodes, verbose)
        avg_rews.append(avg_rew); avg_lens.append(avg_len); avg_sucs.append(avg_suc)
        with open(path, "a") as f:
            text = f"GRAV{gidx}: {avg_rew:.2f} {avg_len:.2f} {avg_suc:.2f}\n"
            f.write(text)
            if verbose: print(text)
    avg_rews, avg_lens, avg_sucs = np.array(avg_rews), np.array(avg_lens), np.array(avg_sucs)  
    with open(path, "a") as f:
        text = f"GRAV-Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lens):.2f}; {np.mean(avg_sucs):.2f}\n\n"
        f.write(text)
        if verbose: print(text)

    # ============= BRIGHT ============== #
    avg_rews, avg_lens, avg_sucs = [], [], []
    for bidx in range(10): # BRIGHT
        env.envs[0].set_dfs('cam00', 0, 0, bidx, 0, 0, 0) 
        avg_rew, avg_len, avg_suc = evaluate_env(env, policy, n_eval_episodes, verbose)
        avg_rews.append(avg_rew); avg_lens.append(avg_len); avg_sucs.append(avg_suc)
        with open(path, "a") as f:
            text = f"BRIG{bidx}: {avg_rew:.2f} {avg_len:.2f} {avg_suc:.2f}\n"
            f.write(text)
            if verbose: print(text)
    avg_rews, avg_lens, avg_sucs = np.array(avg_rews), np.array(avg_lens), np.array(avg_sucs)  
    with open(path, "a") as f:
        text = f"BRIG-Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lens):.2f}; {np.mean(avg_sucs):.2f}\n\n"
        f.write(text)
        if verbose: print(text)

    # ============= CONTRAST ============== #
    avg_rews, avg_lens, avg_sucs = [], [], []
    for cidx in range(10): # contrast
        env.envs[0].set_dfs('cam00', 0, 0, 0, cidx, 0, 0) 
        avg_rew, avg_len, avg_suc = evaluate_env(env, policy, n_eval_episodes, verbose)
        avg_rews.append(avg_rew); avg_lens.append(avg_len); avg_sucs.append(avg_suc)
        with open(path, "a") as f:
            text = f"CONT{cidx}: {avg_rew:.2f} {avg_len:.2f} {avg_suc:.2f}\n"
            f.write(text)
            if verbose: print(text)
    avg_rews, avg_lens, avg_sucs = np.array(avg_rews), np.array(avg_lens), np.array(avg_sucs)  
    with open(path, "a") as f:
        text = f"CONT-Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lens):.2f}; {np.mean(avg_sucs):.2f}\n\n"
        f.write(text)
        if verbose: print(text)

    # ============= SATURATION ============== #
    avg_rews, avg_lens, avg_sucs = [], [], []
    for cidx in range(10): # contrast
        env.envs[0].set_dfs('cam00', 0, 0, 0, 0, cidx, 0) 
        avg_rew, avg_len, avg_suc = evaluate_env(env, policy, n_eval_episodes, verbose)
        avg_rews.append(avg_rew); avg_lens.append(avg_len); avg_sucs.append(avg_suc)
        with open(path, "a") as f:
            text = f"SAT{cidx}: {avg_rew:.2f} {avg_len:.2f} {avg_suc:.2f}\n"
            f.write(text)
            if verbose: print(text)
    avg_rews, avg_lens, avg_sucs = np.array(avg_rews), np.array(avg_lens), np.array(avg_sucs)  
    with open(path, "a") as f:
        text = f"SAT-Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lens):.2f}; {np.mean(avg_sucs):.2f}\n\n"
        f.write(text)
        if verbose: print(text)

    # ============= HUE ============== #
    avg_rews, avg_lens, avg_sucs = [], [], []
    for cidx in range(10): # contrast
        env.envs[0].set_dfs('cam00', 0, 0, 0, 0, 0, cidx ) 
        avg_rew, avg_len, avg_suc = evaluate_env(env, policy, n_eval_episodes, verbose)
        avg_rews.append(avg_rew); avg_lens.append(avg_len); avg_sucs.append(avg_suc)
        with open(path, "a") as f:
            text = f"HUE{cidx}: {avg_rew:.2f} {avg_len:.2f} {avg_suc:.2f}\n"
            f.write(text)
            if verbose: print(text)
    avg_rews, avg_lens, avg_sucs = np.array(avg_rews), np.array(avg_lens), np.array(avg_sucs)  
    with open(path, "a") as f:
        text = f"HUE-Averge: {np.mean(avg_rews):.2f}; {np.mean(avg_lens):.2f}; {np.mean(avg_sucs):.2f}\n\n"
        f.write(text)
        if verbose: print(text)


def render_env(
    config: Dict[str, Any],
    env: gym.Env,
    policy: policies.BasePolicy, 
    fov: str,
    path: str,
    collect: bool = False,
    verbose: bool = False,
):

    # For data collection
    data_dict = dict()
    obss, acts, rews, dones, infos = [], [], [], [], []

    # Save 1-episode images
    images = []
    obs = env.reset()
    while True:
        images.append(env.envs[0].render(camera_name=fov, resolution=(224, 224), original=True)) # for vec env
        if policy is None:
            act = [env.action_space.sample()]
        elif isinstance(policy, Policy):
            act = [policy.get_action(obs[0])]
        elif isinstance(policy, policies.BasePolicy):
            act, _ = policy.predict(obs)
        else:
            raise NotImplementedError 
        next_obs, rew, done, info = env.step(act)
        
        if collect:
            obss.append(obs)
            acts.append(act)
            rews.append(rew)
            dones.append(done)
            infos.append(info)
        obs = next_obs

        if isinstance(info, list): info = info[0]
        if info['is_success']: epi_suc = True
        else: epi_suc = False
        if done: break
    
    path = f'./datasets/{config.env[:-3]}/{config.domain_factor}'
    if config.domain_factor == 'FOV':
        data_path = path + f'/{config.env}_{fov}_{config.tag}.pkl'
        video_path = path + f'/{config.env}_{fov}_{config.tag}.avi'
    if config.domain_factor == 'XWIND':
        data_path = path + f'/{config.env}_{config.xwind_id}_{config.tag}.pkl'
        video_path = path + f'/{config.env}_{config.xwind_id}_{config.tag}.avi'
    if config.domain_factor == 'GRAV':
        data_path = path + f'/{config.env}_{config.gravity_id}_{config.tag}.pkl'
        video_path = path + f'/{config.env}_{config.gravity_id}_{config.tag}.avi'
    if config.domain_factor == 'TEST':
        path = f'./datasets/{config.domain_factor}'
        data_path = path + f'/{config.env}_{config.gravity_id}_{config.tag}.pkl'
        video_path = path + f'/{config.env}_{config.gravity_id}_{config.tag}.avi'
    os.makedirs(path, exist_ok=True)
    
    # Collect Data
    if collect:
        data_dict['images'] = np.array(images)
        data_dict['observations'] = np.array(obss)
        data_dict['actions'] = np.array(acts)
        data_dict['rewards'] = np.array(rews)
        data_dict['terminals'] = np.array(dones)
        data_dict['infos'] = infos
        with open(data_path, 'wb') as f:
            pkl.dump(data_dict, f)
        print(f"Success ?: {epi_suc}")
        print(f"Env({config.env}) video saved at {data_path}")
        return 

    # Save video
    h, w, c = images[0].shape
    video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'DIVX'), 20, (w, h))
    for image in images:
        #image = (image * 255).astype(np.uint8)
        video.write(image)
    video.release()
    
    if verbose:
        print(f"Success ?: {epi_suc}")
        print(f"Env({config.env}) video saved at {video_path}")
    

def get_expert(
    config: Dict[str, Any],
    random: bool = False,
):
    if random:
        return None

    if config.env == 'reach-v2':
        return SawyerReachV2Policy()
    elif config.env == 'push-v2':
        return SawyerPushV2Policy()
    elif config.env == 'pick-place-v2':
        return SawyerPickPlaceV2Policy()
    elif config.env == 'window-open-v2':
        return SawyerWindowOpenV2Policy()
    elif config.env == 'window-close-v2':
        return SawyerWindowCloseV2Policy()
    elif config.env == 'drawer-open-v2':
        return SawyerDrawerOpenV2Policy()
    elif config.env == 'drawer-close-v2':
        return SawyerDrawerCloseV2Policy()
    elif config.env == 'button-press-topdown-v2':
        return SawyerButtonPressTopdownV2Policy()
    elif config.env == 'door-open-v2':
        return SawyerDoorOpenV2Policy()
    elif config.env == 'peg-insert-side-v2':
        return SawyerPegInsertSideV2Policy()
    elif config.env == 'reach-wall-v2':
        return SawyerReachWallV2Policy()
    elif config.env == 'plate-slide-v2':
        return SawyerPlateSlideV2Policy()
    else:
        raise NotImplementedError
