import os
import numpy as np
import click
import json
import torch
from meta_test_algo.network import es_policy2
from meta_test_algo.render_utils import initialize_viewer,render
from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from configs.default import default_config

from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

        
    

def deep_update_dict(fr, to):
    ''' update dict of dicts with new values '''
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to

@click.command()
@click.option('--test_env',default=None)

def main(test_env):
    # load config
    variant = default_config
    if test_env == 'cheetah-dir':
        config = './configs/cheetah-dir.json'
    if test_env == 'cheetah-vel':
        config = './configs/cheetah-vel.json'
    elif test_env == 'ant-goal':
        config = './configs/ant-goal.json'
    elif test_env == 'ant-dir':
        config = './configs/ant-dir.json'
    elif test_env == 'humanoid-dir':
        config = './configs/humanoid-dir.json'
    elif test_env == 'walker-rand-params':
        config = './configs/walker_rand_params.json'
        test_env = 'walker_rand_params'
    elif test_env == 'hopper-rand-params':
        config = './configs/hopper_rand_params.json'
        test_env = 'hopper_rand_params'
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    


    # set env
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    if test_env == 'cheetah-vel':
        env.set_velocity(-2.0) # set velocity (-2)
    elif test_env == 'cheetah-dir':
        env.set_direction(-1) # set direction (forward)
    elif test_env == 'ant-goal':
        env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
    elif test_env == 'ant-dir':
        env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
    elif test_env =='humanoid-dir':
        env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
    elif 'params' in test_env:
        env.set_test_task()

    max_path_length = variant['meta_test_params']['max_path_length']
    max_path_length = 1000
    


    # set network parameters
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    latent_action_dim = variant['sr_params']['latent_action_dim']
    net_size = variant['net_size']
    w = variant['esq_params']['w']
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    
    # load policy
    policy = es_policy2(obs_dim,action_dim,net_size,
                        latent_dim=latent_action_dim,
                        w = w).to(device)
    policy.eval()


    # set test environment
    seeds = [1,2,3,4,5,6,7,8,9,10]
    random_env_seeds = [101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120]
    # random_env_seeds = [101,102]
    types = ['proposed','original']

    
    # run evaluation
    means = []
    stds = []
    ci95s = []
    for t in types:
        score = []
        for seed in tqdm(seeds, desc='Evaluation'):
            policy.load_state_dict(torch.load(f'./reference_data/{test_env}/bc_policy_{t}({seed}).pt'))
            with torch.no_grad():
                returns = []
                for random_env_seed in random_env_seeds:
                    env.set_seed(random_env_seed)
                    obs = env.reset()
                    env_step = 0
                    episode_return = 0
                    while env_step < max_path_length:
                        env_step += 1
                        obs = torch.Tensor(obs).to(device)
                        action = policy(obs).cpu().numpy()
                        next_obs, reward, done, env_info = env.step(action)
                        episode_return += reward
                        obs = next_obs
                        if done and not(test_env=='ant-goal'):
                            break
                    returns.append(episode_return)
                mean_return = sum(returns)/len(returns)
                score.append(mean_return)
        mean = np.mean(score)
        std  = np.std(score, ddof=1)
        ci95 = 1.96 * std / np.sqrt(len(score))
        print(f"mean={mean:.1f}, 95%CI=±{ci95:.1f}, std={std:.1f}")
        means.append(mean)
        stds.append(std)
        ci95s.append(ci95)
    plot(means=means,stds=stds,ci95s=ci95s,test_env=test_env)


def plot(means, stds, ci95s, test_env='ant-goal'):
    means = np.asarray(means, float)
    stds  = np.asarray(stds,  float)   # 표시는 안 쓰지만 시그니처 유지
    ci95s = np.asarray(ci95s, float)   # 95% CI half-width (대칭 에러바)

    # 1) offset 크기 계산: 가장 낮은 CI 하한이 0 이상 되도록
    lo = means - ci95s
    if test_env == 'ant-goal':
        min_return = -3000
    elif test_env == 'ant-dir':
        min_return = 0
    elif test_env == 'cheetah-vel':
        min_return = -2000
    elif test_env == 'cheetah-dir':
        min_return = 0
    elif 'params' in test_env:
        min_return = 0
    c = max(-np.min(lo), -min_return)          # c를 더하면 모든 (mean-ci) ≥ 0

    # 2) offset 적용해서 그리기
    x = np.arange(len(means))
    means_plot = means + c             # 막대만 올림 (에러바 half-width는 그대로)
    yerr = ci95s

    fig, ax = plt.subplots()
    ax.bar(x, means_plot)
    ax.errorbar(x, means_plot, yerr=yerr, fmt='none',
                capsize=8, elinewidth=2, ecolor='black', zorder=3)

    # 3) y축 라벨은 "원래 스케일"로 보이도록 포매터 적용
    ax.yaxis.set_major_formatter(FuncFormatter(lambda v, pos: f"{v - c:.1f}"))

    # ★ 하한을 원래 스케일 -600으로 고정 → 오프셋 좌표에선 -600 + c
    ax.set_ylim(bottom=min_return + c)  # 상한은 자동 유지

    ax.set_xticks(x)
    ax.set_xticklabels(['Pretrained backbone (frozen)\n + linear layers (BC)', 'Full-policy BC\n(all layers)'])
    ax.set_ylabel("Mean Return")
    if test_env == 'ant-goal':
        ax.set_title("Ant Goal")
    elif test_env == 'ant-dir':
        ax.set_title("Ant Direction")
    elif test_env == 'cheetah-vel':
        ax.set_title("Cheetah Velocity")
    elif test_env == 'cheetah-dir':
        ax.set_title("Cheetah Direction")
    elif test_env == 'walker_rand_params':
        ax.set_title("Walker Random Params")
    elif test_env == 'hopper_rand_params':
        ax.set_title("Hopper Random Params")
    plt.tight_layout()
    fig.subplots_adjust(right=0.66)
    plt.savefig(f"{test_env}_bc.png", dpi=600, bbox_inches='tight')

if __name__ == "__main__":
    main()