from Env.FPA_env import *
from Env.bid_general import *
from Algorithm.Wang_algo import *
from Algorithm.random_algo import *
from Algorithm.L2FOB import *
from utils import *

import numpy as np
import math
import copy
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", type=str, default="FPA_results")
parser.add_argument("--name", type=str, default="FPA_test")
parser.add_argument("--trials", type=int, default=5)
parser.add_argument("--save", type=bool, default=False)
args = parser.parse_args()

T = 10**4
B = 10**2
vbar = 1.0
seed_list = range(args.trials)
delta = 0.01
rho = 0.01
ROI_gamma = 1.8 
grid_size = 81
use_optimism = True

FPA_Wang_eta = 1/math.sqrt(T) # 10
L2FOB_eta_rho = 0.6 
L2FOB_eta_gamma = 0.6 

seed = 0
env = FPAEnvSimple(vbar=vbar, feedback="full", value_kind="normal", seed=seed)
b_star, opt_r, opt_c = compute_optimal_fixed_bid(env, T, B, ROI_gamma, grid_size=grid_size, seed=seed)
print(f"Optimal fixed bid: {b_star:.3f}, reward={opt_r:.2f}, cost={opt_c:.2f}, ROI={opt_r/opt_c if opt_c>0 else 0:.2f}")

alg_list = [
    ("Wang", WangBidding(T=T, B=B, eta=FPA_Wang_eta, rho=rho, grid_size=grid_size, delta=delta, use_optimism=use_optimism)),
    ("L2FOB", L2FOB_FPA(T=T, B=B, rho=rho, gamma=ROI_gamma, eta_rho=L2FOB_eta_rho, eta_gamma=L2FOB_eta_gamma, grid_size=grid_size, delta=delta, use_optimism=use_optimism))
]

for (alg_name, alg) in alg_list:
    print(f"Parameters: Wang_eta={FPA_Wang_eta}, eta_rho={L2FOB_eta_rho}, eta_gamma={L2FOB_eta_gamma}, use_optimism={use_optimism}")
    for i in range(args.trials):
        seed = seed_list[i]
        print(f"Trial {i+1} for {alg_name}")
        alg_copy = copy.deepcopy(alg)
        env.reset(seed=seed)
        np.random.seed(seed)
        runner = OnlineInteractionRunner(env=env, algo=alg_copy, T=T, budget=B, gamma=ROI_gamma)
        res = runner.run()
        if args.save:
            filepath = "FPA_results"
            filename = f"{args.name}_trial{i+1}_{alg_name}.pkl"
            save_results(filepath, filename, res)
        print(f"{alg_name} total_reward: {res.total_reward:.2f} total_cost: {res.total_cost:.2f} ROI: {res.total_reward/res.total_cost if res.total_cost>0 else 0:.2f} Gamma loss: {res.total_gamma_loss:.2f}")
        print("stop_round: ", res.stop_round, "stopped:", res.stopped_by_budget)