import numpy as np
import torch
import os
import argparse
import matplotlib.pyplot as plt
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Zhu2022_SmoothIGW.SmoothIGW import *
from env import *
from oracle import *

model_list = ['linear', 'loglinear', 'PH']
cdf_type_list = ['gaussian', 'MoU']
context_type_list = ['gaussian', 'uniform', 'binary']

algo_name = f'SmoothIGW'

num_grad = 2
gamma_set = [4, 16, 64, 256, 1024]
lr_set = [2e-3, 1e-2, 5e-2]


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda', type=str, default='0')
    parser.add_argument('--search', action='store_true')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"]= args.cuda

    device = "cuda" if torch.cuda.is_available() else "cpu"
    env_generator = torch.Generator(device=device)
    algo_generator = torch.Generator(device=device)

    for model in model_list:
        for cdf_type in cdf_type_list:
            for context_type in context_type_list:
                print(f'model={model}_cdf={cdf_type}_context={context_type}')

                d = 5
                T = 2000
                rep = 5
                env_seed = algo_seed = 1234

                basedir = f'search/sim_d={d}_T={T}/model={model}/cdf={cdf_type}/context={context_type}/{algo_name}'
                if not os.path.exists(basedir):
                    os.makedirs(basedir)

                if context_type=='gaussian':
                    context_dist = GaussianContext(d=d, sigma=1, generator=env_generator, device=device)
                elif context_type=='uniform':
                    context_dist = UniformContext(d=d, device=device, generator=env_generator)
                elif context_type=='binary':
                    context_dist = BinaryContext(d=d, device=device, generator=env_generator)
                else:
                    raise NotImplementedError
                
                if cdf_type=='gaussian':
                    cdf = gaussian_cdf
                elif cdf_type=='MoU':
                    cdf = MoU_cdf
                else:
                    raise NotImplementedError

                if model=='linear':
                    valuation_model = LinearModel(d=d, cdf=cdf, device=device)
                elif model=='loglinear':
                    valuation_model = LogLinearModel(d=d, cdf=cdf, device=device)
                elif model=='PH':
                    valuation_model = PHModel(d=d, cdf=cdf, device=device)
                else:
                    raise NotImplementedError
                env = Env(generator=env_generator, context_dist=context_dist, valuation_model=valuation_model)

                if args.search:
                    # grid search over parameters
                    for gamma in gamma_set:
                        for lr in lr_set:
                            print(f'gamma={gamma}_lr={lr}')
                            env_generator.manual_seed(env_seed)
                            algo_generator.manual_seed(algo_seed)

                            oracle = SqMLPOracle(d=d, T=T, lr=lr, num_grad=num_grad)
                            algo = SmoothIGW(gamma=gamma, T=T, oracle=oracle, generator=algo_generator)
                            algo.run(rep=rep, env=env, basedir=f'{basedir}/gamma={gamma}_lr={lr}')

                best = 1e9
                for gamma in gamma_set:
                    for lr in lr_set:
                        dir=f'{basedir}/gamma={gamma}_lr={lr}'
                        reward = np.load(dir+'/reward.npy')
                        optimal_reward = np.load(dir+'/optimal_reward.npy')
                        mean_regret = np.mean(optimal_reward-reward)
                        if mean_regret < best:
                            best = mean_regret
                            best_param = (gamma,lr)
                        print(f'gamma={gamma}_lr={lr}: {mean_regret}')

                # run with the best parameter
                print(f'best parameter : gamma={best_param[0]}, lr={best_param[1]}')

                gamma = best_param[0]
                lr = best_param[1]
                d = 5
                T = 5000
                rep = 5
                env_seed = algo_seed = 123
                basedir = f'results/sim_d={d}_T={T}/model={model}/cdf={cdf_type}/context={context_type}/{algo_name}'

                env_generator.manual_seed(env_seed)
                algo_generator.manual_seed(algo_seed)

                oracle = SqMLPOracle(d=d, T=T, lr=lr, num_grad=num_grad)
                algo = SmoothIGW(gamma=gamma, T=T, oracle=oracle, generator=algo_generator)
                algo.run(rep=rep, env=env, basedir=basedir)
                
