import numpy as np
import torch
import os
import sys
import argparse
import matplotlib.pyplot as plt

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from Chen2021_ABE.ABE import *
from env import *


cdf_type_list = ['MoN']
dataset_list = ['abalone', 'housing', 'wine']

algo_name = 'ABE'

c_set = [1/16, 1/4, 1, 4, 16]

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda', type=str, default='0')
    parser.add_argument('--search', action='store_true')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"]= args.cuda

    device = "cuda" if torch.cuda.is_available() else "cpu"

    for dataset in dataset_list:
        for cdf_type in cdf_type_list:
            print(f'dataset={dataset}_cdf={cdf_type}')
            T = 2000
            rep = 5
            env_seed = algo_seed = 1234

            basedir = f'search/real/{dataset}/cdf={cdf_type}/{algo_name}'
            if not os.path.exists(basedir):
                os.makedirs(basedir)
            
            if cdf_type=='gaussian':
                cdf = gaussian_cdf
            elif cdf_type=='MoU':
                cdf = MoU_cdf
            elif cdf_type=='MoN':
                cdf = MoN_cdf
            else:
                raise NotImplementedError
            
            price_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_price.npy')).astype(np.float32)
            feature_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_features.npy')).astype(np.float32)
            optimal_price_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_optimal_price_{cdf_type}.npy')).astype(np.float32)
            optimal_revenue_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_optimal_revenue_{cdf_type}.npy')).astype(np.float32)
            d = feature_data.shape[-1]
            env = DataEnv(price_data, feature_data, optimal_price_data, optimal_revenue_data, cdf, device, None)

            if args.search:
                for c in c_set:
                    print(f'c={c}')
                    env_generator = np.random.default_rng(seed=env_seed)
                    env.generator = env_generator
                    algo_generator = torch.Generator(device=device)
                    algo_generator.manual_seed(algo_seed)

                    algo = ABE(d, c, T)
                    algo.run(rep=rep, env=env, basedir=f'{basedir}/c={c}')
            
            best = 1e9
            for c in c_set:
                dir=f'{basedir}/c={c}'
                reward = np.load(dir+'/reward.npy')
                optimal_reward = np.load(dir+'/optimal_reward.npy')
                mean_regret = np.mean(optimal_reward-reward)
                if mean_regret < best:
                    best = mean_regret
                    best_param = c
                print(f'c={c}: {mean_regret}')
            
            # run with the best parameter
            print(f'best parameter : c={best_param}')

            c = best_param
            T = 5000
            rep = 5
            env_seed = algo_seed = 123
            basedir = f'results/real/{dataset}/cdf={cdf_type}/{algo_name}'

            env_generator = np.random.default_rng(seed=env_seed)
            env.generator = env_generator
            algo_generator = torch.Generator(device=device)
            algo_generator.manual_seed(algo_seed)
            algo = ABE(d, c, T)
            algo.run(rep=rep, env=env, basedir=basedir)