import pickle
import numpy as np

import matplotlib.pyplot as plt
import ipdb

from utils import (
    build_bandit_data_filename,
    build_bandit_model_filename,
    build_linear_bandit_data_filename,
    build_linear_bandit_model_filename,
    build_darkroom_data_filename,
    build_darkroom_model_filename,
    build_miniworld_data_filename,
    build_miniworld_model_filename,
    build_linear_bandit_data_filename_custom,
)

from dataset_new import Dataset, ImageDataset, Dataset_wt, Dataset_pred_reward, Dataset_pred_reward_opt_a
from tqdm import tqdm
import torch

if __name__ == '__main__':

    ########################## This file is run for single models pickle files, set the horizon H as the actual H ##########
    # H = 101
    H = 25 ## smaller horizon
    # H = 200 ## larger horizon

    ### load files ###
    # h = 0
    # data_wh = []

    # with open('script/run_without_head/r_pred_with_head.pkl', 'rb+') as f:

    #     while h < H:
    #         data_wh.append(pickle.load(f))
    #         h += 1
    # print("1",np.shape(data_wh))


    h = 0
    data_wh = []
    # with open('script/run_without_head/r_pred_without_head_linucb.pkl', 'rb+') as f:
    # with open('script/online/data/r_pred_without_head_linucb.pkl', 'rb+') as f: ## smaller horizon
    # with open('script/online/data_larger/r_pred_without_head_linucb.pkl', 'rb+') as f: ## larger horizon
    with open('script/online/data_large/r_pred_without_head_linucb.pkl', 'rb+') as f: ## large horizon

        while h < H:
            data_wh.append(pickle.load(f))
            h += 1
    print(np.shape(data_wh))

    
    h = 0
    data_woh = []
    # with open('script/run_without_head/r_pred_without_head_pred.pkl', 'rb+') as f:
    # with open('script/online/data/r_pred_without_head_pred.pkl', 'rb+') as f: ## smaller horizon
    # with open('script/online/data_larger/r_pred_without_head_pred.pkl', 'rb+') as f: ## larger horizon
    with open('script/online/data_large/r_pred_without_head_pred.pkl', 'rb+') as f: ## large horizon
        while h < H:
            data_woh.append(pickle.load(f))
            h += 1
    print(np.shape(data_woh))


    h = 0
    data_woh_tau = []
    # with open('script/run_without_head/r_pred_without_head_pred.pkl', 'rb+') as f:
    # with open('script/online/data/r_pred_without_head_pred_tau.pkl', 'rb+') as f: ## smaller horizon
    # with open('script/online/data_larger/r_pred_without_head_pred_tau.pkl', 'rb+') as f: ## larger horizon
    with open('script/online/data_large/r_pred_without_head_pred_tau.pkl', 'rb+') as f: ## large horizon
        while h < H:
            data_woh_tau.append(pickle.load(f))
            h += 1
    print(np.shape(data_woh_tau))


    ### load environment data ###
    envname = 'linear_bandit_train_lookahead_pred_reward'
    n_eval = 200 ## smaller, larger horizon
    # n_eval = 205 ## large horizon
    horizon = H
    
    # dim = 20 ## larger horizon
    dim = 10 ## smaller horizon
    
    # lin_d = 5 ## larger horizon
    lin_d = 2 ## smaller horizon. large horizon
    var = 0.3
    cov = 0.0
    dataset_config = {
        'horizon': horizon,
        'dim': dim,
    }
    dataset_config.update({'lin_d': lin_d, 'var': var, 'cov': cov})

    eval_filepath = build_linear_bandit_data_filename(
        envname, n_eval, dataset_config, mode=2)
    
    print(eval_filepath)

    with open(eval_filepath, 'rb') as f:
        eval_trajs = pickle.load(f)

    # print(np.shape(eval_trajs))

    # ipdb.set_trace()

    best_arm = []
    for i in range(0,n_eval):
        best_arm.append(np.argmax(eval_trajs[i]['means']))
    
    print(np.unique(best_arm, return_counts=True))

    
    ### Empirical best arm ###

    best_arm_emp = []
    best_arm_emp_value = np.zeros(dim)

    sum_rewards = np.zeros((n_eval, dim))
    context_rewards_proxy = np.zeros((n_eval, horizon, dim))
    context_pred_opt_a = np.zeros((n_eval, horizon, dim))

    context_rewards = np.zeros((n_eval, horizon, 1))
    context_actions = np.zeros((n_eval, horizon, dim))
    
    context_means = np.zeros((n_eval, dim))
    for i in range(0, n_eval):

        context_rewards[i, :, 0] = eval_trajs[i]['context_rewards']
        context_actions[i, :, :] = eval_trajs[i]['context_actions']
        context_means[i, :] = eval_trajs[i]['means']
    
    
    context_rewards_proxy = context_rewards_proxy + context_rewards * context_actions

    sum_rewards = np.sum(context_rewards_proxy, axis=1)
    # print(np.shape(sum_rewards))

    best_arm_emp_value = np.mean(sum_rewards, axis=0)
    best_arm_true_value = np.mean(context_means, axis=0)

    argmax_a = np.argmax(sum_rewards, axis=1)
    a = np.repeat(argmax_a[:, np.newaxis], horizon, axis=1)
    for i in range(a.shape[0]):
        context_pred_opt_a[i, :, a[i,0]] = 1


    for i in range(0, 200):
        best_arm_emp.append(np.argmax(context_pred_opt_a[i, H-1, :]))
    
    print(best_arm_emp)
    print(np.unique(best_arm_emp, return_counts=True))

    # ipdb.set_trace()
    
    env_plot = 3
    plt.xlabel('Arms')
    plt.ylabel('Prediction error')
    plt.title(f'Prediction error of models of best arm at the last round for {env_plot} environment')
    plt.xticks(np.arange(0, dim, 1), np.arange(1, dim+1, 1))

    ### Prediction of models of best arm at the last round averaged over 200 environments ###
    # plt.plot(np.mean(data_wh[H-1], axis=0), label='with head')
    # plt.plot(np.mean(data_woh[H-1], axis=0), label='without head')

    
    plt.plot(abs(data_wh[H-1][env_plot] - context_means[env_plot]), label='with head')
    plt.plot(abs(data_woh[H-1][env_plot] - context_means[env_plot]), label='without head')

    print(eval_trajs[env_plot]['means'])

    # plt.plot(data_wh_[H-1], label='with head')
    # plt.plot(data_woh_[H-1], label='without head')

    # print(best_arm_emp[0])

    # best_arm_emp_value = np.mean(np.array(best_arm_emp_value))
    # print(best_arm_emp_value)
    # print(best_arm_true_value)

    # bar1 = plt.bar(x=np.unique(best_arm_emp, return_counts=True)[0], height=np.unique(best_arm_emp, return_counts=True)[1]*0.01, label='Emp best arm', color='green', alpha=0.5)
    # for rect in bar1:
    #     height_text = rect.get_height()*100
    #     height = rect.get_height()
    #     best_arm_true_value_rect = best_arm_true_value[int(rect.get_x())]
    #     plt.text(rect.get_x() + rect.get_width() / 2.0, height, f'{height_text:.0f}', ha='center', va='bottom')
    #     plt.text(rect.get_x() + rect.get_width() / 2.0, height-(height/2), f'{best_arm_true_value_rect:.2f}', ha='center', va='center')
    
    plt.legend()
    # plt.savefig('script/run_without_head/analysis.png')
    
    # plt.savefig('script/online/data/analysis.png') ## smaller horizon
    # plt.savefig('script/online/data_larger/analysis.png') ## larger horizon
    plt.savefig('script/online/data_large/analysis.png') ## large horizon

    plt.close()
    plt.clf()
    plt.cla()

    # total_pred_error_wh = np.zeros(200)
    # total_pred_error_woh = np.zeros(200)

    total_pred_error_wh = np.zeros(dim)
    total_pred_error_woh = np.zeros(dim)
    total_pred_error_woh_tau = np.zeros(dim)


    # best_arm_env = np.argmax(context_means[0])
    # total_pred_error_wh[best_arm_env] =(data_wh[H-1][0][best_arm_env] - context_means[0][best_arm_env])**2
    # total_pred_error_woh[best_arm_env] = (data_woh[H-1][0][best_arm_env] - context_means[0][best_arm_env])**2

    counts_of_best_arm = np.ones(dim)
    for i in range(0, 200):
        best_arm_env = np.argmax(context_means[i])
        counts_of_best_arm[best_arm_env] += 1
        total_pred_error_wh[best_arm_env] += abs(data_wh[H-1][i][best_arm_env] - context_means[i][best_arm_env])
        total_pred_error_woh[best_arm_env] += abs(data_woh[H-1][i][best_arm_env] - context_means[i][best_arm_env])
        total_pred_error_woh_tau[best_arm_env] += abs(data_woh_tau[H-1][i][best_arm_env] - context_means[i][best_arm_env])
    
    total_pred_error_wh = total_pred_error_wh / counts_of_best_arm
    total_pred_error_woh = total_pred_error_woh / counts_of_best_arm
    total_pred_error_woh_tau = total_pred_error_woh_tau / counts_of_best_arm
    
    # plt.plot(total_pred_error_wh, label='with head')
    # plt.plot(total_pred_error_woh, label='without head')
    # plt.plot(total_pred_error_woh, label='PreDeToR (ours)')

    width = 0.25
    plt.bar(x=np.arange(0, dim, 1), height=total_pred_error_wh, width = width, label='LinUCB', color='red', alpha=0.5)
    plt.bar(x=np.arange(0, dim, 1) - width, height=total_pred_error_woh, width = width, label='PreDeToR (ours)', color='orange', alpha=0.5)
    plt.bar(x=np.arange(0, dim, 1) + width, height=total_pred_error_woh_tau, width = width, label='PreDeToR-$\\tau$ (ours)', color='blue', alpha=0.5)

    plt.xlabel('Arms')
    plt.ylabel('Prediction error')
    plt.title('Total error for each arm when they are optimal averaged over 200 envs')
    plt.xticks(np.arange(0, dim, 1), np.arange(1, dim+1, 1))

    plt.legend()
    # plt.savefig('script/run_without_head/analysis_total.png')
    
    # plt.savefig('script/online/data/analysis_total.png') ## smaller horizon
    # plt.savefig('script/online/data_larger/analysis_total.png') ## larger horizon
    plt.savefig('script/online/data_large/analysis_total.png') ## large horizon

    plt.close()
    plt.clf()
    plt.cla()


    bar1 = plt.bar(x=np.unique(best_arm_emp, return_counts=True)[0], height=np.unique(best_arm_emp, return_counts=True)[1]*0.01, label='Emp best arm', color='green', alpha=0.5)
    for rect in bar1:
        height_text = rect.get_height()*100
        height = rect.get_height()
        best_arm_true_value_rect = best_arm_true_value[int(rect.get_x())]
        plt.text(rect.get_x() + rect.get_width() / 2.0, height, f'{height_text:.0f}', ha='center', va='bottom')
        # plt.text(rect.get_x() + rect.get_width() / 2.0, height-(height/2), f'{best_arm_true_value_rect:.2f}', ha='center', va='center')
    

    plt.xlabel('Arms')
    plt.ylabel('Frequancy Value')
    plt.title('Arms Disribution')
    plt.xticks(np.arange(0, dim, 1), np.arange(1, dim+1, 1))

    plt.legend()
    
    # plt.savefig('script/online/data/analysis_arm_dist.png') ## smaller horizon
    plt.savefig('script/online/data_large/analysis_arm_dist.png') ## large horizon
    # plt.savefig('script/online/data_larger/analysis_arm_dist.png') ## larger horizon


    plt.close()
    plt.clf()
    plt.cla()




    

    ### load environment training data ###
    # n_envs = 100000  ## smaller horizon
    n_envs = 200000 ## large horizon
    # n_envs = 150080 ## larger horizon
    state_dim = 1
    action_dim = 10 ## smaller horizon
    # action_dim = 20 ## larger horizon
    n_layer = 4
    n_embd = 32
    n_head = 4
    shuffle = False
    dropout = 0.1

    n_hists = 1
    n_samples = 1
    H = 25  ## smaller horizon, large horizon
    # H = 200 ## larger horizon

    dataset_config.update({'lin_d': lin_d, 'var': var, 'cov': cov, 'n_hists': n_hists, 'n_samples': n_samples})
    path_train = build_linear_bandit_data_filename(
        envname, n_envs, dataset_config, mode=0)

    eval_filepath = build_linear_bandit_data_filename(
        envname, n_eval, dataset_config, mode=2)

    train_config = {
        'horizon': horizon,
        'state_dim': state_dim,
        'action_dim': action_dim,
        'n_layer': n_layer,
        'n_embd': n_embd,
        'n_head': n_head,
        'shuffle': shuffle,
        'dropout': dropout,
        'test': False,
        'store_gpu': True,
    }
    # config.update({'path_train': path_train, 'shuffle': shuffle, 'store_gpu': True})
                    
    ##### Load Train dataset, comment if you don't want to load train dataset ######
    train_dataset = Dataset_pred_reward(path_train, train_config)

    ##### Load Eval dataset, comment if you don't want to load eval dataset ######
    eval_dataset = Dataset_pred_reward(eval_filepath, train_config)

    ## find optimal action in train_dataset ###

    # ipdb.set_trace()
    opt_action = torch.zeros(n_envs, action_dim)
    reward_opt_action = torch.zeros(n_envs, horizon, action_dim)
    for env in tqdm(range(int(0.8*n_envs))):
    # for env in tqdm(range(int(0.02*n_envs))):
        opt_action[env] = train_dataset[env]['optimal_actions']
        # ipdb.set_trace()


        # ############## Comment out if you do not want to calculate reward_opt_action ( this is slow) ############
        # x = torch.repeat_interleave(train_dataset[env]['context_rewards'], action_dim, dim=1)
        # y = torch.repeat_interleave(train_dataset[env]['optimal_actions'].reshape(1,-1), horizon, dim=0)
        # y = y * train_dataset[env]['context_actions']
        # reward_opt_action[env] = x * y
    

    
    # # with open('script/online/data/opt_a_reward_train.pkl', 'wb') as f: ## small horizon
    # with open('script/online/data_large/opt_a_reward_train.pkl', 'wb') as f: ## large horizon
    #     pickle.dump(reward_opt_action, f)
    # f.close()

    # with open('script/online/data/opt_a_reward_train.pkl', 'rb') as f: ## smaller horizon
    #     reward_opt_action = pickle.load(f)
    # f.close()


    ### Uncomment to read file for opt action reward #####
    # with open('script/online/data_larger/opt_a_reward_train.pkl', 'rb') as f: ## larger horizon
    #     reward_opt_action = pickle.load(f)
    # f.close()

    with open('script/online/data_large/opt_a_reward_train.pkl', 'rb') as f: ## large horizon
        reward_opt_action = pickle.load(f)
    f.close()

    x = torch.mean(reward_opt_action, dim = 0)
    x = torch.mean(x, dim = 0)
    print("Actual observed average reward of optimal action in train dataset", x)

    # ipdb.set_trace()

    opt_action_test = torch.zeros(n_envs, action_dim)
    for env in tqdm(range(n_eval)):
        opt_action_test[env] = eval_dataset[env]['optimal_actions']
        

    bar_width = 0.35
    x = np.arange(0, action_dim, 1)
    y1 = torch.sum(opt_action, dim = 0)
    y1 = y1.detach().numpy()

    y2 = torch.sum(opt_action_test, dim = 0)
    y2 = y2.detach().numpy()

    ### plot barplot of opt_action train dataset, test dataset #############
    plt.bar(x, y1, width=bar_width, label=f'train dataset')
    # plt.bar(x - 1.5*bar_width, y1, width=bar_width, label=f'train dataset')
    # plt.bar(x + 1.5*bar_width, y2, width=bar_width, label=f'test dataset')

    # plt.bar(np.arange(0, action_dim, 1), x)
    plt.xlabel('Arms')
    plt.ylabel('Frequency')
    plt.title('Frequency of optimal actions in training dataset')
    plt.xticks(np.arange(0, action_dim, 1), np.arange(1, action_dim+1, 1))

    # plt.savefig('script/online/data/optimal_action_train.png') ## smaller horizon
    # plt.savefig('script/online/data_larger/optimal_action_train.png') ## larger horizon
    plt.savefig('script/online/data_large/optimal_action_train.png') ## large horizon
    plt.close()
    plt.clf()

    plt.xlabel('Arms')
    plt.ylabel('Frequency')
    plt.title('Frequency of optimal actions in test dataset')
    plt.bar(x, y2, width=bar_width, label=f'test dataset')
    # plt.bar(np.arange(0, action_dim, 1), x)
    plt.xlabel('Arms')
    plt.ylabel('Frequency')
    plt.title('Frequency of optimal actions in test dataset')
    plt.xticks(np.arange(0, action_dim, 1), np.arange(1, action_dim+1, 1))

    # plt.savefig('script/online/data/optimal_action_test.png') ## smaller horizon
    # plt.savefig('script/online/data_larger/optimal_action_test.png') ## larger horizon
    plt.savefig('script/online/data_large/optimal_action_test.png') ## large horizon

    # ipdb.set_trace()
