from Env.FPA_env import *
from Env.bid_general import *
from Env.Bandit_env import *
from Algorithm.Wang_algo import *
from Algorithm.random_algo import *
from Algorithm.L2FOB import *
from Algorithm.OPT3 import *
from utils import *

import numpy as np
import copy
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", type=str, default="Bandit_results")
parser.add_argument("--name", type=str, default="Bandit_test")
parser.add_argument("--trials", type=int, default=5)
parser.add_argument("--save", type=bool, default=False)
args = parser.parse_args()

T = 5000
ARMS = 20

data = np.load("Data\mslr.npz")
context_dim = 136
precontext = data['features'].reshape(-1, context_dim)[:50,]
prerewards = data['relevances'].reshape(-1, 1)[:50]
relev = data['relevances'][:, :ARMS]
whole_context = data['features'][:, :ARMS, :]
whole_relev = relev / np.max(relev)

seed_list = range(args.trials)
ROI_gamma = 1.3 

OPT3_alpha = 0.2 
L2FOB_eta_rho = 0.6 
L2FOB_eta_gamma = 0.6 

budget = 1000

seed = 0
env = BanditEnv(context=whole_context, relev=whole_relev, T=T, tot_budget=budget, cost=None, reward_noise=0.05, cost_noise=0.05, seed=seed)
each_budget = budget / T

rmodel = gb5(depth=5, X_train=precontext, y_train=prerewards, update_time=50)
cmodel2 = cost_LCB(K=ARMS)

alg_list = [
    ("OPT3", OPT3(K=ARMS, dim=context_dim, alpha=OPT3_alpha, gamma=ROI_gamma, rmodel=rmodel, cmodel=cmodel2, D=3, budget=each_budget)),
    ("L2FOB", L2FOB_Bandit(T=T, B=budget, ARMS=ARMS, gamma=ROI_gamma, eta_rho=L2FOB_eta_rho, eta_gamma=L2FOB_eta_gamma, rho=float(budget/T), rmodel=rmodel, cmodel=cmodel2)),
    ]

for (alg_name, alg) in alg_list:
    print(f"Parameters: OPT3_alpha={OPT3_alpha}, eta_rho={L2FOB_eta_rho}, eta_gamma={L2FOB_eta_gamma}")
    for i in range(args.trials):
        seed = seed_list[i]
        print(f"Trial {i+1} for {alg_name} with budget {budget}")
        alg_copy = copy.deepcopy(alg)
        rmodel.reset(X_train=precontext, y_train=prerewards)
        env.reset(seed=seed)
        np.random.seed(seed)
        cmodel2.reset()
        runner = OnlineInteractionRunner(env=env, algo=alg, T=T, budget=budget, gamma=ROI_gamma)
        res = runner.run()
        
        if args.save:
            filepath = args.save_dir
            filename = f"{args.name}_trial{i+1}_{alg_name}.pkl"
            save_results(filepath, filename, res)
    
        print(f"{alg_name} Algorithm with budget {budget} in trial No.{i}:")
        print(f"total_reward: {res.total_reward:.2f} total_cost: {res.total_cost:.2f} ROI: {res.total_reward/res.total_cost:.2f} Gamma loss: {res.total_gamma_loss:.2f}")
        print("stop_round: ", res.stop_round, "stopped:", res.stopped_by_budget)