import numpy as np
import torch
import os
import argparse
import matplotlib.pyplot as plt
import sys

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from Zhu2022_SmoothIGW.SmoothIGW import *
from env import *
from oracle import *

cdf_type_list = ['MoU', 'MoN']
dataset_list = ['abalone', 'diamonds', 'energy', 'housing', 'obesity', 'wine']

algo_name = f'SmoothIGW'

num_grad = 2
gamma_set = [4, 16, 64, 256, 1024]
lr_set = [2e-3, 1e-2, 5e-2]


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda', type=str, default='0')
    parser.add_argument('--search', action='store_true')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"]= args.cuda

    device = "cuda" if torch.cuda.is_available() else "cpu"
    env_generator = torch.Generator(device=device)
    algo_generator = torch.Generator(device=device)

    for dataset in dataset_list:
        for cdf_type in cdf_type_list:
            print(f'dataset={dataset}_cdf={cdf_type}')
            T = 2000
            rep = 5
            env_seed = algo_seed = 1234

            basedir = f'search/real/{dataset}/cdf={cdf_type}/{algo_name}'
            if not os.path.exists(basedir):
                os.makedirs(basedir)
            
            if cdf_type=='gaussian':
                cdf = gaussian_cdf
            elif cdf_type=='MoU':
                cdf = MoU_cdf
            else:
                raise NotImplementedError

            price_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_price.npy')).astype(np.float32)
            feature_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_features.npy')).astype(np.float32)
            optimal_price_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_optimal_price_{cdf_type}.npy')).astype(np.float32)
            optimal_revenue_data = np.load(os.path.join(parent_dir, f'datasets/{dataset}_optimal_revenue_{cdf_type}.npy')).astype(np.float32)
            d = feature_data.shape[-1]
            env = DataEnv(price_data, feature_data, optimal_price_data, optimal_revenue_data, cdf, device, None)

            if args.search:
                # grid search over parameters
                for gamma in gamma_set:
                    for lr in lr_set:
                        print(f'gamma={gamma}_lr={lr}')
                        env_generator = np.random.default_rng(seed=env_seed)
                        env.generator = env_generator
                        algo_generator = torch.Generator(device=device)
                        algo_generator.manual_seed(algo_seed)

                        oracle = SqMLPOracle(d=d, T=T, lr=lr, num_grad=num_grad)
                        algo = SmoothIGW(gamma=gamma, T=T, oracle=oracle, generator=algo_generator)
                        algo.run(rep=rep, env=env, basedir=f'{basedir}/gamma={gamma}_lr={lr}')

            best = 1e9
            for gamma in gamma_set:
                for lr in lr_set:
                    dir=f'{basedir}/gamma={gamma}_lr={lr}'
                    reward = np.load(dir+'/reward.npy')
                    optimal_reward = np.load(dir+'/optimal_reward.npy')
                    mean_regret = np.mean(optimal_reward-reward)
                    if mean_regret < best:
                        best = mean_regret
                        best_param = (gamma,lr)
                    print(f'gamma={gamma}_lr={lr}: {mean_regret}')

            # run with the best parameter
            print(f'best parameter : gamma={best_param[0]}, lr={best_param[1]}')

            gamma = best_param[0]
            lr = best_param[1]
            T = 5000
            rep = 5
            env_seed = algo_seed = 123
            basedir = f'results/real/{dataset}/cdf={cdf_type}/{algo_name}'

            env_generator = np.random.default_rng(seed=env_seed)
            env.generator = env_generator
            algo_generator = torch.Generator(device=device)
            algo_generator.manual_seed(algo_seed)

            oracle = SqMLPOracle(d=d, T=T, lr=lr, num_grad=num_grad)
            algo = SmoothIGW(gamma=gamma, T=T, oracle=oracle, generator=algo_generator)
            algo.run(rep=rep, env=env, basedir=basedir)
                
