import numpy as np
import os
import matplotlib.pyplot as plt
from scipy.stats import sem, t
from matplotlib.ticker import MaxNLocator


env_names={
           'Alien-v5':100000,
           'Amidar-v5':200000,
           'BankHeist-v5':200000,
           'Breakout-v5':100000,
           'ChopperCommand-v5':200000,
           'CrazyClimber-v5':300000,                            
           'Freeway-v5':300000,
           'Hero-v5':300000,
           'Jamesbond-v5':200000,
           'Krull-v5':100000,
           'MsPacman-v5':100000,
           'Pong-v5':500000,
           'Qbert-v5':300000,
           'Seaquest-v5':300000,
           'UpNDown-v5':200000}

human_scores = {
        'Alien-v5': 7127.7,
        'Amidar-v5': 1719.5,
        'BankHeist-v5': 753.1,
        'Breakout-v5': 30.5,
        'ChopperCommand-v5': 7387.8,
        'CrazyClimber-v5': 35829.4,
        'DemonAttack-v5': 1971.0,
        'Freeway-v5': 29.6,
        'Frostbite-v5': 4334.7,
        'Hero-v5': 30826.4,
        'Jamesbond-v5': 302.8,
        'Krull-v5': 2665.5,
        'MsPacman-v5': 6951.6,
        'Pong-v5': 14.6,
        'Qbert-v5': 13455.0,
        'Seaquest-v5': 4254.7,
        'UpNDown-v5': 11693.2
}

random_scores = {
        'Alien-v5': 227.8,
        'Amidar-v5': 5.8,
        'BankHeist-v5': 14.2,
        'Breakout-v5': 1.7,
        'ChopperCommand-v5': 811.0,
        'CrazyClimber-v5': 10780.5,
        'DemonAttack-v5': 252.1,
        'Freeway-v5': 0.0,
        'Frostbite-v5': 65.2,
        'Hero-v5': 1027.0,
        'Jamesbond-v5': 29.0,
        'Krull-v5': 1598.0,
        'MsPacman-v5': 307.3,
        'Pong-v5': -20.7,
        'Qbert-v5': 163.9,
        'Seaquest-v5': 68.4,
        'UpNDown-v5': 533.4
}

for env_name,num_step in env_names.items():
    print('\n'+env_name.strip('-v5'),end='\t')
    base_DQN_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_base', env_name)):
        if folder[:3] == 'DQN':
            base_DQN_folders.append(os.path.join('..', 'model_base', env_name, folder))

    llm_main_folders = list()
    for folder in os.listdir(os.path.join('..', 'model_main', env_name)):
        llm_main_folders.append(os.path.join('..', 'model_main', env_name, folder))

    # print('base_DQN_folders:', base_DQN_folders)
    # print('llm_main_folders:', llm_main_folders)


    def sliding_window_average(data, window_size=100):
        averages = []
        for i in range(1, len(data) + 1):
            start_idx = max(0, i - window_size)
            end_idx = i
            window_avg = np.mean(data[start_idx:end_idx])
            averages.append(window_avg)

        return np.array(averages)

    def compute_confidence_interval(data, confidence=0.95):
        n = data.shape[0]
        m = np.mean(data, axis=0)
        se = sem(data, axis=0)
        h = se * t.ppf((1 + confidence) / 2., n - 1)
        return m, se, m - h, m + h

    def load_and_process_data(folders,num_step):
        data_list = []
        data_len_list = []
        for folder in folders:
            episode_data = np.load(os.path.join(folder,'episode_rewards.npy'))
            num_episodes = len(episode_data)

            episode_data = sliding_window_average(episode_data)
            episode_data_x = []

            x=0
            for i in range(num_episodes):
                try:
                    episode_reward=np.load(os.path.join(folder,f'reward_record-episode{i}.npy'))
                    x+=len(episode_reward)
                    episode_data_x.append(x)
                except:
                    break

            # print(folder, num_episodes, x)

            data_list.append(np.interp(np.arange(x), episode_data_x, episode_data[:len(episode_data_x)]))  # 插值
            data_len_list.append(x) 

        min_len = min(min(data_len_list),num_step+1)
        data_list = [data[:min_len] for data in data_list]

        return np.array(data_list)[:,::1000]
    
    def load_token_consumption(folders):
        input_tokens = []
        output_tokens = []
        for folder in folders:
            try:
                input_tokens.append(np.sum(np.load(os.path.join(folder,'token_input_stage1.npy')))+np.sum(np.load(os.path.join(folder,'token_input_stage2.npy'))))
                output_tokens.append(np.sum(np.load(os.path.join(folder,'token_output_stage1.npy')))+np.sum(np.load(os.path.join(folder,'token_output_stage2.npy'))))
            except:
                return None,None
            
        return np.mean(input_tokens),np.mean(output_tokens)

    tasks ={
            'main':{'Double-DQN':base_DQN_folders, 'Double-DQN+LLM-Exp':llm_main_folders},
            }


    for task_name,task_folders in tasks.items():
        # print('\n'+task_name,end='\t')

        plt.figure(figsize=(4, 4))
        ymin = 1e10
        ymax = -1e10

        lines=[]
        labels=[]

        for label,folders in task_folders.items():
            data = load_and_process_data(folders,num_step)
            # for i in range(3):
            #     plt.plot(data[i],alpha=0.5)
            mean, se, lower, upper = compute_confidence_interval(data)
            ymin = min(ymin, mean.min())
            ymax = max(ymax, mean.max())
            line = plt.plot(np.arange(len(mean))*1000/1e5, mean, linewidth=2)
            plt.fill_between(np.arange(len(mean))*1000/1e5, mean-se, mean+se, alpha=0.3)

            token_input,token_output = load_token_consumption(folders)

            # print(folders[0], np.round(mean[-1],2), np.round(se[-1],2), np.round(100*(mean[-1]-random_scores[env_name])/(human_scores[env_name]-random_scores[env_name]),2),np.round(100*se[-1]/(human_scores[env_name]-random_scores[env_name]),2))
            print(np.round(mean[-1],2), np.round(100*(mean[-1]-random_scores[env_name])/(human_scores[env_name]-random_scores[env_name]),2),sep='\t',end='\t')


            lines.append(line[0])
            labels.append(label)

        if ymax > 0:
            ymax = ymax * 1.1
        else:
            ymax = ymax * 0.9

        if ymin > 0:
            ymin = ymin * 0.9
        else:
            ymin = ymin * 1.1

        plt.xlabel('Environment steps ($10^5$)',fontsize=18)
        plt.xlim(0,num_step/1e5)
        plt.ylim(ymin,ymax)
        plt.xticks(fontsize=13)
        plt.yticks(fontsize=13)
        plt.gca().xaxis.set_major_locator(MaxNLocator(5))
        plt.gca().yaxis.set_major_locator(MaxNLocator(5))
        plt.ylabel('Episode return',fontsize=18)
        # plt.legend(loc='lower right',fontsize=13)
        plt.title(env_name.strip('-v5'),fontsize=18)
        plt.grid(linestyle='--',alpha=0.5)
        plt.tight_layout()
        plt.savefig(os.path.join('..', 'figure', env_name + '-' + task_name + '.pdf'))
        plt.close()

        fig_legend = plt.figure(figsize=(8, 1))
        legend = fig_legend.legend(lines, labels, loc='center', fontsize=18,ncol=(3 if len(label)%3==0 else 2), frameon=True)
        fig_legend.savefig(os.path.join('..', 'figure', f'legend_{task_name}.pdf'), bbox_inches='tight')
        plt.close()