import numpy as np
import os
import matplotlib.pyplot as plt
from scipy.stats import sem, t
from matplotlib.ticker import MaxNLocator

# 定义文件夹名称
env_names={
           'Alien-v5':{'llm_temperature':1.0,'adjust_frequency':1,'sample_rate':100,'num_step':100000},
           'Freeway-v5':{'llm_temperature':1.0,'adjust_frequency':1,'sample_rate':100,'num_step':300000},
           'MsPacman-v5':{'llm_temperature':1.0,'adjust_frequency':1,'sample_rate':100,'num_step':100000}
           }

human_scores = {
        'Alien-v5': 7127.7,
        'Amidar-v5': 1719.5,
        'BankHeist-v5': 753.1,
        'Breakout-v5': 30.5,
        'ChopperCommand-v5': 7387.8,
        'CrazyClimber-v5': 35829.4,
        'DemonAttack-v5': 1971.0,
        'Freeway-v5': 29.6,
        'Frostbite-v5': 4334.7,
        'Hero-v5': 30826.4,
        'Jamesbond-v5': 302.8,
        'Krull-v5': 2665.5,
        'MsPacman-v5': 6951.6,
        'Pong-v5': 14.6,
        'Qbert-v5': 13455.0,
        'Seaquest-v5': 4254.7,
        'UpNDown-v5': 11693.2
}

random_scores = {
        'Alien-v5': 227.8,
        'Amidar-v5': 5.8,
        'BankHeist-v5': 14.2,
        'Breakout-v5': 1.7,
        'ChopperCommand-v5': 811.0,
        'CrazyClimber-v5': 10780.5,
        'DemonAttack-v5': 252.1,
        'Freeway-v5': 0.0,
        'Frostbite-v5': 65.2,
        'Hero-v5': 1027.0,
        'Jamesbond-v5': 29.0,
        'Krull-v5': 1598.0,
        'MsPacman-v5': 307.3,
        'Pong-v5': -20.7,
        'Qbert-v5': 163.9,
        'Seaquest-v5': 68.4,
        'UpNDown-v5': 533.4
}

for env_name,paras in env_names.items():
    print('\n'+env_name.strip('-v5'),end='\t')
    llm_temperature = paras['llm_temperature']
    adjust_frequency = paras['adjust_frequency']
    sample_rate = paras['sample_rate']
    num_step = paras['num_step']

    base_VanillaDQN_folders = list()
    base_DQN_folders = list()
    base_DuelingDQN_folders = list()
    base_PERDQN_folders = list()
    base_Rainbow_folders = list()
    base_CURL_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_base', env_name)):
        if 'Vanilla_DQN' in folder:
            base_VanillaDQN_folders.append(os.path.join('..', 'model_base', env_name, folder))
        elif 'Dueling_DQN' in folder:
            base_DuelingDQN_folders.append(os.path.join('..', 'model_base', env_name, folder))
        elif 'PER_DQN' in folder:
            base_PERDQN_folders.append(os.path.join('..', 'model_base', env_name, folder))
        elif 'Rainbow' in folder:
            base_Rainbow_folders.append(os.path.join('..', 'model_base', env_name, folder))
        elif 'CURL' in folder:
            base_CURL_folders.append(os.path.join('..', 'model_base', env_name, folder))
        else:
            base_DQN_folders.append(os.path.join('..', 'model_base', env_name, folder))


    llm_main_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_main', env_name)):
        llm_main_folders.append(os.path.join('..', 'model_main', env_name, folder))

    llm_model_GPT4o_folders = list()
    llm_model_GPT35_folders = list()
    llm_model_Llama31_405B_folders = list()
    llm_model_Llama31_70B_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_llmtype', env_name)):
        if 'gpt-3.5-turbo-0125' in folder:
            llm_model_GPT35_folders.append(os.path.join('..', 'model_llmtype', env_name, folder))
        elif 'gpt-4o-2024-08-06' in folder:
            llm_model_GPT4o_folders.append(os.path.join('..', 'model_llmtype', env_name, folder))
        elif 'Llama-3.1-405B' in folder:
            llm_model_Llama31_405B_folders.append(os.path.join('..', 'model_llmtype', env_name, folder))
        elif 'Llama-3.1-70B' in folder:
            llm_model_Llama31_70B_folders.append(os.path.join('..', 'model_llmtype', env_name, folder))

    llm_prompt_name_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_prompt', env_name)):
        llm_prompt_name_folders.append(os.path.join('..', 'model_prompt', env_name, folder))

    llm_ablation_folders = list()    
    for folder in os.listdir(os.path.join('..', 'model_ablation', env_name)):
        llm_ablation_folders.append(os.path.join('..', 'model_ablation', env_name, folder))

    llm_adjust_2x_folders = list()
    llm_adjust_3x_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_adjust', env_name)):
        if f'adjust{2*adjust_frequency}' in folder:
            llm_adjust_2x_folders.append(os.path.join('..', 'model_adjust', env_name, folder))
        elif f'adjust{3*adjust_frequency}' in folder:
            llm_adjust_3x_folders.append(os.path.join('..', 'model_adjust', env_name, folder))

    llm_sample_2x_folders = list()
    llm_sample_05x_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_sample', env_name)):
        if f'sample{2*sample_rate}' in folder:
            llm_sample_2x_folders.append(os.path.join('..', 'model_sample', env_name, folder))
        elif f'sample{int(0.5*sample_rate)}' in folder:
            llm_sample_05x_folders.append(os.path.join('..', 'model_sample', env_name, folder))

    llm_algorithm_VanillaDQN_folders = list()
    llm_algorithm_DuelingDQN_folders = list()
    llm_algorithm_PERDQN_folders = list()
    llm_algorithm_Rainbow_folders = list()
    llm_algorithm_CURL_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_adaption', env_name)):
        if 'Vanilla_DQN' in folder:
            llm_algorithm_VanillaDQN_folders.append(os.path.join('..', 'model_adaption', env_name, folder))
        elif 'Dueling_DQN' in folder:
            llm_algorithm_DuelingDQN_folders.append(os.path.join('..', 'model_adaption', env_name, folder))
        elif 'PER_DQN' in folder:
            llm_algorithm_PERDQN_folders.append(os.path.join('..', 'model_adaption', env_name, folder))
        elif 'Rainbow' in folder:
            llm_algorithm_Rainbow_folders.append(os.path.join('..', 'model_adaption', env_name, folder))
        elif 'CURL' in folder:
            llm_algorithm_CURL_folders.append(os.path.join('..', 'model_adaption', env_name, folder))       

    # print('base_VanillaDQN_folders:', base_VanillaDQN_folders)
    # print('base_DQN_folders:', base_DQN_folders)
    # print('base_DuelingDQN_folders:', base_DuelingDQN_folders)
    # print('base_PER_DQN_folders:', base_PERDQN_folders)
    # print('base_Rainbow_folders:', base_Rainbow_folders)
    # print('base_CURL_folders:', base_CURL_folders)

    # print('llm_main_folders:', llm_main_folders)
    # print('llm_ablation_folders:', llm_ablation_folders)
    # print('llm_model_GPT4o_folders:', llm_model_GPT4o_folders)
    # print('llm_model_GPT35_folders:', llm_model_GPT35_folders)
    # print('llm_model_Llama31_405B_folders:', llm_model_Llama31_405B_folders)
    # print('llm_model_Llama31_70B_folders:', llm_model_Llama31_70B_folders)
    # print('llm_prompt_name_folders:', llm_prompt_name_folders)
    # print('llm_adjust_2x_folders:', llm_adjust_2x_folders)
    # print('llm_adjust_3x_folders:', llm_adjust_3x_folders)
    # print('llm_sample_3x_folders:', llm_sample_2x_folders)
    # print('llm_sample_05x_folders:', llm_sample_05x_folders)
    # print('llm_algorithm_VanillaDQN_folders:', llm_algorithm_VanillaDQN_folders)
    # print('llm_algorithm_DuelingDQN_folders:', llm_algorithm_DuelingDQN_folders)
    # print('llm_algorithm_PERDQN_folders:', llm_algorithm_PERDQN_folders)
    # print('llm_algorithm_Rainbow_folders:', llm_algorithm_Rainbow_folders)
    # print('llm_algorithm_CURL_folders:', llm_algorithm_CURL_folders)


    def sliding_window_average(data, window_size=100):
        averages = []
        for i in range(1, len(data) + 1):
            start_idx = max(0, i - window_size)
            end_idx = i
            window_avg = np.mean(data[start_idx:end_idx])
            averages.append(window_avg)

        return np.array(averages)

    def compute_confidence_interval(data, confidence=0.95):
        n = data.shape[0]
        m = np.mean(data, axis=0)
        se = sem(data, axis=0)
        h = se * t.ppf((1 + confidence) / 2., n - 1)
        return m, se, m - h, m + h

    def load_and_process_data(folders,num_step):
        data_list = []
        data_len_list = []
        for folder in folders:
            episode_data = np.load(os.path.join(folder,'episode_rewards.npy'))
            num_episodes = len(episode_data)

            episode_data = sliding_window_average(episode_data)
            episode_data_x = []

            x=0
            for i in range(num_episodes):
                try:
                    episode_reward=np.load(os.path.join(folder,f'reward_record-episode{i}.npy'))
                    x+=len(episode_reward)
                    episode_data_x.append(x)
                except:
                    break

            # print(folder, num_episodes, x)

            data_list.append(np.interp(np.arange(x), episode_data_x, episode_data[:len(episode_data_x)]))  # 插值
            data_len_list.append(x) 

        enough_step_index = np.where(np.array(episode_data_x)>num_step)[0][0]

        min_len = min(min(data_len_list),num_step+1)
        data_list = [data[:min_len] for data in data_list]

        return np.array(data_list)[:,::1000], enough_step_index
    
    def load_token_consumption(folders, enough_step_index, factor=1):
        enough_step_index = int(enough_step_index/factor)

        input_tokens = []
        output_tokens = []
        for folder in folders:
            try:
                input_tokens.append(np.sum(np.load(os.path.join(folder,'token_input_stage1.npy'))[:enough_step_index])+np.sum(np.load(os.path.join(folder,'token_input_stage2.npy'))[:enough_step_index]))
                output_tokens.append(np.sum(np.load(os.path.join(folder,'token_output_stage1.npy'))[:enough_step_index])+np.sum(np.load(os.path.join(folder,'token_output_stage2.npy'))[:enough_step_index]))
                # print(np.load(os.path.join(folder,'token_input_stage1.npy')),np.load(os.path.join(folder,'token_input_stage2.npy')),np.load(os.path.join(folder,'token_output_stage1.npy')),np.load(os.path.join(folder,'token_input_output2.npy')))
            except:
                try:
                    input_tokens.append(np.sum(np.load(os.path.join(folder,'token_input.npy'))[:enough_step_index]))
                    output_tokens.append(np.sum(np.load(os.path.join(folder,'token_output.npy'))[:enough_step_index]))
                except:
                    return None,None
            
        return np.mean(input_tokens),np.mean(output_tokens)

    tasks ={
            'consumption':{'Double-DQN':base_DQN_folders, 'Double-DQN+LLM-Exp (full)':llm_main_folders, 'Double-DQN+LLM-Exp (w/o summarize & suggestion)':llm_ablation_folders, 'Double-DQN+LLM-Exp (w/o environmental information)':llm_prompt_name_folders},
            'llm_type':{'Double-DQN':base_DQN_folders, 'Double-DQN+LLM-Exp (GPT-4o mini)':llm_main_folders, 'Double-DQN+LLM-Exp (GPT-4o)':llm_model_GPT4o_folders, 'Double-DQN+LLM-Exp (GPT-3.5)':llm_model_GPT35_folders, 'Double-DQN+LLM-Exp (Llama-3.1-405B)':llm_model_Llama31_405B_folders, 'Double-DQN+LLM-Exp (Llama-3.1-70B)':llm_model_Llama31_70B_folders},
            'adjust_freq':{'Double-DQN':base_DQN_folders, 'Double-DQN+LLM-Exp  ($K$=1)':llm_main_folders, 'Double-DQN+LLM-Exp  ($K$=2)':llm_adjust_2x_folders, 'Double-DQN+LLM-Exp  ($K$=3)':llm_adjust_3x_folders},
            'sample_rate':{'Double-DQN':base_DQN_folders, 'Double-DQN+LLM-Exp ($M$=100)':llm_main_folders, 'Double-DQN+LLM-Exp ($M$=50)':llm_sample_2x_folders, 'Double-DQN+LLM-Exp ($M$=200)':llm_sample_05x_folders},
            'adaptiabilty_DQN':{'DQN variant':base_VanillaDQN_folders, 'DQN variant+LLM-Exp':llm_algorithm_VanillaDQN_folders},
            'adaptiabilty_Dueling-DQN':{'DQN variant':base_DuelingDQN_folders, 'DQN variant+LLM-Exp':llm_algorithm_DuelingDQN_folders},
            'adaptiabilty_PER-DQN':{'DQN variant':base_PERDQN_folders, 'DQN variant+LLM-Exp':llm_algorithm_PERDQN_folders},
            'adaptiabilty_Rainbow':{'DQN variant':base_Rainbow_folders, 'DQN variant+LLM-Exp':llm_algorithm_Rainbow_folders},
            'adaptiabilty_CURL':{'DQN variant':base_CURL_folders, 'DQN variant+LLM-Exp':llm_algorithm_CURL_folders},
            'base_compare':{'Double-DQN':base_DQN_folders, 'DQN':base_VanillaDQN_folders, 'Dueling-DQN':base_DuelingDQN_folders, 'PER-DQN':base_PERDQN_folders, 'Rainbow':base_Rainbow_folders, 'CURL':base_CURL_folders},
            }

    for task_name,task_folders in tasks.items():
        # print(task_name)

        if 'adaptiabilty' in task_name:
            plt.figure(figsize=(4, 4))
        else:
            plt.figure(figsize=(7, 6))

        ymin = 1e10
        ymax = -1e10

        lines=[]
        labels=[]

        for label,folders in task_folders.items():
            data, enough_step_index = load_and_process_data(folders,num_step)

            factor = 1
            if f'adjust{2*adjust_frequency}' in folders[0]:
                factor = 2
            elif f'adjust{3*adjust_frequency}' in folders[0]:
                factor = 3

            # for i in range(3):
            #     plt.plot(data[i],alpha=0.5)
            mean, se, lower, upper = compute_confidence_interval(data)
            ymin = min(ymin, mean.min())
            ymax = max(ymax, mean.max())
            line = plt.plot(np.arange(len(mean))*1000/1e5, mean, linewidth=2)
            plt.fill_between(np.arange(len(mean))*1000/1e5, mean-se, mean+se, alpha=0.3)

            token_input,token_output = load_token_consumption(folders, enough_step_index, factor)

            # print(np.round(100*(mean[-1]-random_scores[env_name])/(human_scores[env_name]-random_scores[env_name]),2),end='\t')
            # print(np.round(100*(mean[-1]-random_scores[env_name])/(human_scores[env_name]-random_scores[env_name]),2), np.round(token_input/1000,2), np.round(token_output/1000,2), sep='\t', end='\t')

            lines.append(line[0])
            labels.append(label)

        if ymax > 0:
            ymax = ymax * 1.1
        else:
            ymax = ymax * 0.9

        if ymin > 0:
            ymin = ymin * 0.9
        else:
            ymin = ymin * 1.1

        plt.xlabel('Environment steps ($10^5$)',fontsize=18)
        plt.xlim(0,num_step/1e5)
        plt.ylim(ymin,ymax)
        plt.xticks(fontsize=13)
        plt.yticks(fontsize=13)
        plt.gca().xaxis.set_major_locator(MaxNLocator(5))
        plt.gca().yaxis.set_major_locator(MaxNLocator(5))
        plt.ylabel('Episode return',fontsize=18)
        # plt.legend(loc='lower right',fontsize=13)

        if 'adaptiabilty' in task_name:
            plt.title(task_name.strip('adaptiabilty_')+'@'+env_name.strip('-v5'),fontsize=18)
        else:
            plt.title(env_name.strip('-v5'),fontsize=18)

        plt.grid(linestyle='--',alpha=0.5)
        plt.tight_layout()
        plt.savefig(os.path.join('..', 'figure_full', env_name + '-' + task_name + '.pdf'))
        plt.close()


        fig_legend = plt.figure(figsize=(8, 1))
        legend = fig_legend.legend(lines, labels, loc='center', fontsize=13, ncol=(3 if len(labels)%3==0 else 2), frameon=True)
        fig_legend.savefig(os.path.join('..', 'figure_full', f'legend_{task_name}.pdf'), bbox_inches='tight')
        plt.close()

    print('--------------------------------')