import numpy as np
import torch
import os
import sys
import argparse
import matplotlib.pyplot as plt

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from Fan import Fan
from env import *

cdf_type_list = ['MoU', 'MoN']
dataset_list = ['abalone', 'diamonds', 'energy', 'housing', 'obesity', 'wine']

algo_name = 'Fan'

l0_set = [32, 64, 128, 256, 512]
C1_set = [1/4, 1/2, 1, 2, 4]


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda', type=str, default='0')
    parser.add_argument('--search', action='store_true')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"]= args.cuda

    device = "cuda" if torch.cuda.is_available() else "cpu"

    for dataset in dataset_list:
        for cdf_type in cdf_type_list:
            print(f'dataset={dataset}_cdf={cdf_type}')
            T = 2000
            rep = 5
            env_seed = algo_seed = 1234

            basedir = f'search/real/{dataset}/cdf={cdf_type}/{algo_name}'
            if not os.path.exists(basedir):
                os.makedirs(basedir)
            
            if cdf_type=='gaussian':
                cdf = gaussian_cdf
            elif cdf_type=='MoU':
                cdf = MoU_cdf
            elif cdf_type=='MoN':
                cdf = MoN_cdf
            else:
                raise NotImplementedError

            price_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_price.npy')).astype(np.float32)
            feature_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_features.npy')).astype(np.float32)
            optimal_price_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_optimal_price_{cdf_type}.npy')).astype(np.float32)
            optimal_revenue_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_optimal_revenue_{cdf_type}.npy')).astype(np.float32)
            d = feature_data.shape[-1]
            env = DataEnv(price_data, feature_data, optimal_price_data, optimal_revenue_data, cdf, device, None)
                
            if args.search:
                for l0 in l0_set:
                    for C1 in C1_set:
                        print(f'l0={l0}_C1={C1}')
                        env_generator = np.random.default_rng(seed=env_seed)
                        env.generator = env_generator
                        algo_generator = torch.Generator(device=device)
                        algo_generator.manual_seed(algo_seed)

                        algo = Fan(l0=l0, C1=C1, d=d, T=T, generator=algo_generator)
                        algo.run(rep=rep, env=env, basedir=f'{basedir}/l0={l0}_C1={C1}')
            
            best = 1e9
            for l0 in l0_set:
                for C1 in C1_set:
                    dir=f'{basedir}/l0={l0}_C1={C1}'
                    reward = np.load(dir+'/reward.npy')
                    optimal_reward = np.load(dir+'/optimal_reward.npy')
                    mean_regret = np.mean(optimal_reward-reward)
                    if mean_regret < best:
                        best = mean_regret
                        best_param = (l0,C1)
                    print(f'l0={l0}_C1={C1}: {np.mean(optimal_reward-reward)}')
            
            # run with the best parameter
            print(f'best parameter : l0={best_param[0]}, C1={best_param[1]}')

            l0 = best_param[0]
            C1 = best_param[1]
            T = 5000
            rep = 5
            env_seed = algo_seed = 123
            basedir = f'results/real/{dataset}/cdf={cdf_type}/{algo_name}'

            env_generator = np.random.default_rng(seed=env_seed)
            env.generator = env_generator
            algo_generator = torch.Generator(device=device)
            algo_generator.manual_seed(algo_seed)
            algo = Fan(l0=l0, C1=C1, d=d, T=T, generator=algo_generator)
            algo.run(rep=rep, env=env, basedir=basedir)