import pickle
import numpy as np

import matplotlib.pyplot as plt
import ipdb

from utils import (
    build_bandit_data_filename,
    build_bandit_model_filename,
    build_linear_bandit_data_filename,
    build_linear_bandit_model_filename,
    build_darkroom_data_filename,
    build_darkroom_model_filename,
    build_miniworld_data_filename,
    build_miniworld_model_filename,
    build_linear_bandit_data_filename_custom,
)




if __name__ == '__main__':

    H = 204 #### this should be 4 times the horizon, 



    h = 0
    data = []
    actions = []

    
    # with open('script/new_arms/data/a_pred.pkl', 'rb+') as f: ## large horizon
    # with open('script/new_arms/data_large/a_pred.pkl', 'rb+') as f: ## large horizon
    # with open('script/new_arms_final/data/a_pred.pkl', 'rb+') as f: ## large horizon
    with open('script/new_arms_final/data_large/a_pred.pkl', 'rb+') as f: ## large horizon
        while h < H:
            data.append(pickle.load(f))
            actions.append(data[h])
            h += 1
    # print(np.shape(data))

    
    data_opt = []
    actions_opt = []
    
    # with open('script/new_arms/data/opt_a.pkl', 'rb+') as f: ## large horizon
    # with open('script/new_arms/data_large/opt_a.pkl', 'rb+') as f: ## large horizon
    # with open('script/new_arms_final/data/opt_a.pkl', 'rb+') as f: ## large horizon
    with open('script/new_arms_final/data_large/opt_a.pkl', 'rb+') as f: ## large horizon
        data_opt.append(pickle.load(f))
        actions_opt = data_opt[0]
        
    # print(np.shape(data_opt))

    # ipdb.set_trace()

    steps = 10
    _, _, num_arms = np.shape(data_opt)

    opt_action = np.argmax(np.sum(actions_opt, axis=0))

    #### if set manually, use array indice format (do -1)#####
    # opt_action = 16

    print("opt action", opt_action+1)
    envs_with_opt_action = np.where(actions_opt[:, opt_action] == 1)[0]

    print(envs_with_opt_action, len(envs_with_opt_action), opt_action)
    ##### This captures data of 4 trasnformer algorithms, 
    actions = np.array(actions)

    action_ep_pred = np.zeros((len(envs_with_opt_action), num_arms))
    action_ep_pred_tau = np.zeros((len(envs_with_opt_action), num_arms))
    actions_ep_dpt = np.zeros((len(envs_with_opt_action), num_arms))

    action_ep_pred = actions[25:50, envs_with_opt_action, :] 
    action_ep_pred_tau = actions[50:75, envs_with_opt_action, :] 
    actions_ep_dpt = actions[0:25, envs_with_opt_action, :]

    

    actions_first_5_timesteps = action_ep_pred[:steps, :]
    actions_last_5_timesteps = action_ep_pred[-steps:, :]
    actions_first_5_timesteps = np.mean(np.sum(actions_first_5_timesteps, axis=0), axis = 0)
    actions_last_5_timesteps = np.mean(np.sum(actions_last_5_timesteps, axis=0), axis = 0)

    actions_first_5_timesteps_tau = action_ep_pred_tau[:steps, :]
    actions_last_5_timesteps_tau = action_ep_pred_tau[-steps:, :]
    actions_first_5_timesteps_tau = np.mean(np.sum(actions_first_5_timesteps_tau, axis=0), axis = 0)
    actions_last_5_timesteps_tau = np.mean(np.sum(actions_last_5_timesteps_tau, axis=0), axis = 0)

    actions_first_5_timesteps_dpt = actions_ep_dpt[:steps, :]
    actions_last_5_timesteps_dpt = actions_ep_dpt[-steps:, :]
    actions_first_5_timesteps_dpt = np.mean(np.sum(actions_first_5_timesteps_dpt, axis=0), axis = 0)
    actions_last_5_timesteps_dpt = np.mean(np.sum(actions_last_5_timesteps_dpt, axis=0), axis = 0)

    # ipdb.set_trace()

    ############## plot bar diagram ############

    # Define the data
    x = np.arange(1, num_arms+1)  # Values from 1 to 10 for x-axis
    y1 = actions_first_5_timesteps  # First set of values
    y2 = actions_last_5_timesteps  # Second set of values

    # k = 0.05
    # std_dev_y1 = k*np.std(actions_first_5_timesteps, axis = 0)
    # std_dev_y2 = k*np.std(actions_last_5_timesteps, axis = 0)

    # Set the width of the bars
    bar_width = 0.15

    # Plot the bars
    # plt.bar(x - bar_width/2, y1, width=bar_width, yerr=std_dev_y1, capsize=5, label='Pred (first 5)')
    # plt.bar(x + bar_width/2, y2, width=bar_width, yerr=std_dev_y1, capsize=5, label='Pred (last 5)')
    plt.bar(x - bar_width/2, y1, width=bar_width, label=f'Pred (first {steps})')
    plt.bar(x + bar_width/2, y2, width=bar_width, label=f'Pred (last {steps})')

    y1 = actions_first_5_timesteps_tau  # First set of values
    y2 = actions_last_5_timesteps_tau  # Second set of values

    plt.bar(x - 1.5*bar_width, y1, width=bar_width, label=f'Pred-$\\tau$ (first {steps})')
    plt.bar(x + 1.5*bar_width, y2, width=bar_width, label=f'Pred-$\\tau$ (last {steps})')


    y1 = actions_first_5_timesteps_dpt
    y2 = actions_last_5_timesteps_dpt

    plt.bar(x - 2.5*bar_width, y1, width=bar_width, label=f'DPT-Greedy (first {steps})')
    plt.bar(x + 2.5*bar_width, y2, width=bar_width, label=f'DPT-Greedy (last {steps})')


    # Add labels and title
    plt.xlabel('Arms')
    plt.ylabel(f'Average number of pulls in {steps} rounds')
    plt.title('Exploration by algorithms')
    plt.xticks(x, x)
    plt.legend()

    
    plt.legend()
    
    # plt.savefig('script/new_arms/data/analysis_pull.png') ## large horizon
    # plt.savefig('script/new_arms/data_large/analysis_pull.png') ## large horizon
    # plt.savefig('script/new_arms_final/data/analysis_pull.png') ## large horizon
    plt.savefig('script/new_arms_final/data_large/analysis_pull.png') ## large horizon
    
    plt.close()
    plt.clf()

    # ipdb.set_trace()

    k = 0.1


    horizon, env, num_arms = np.shape(actions[0:25,:,:])
    num_actions_env = np.zeros((env, horizon))

    for i in range(env):
        num_actions_env[i] = np.where(actions[0:25,i,:] == 1)[1]

    count_actions = np.zeros((len(envs_with_opt_action), horizon))
    
    for t in range(horizon):
        # for e in range(env):
        # m = 0
        for m, e in enumerate(envs_with_opt_action):
            # print(f"Time {t}")
            # print(np.unique(num_actions_env[e,t], return_counts=True))
            # print(len(np.unique(num_actions_env[e,t:horizon])), e, t)
            count_actions[m, t] = len(np.unique(num_actions_env[e,t:horizon]))
            

    mean_count_actions = np.mean(count_actions, axis=0)

    # ipdb.set_trace()

    plt.plot(mean_count_actions, label = 'DPT-Greedy')
    ##### shade error region ###
    plt.fill_between(np.arange(0, horizon), mean_count_actions - k*np.std(count_actions, axis=0), mean_count_actions + k*np.std(count_actions, axis=0), alpha=0.2)
    # plt.errorbar(np.arange(0, horizon), mean_count_actions, yerr=k*np.std(count_actions, axis=0), label='DPT-Greedy')



    horizon, env, num_arms = np.shape(actions[25:50,:,:])
    num_actions_env = np.zeros((env, horizon))
    # actions_time = np.reshape(actions[25:50,:,:], (env, horizon, num_arms))
    # actions_time = actions[25:50,:,:]

    

    for i in range(env):
        num_actions_env[i] = np.where(actions[25:50,i,:] == 1)[1]

    # count_actions = np.zeros((env, horizon))
    count_actions = np.zeros((len(envs_with_opt_action), horizon))
    for t in range(horizon):
        # for e in range(env):
        for m, e in enumerate(envs_with_opt_action):
            count_actions[m,t] = len(np.unique(num_actions_env[e,t:horizon]))

    mean_count_actions = np.mean(count_actions, axis=0)
    plt.plot(mean_count_actions, label = 'PreDeToR (ours)')
    plt.fill_between(np.arange(0, horizon), mean_count_actions - k*np.std(count_actions, axis=0), mean_count_actions + k*np.std(count_actions, axis=0), alpha=0.2)
    # plt.errorbar(np.arange(0, horizon), mean_count_actions, yerr=k*np.std(count_actions, axis=0), label='PreDeToR (ours)')

    # ipdb.set_trace()

    horizon, env, num_arms = np.shape(actions[50:75,:,:])
    num_actions_env = np.zeros((env, horizon))

    for i in range(env):
        num_actions_env[i] = np.where(actions[50:75,i,:] == 1)[1]

    # count_actions = np.zeros((env, horizon))
    count_actions = np.zeros((len(envs_with_opt_action), horizon))
    for t in range(horizon):
        # for e in range(env):
        for m, e in enumerate(envs_with_opt_action):
            count_actions[m,t] = len(np.unique(num_actions_env[e,t:horizon]))

    mean_count_actions = np.mean(count_actions, axis=0)

    plt.plot(mean_count_actions, label = 'PreDeToR-$\\tau$ (ours)')
    plt.fill_between(np.arange(0, horizon), mean_count_actions - k*np.std(count_actions, axis=0), mean_count_actions + k*np.std(count_actions, axis=0), alpha=0.2)
    # plt.errorbar(np.arange(0, horizon), mean_count_actions, yerr=k*np.std(count_actions, axis=0), label='PreDeToR-$\\tau$ (ours)')



    plt.xlabel('Time')
    plt.ylabel('Average number of arms pulled in time forward')
    
    plt.legend()

    x = np.arange(1, horizon+1, 5)
    plt.xticks(x, x)


    
    
    # plt.savefig('script/new_arms/data/analysis_time_1.png') ## large horizon
    # plt.savefig('script/new_arms/data_large/analysis_time_1.png') ## large horizon
    # plt.savefig('script/new_arms_final/data/analysis_time_1.png') ## large horizon
    plt.savefig('script/new_arms_final/data_large/analysis_time_1.png') ## large horizon


    