#!/usr/bin/env python
import argparse
import os
import json
from multiprocessing import Pool

import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

from env import LogisticBanditEnvironment
from agent_diag import LogisticBanditAgent
from utils import sigmoid


def simulate_experiment(
    algorithm: str,
    T: int,
    batch_size: int,
    d: int,
    K: int,
    latent_function: str,
    cnum: int,
    S: float,
    kappa: float,
    nu: float,
    lambda_reg: float,
    seed: int,
    m: int,
) -> list[float]:
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    env = LogisticBanditEnvironment(
        device, d=d, K=K, latent_function=latent_function, cnum=cnum, seed=seed
    )
    contexts = torch.tensor(
        env.get_contexts(T, K, d), dtype=torch.float32, device=device
    )
    rewards, probs = env.get_rewards(contexts)
    optimal_probs, _ = probs.max(dim=1)

    X = contexts.view(-1, d)
    r = rewards.view(-1, 1)
    dataset = torch.utils.data.TensorDataset(X, r)

    agent = LogisticBanditAgent(
        device,
        dataset,
        d=d,
        K=K,
        m=m,
        S=S,
        kappa=kappa,
        nu=nu,
        lambda_reg=lambda_reg,
        batch_size=batch_size,
        algorithm=algorithm,
        lr=1e-2,
        seed=seed,
    )

    cumulative_regret = 0.0
    regrets = []
    num_batches = T // batch_size
    for batch_t in tqdm(range(1, num_batches + 1)):
        start = (batch_t - 1) * batch_size
        end = batch_t * batch_size
        ctx = contexts[start:end]
        rew = rewards[start:end]
        pr = probs[start:end]
        opt = optimal_probs[start:end]

        idx, _ = agent.select_action(ctx)
        chosen_ctx = torch.take_along_dim(ctx, idx.unsqueeze(-1), dim=1).squeeze(1)
        chosen_pr = torch.take_along_dim(pr, idx.unsqueeze(-1), dim=1).squeeze(1)

        batch_regret = (opt - chosen_pr).view(-1)
        for r_ in batch_regret:
            cumulative_regret += float(r_)
            regrets.append(cumulative_regret)

        agent.record(batch_t, idx)
        agent.update_model(num_epochs=100, update_batch_size=50)
        agent.update_design_matrix(chosen_ctx)

    return regrets


def run_repeated_experiments(
    algorithm: str,
    T: int,
    batch_size: int,
    d: int,
    K: int,
    latent_function: str,
    cnum: int,
    seed: int,
    S: float,
    kappa: float,
    nu: float,
    lambda_reg: float,
    num_runs: int,
    m: int,
) -> np.ndarray:
    all_runs = []
    for run in tqdm(range(num_runs), desc=f"{algorithm} runs"):
        regrets = simulate_experiment(
            algorithm,
            T,
            batch_size,
            d,
            K,
            latent_function,
            cnum,
            S,
            kappa,
            nu,
            lambda_reg,
            seed + run,
            m,
        )
        all_runs.append(regrets)
    return np.array(all_runs)


def process_algorithm(args: tuple) -> tuple:
    (
        algorithm,
        T,
        batch_size,
        d,
        K,
        latent_function,
        cnum,
        num_runs,
        seed,
        S,
        kappa,
        nu,
        lambda_reg,
        m,
    ) = args

    data = run_repeated_experiments(
        algorithm,
        T,
        batch_size,
        d,
        K,
        latent_function,
        cnum,
        seed,
        S,
        kappa,
        nu,
        lambda_reg,
        num_runs,
        m,
    )
    mean = data.mean(axis=0)
    std = data.std(axis=0)
    return algorithm, mean, std, data


def main():
    parser = argparse.ArgumentParser(
        description="Generate figures comparing bandit algorithms"
    )
    parser.add_argument(
        "--algorithms",
        type=str,
        default="alg1,alg2,alg3,alg4,alg5",
    )
    parser.add_argument("--T", type=int, default=2000)
    parser.add_argument("--batch_size", type=int, default=50)
    parser.add_argument("--d", type=int, default=20)
    parser.add_argument("--m", type=int, default=20)
    parser.add_argument("--K", type=int, default=5)
    parser.add_argument(
        "--latent_function",
        type=str,
        default="h1",
        choices=["h1", "h2", "h3", "h4", "h5", "h6"],
    )
    parser.add_argument("--cnum", type=int, default=500)
    parser.add_argument("--num_runs", type=int, default=10)
    parser.add_argument("--output_dir", type=str, default="results/figures")
    parser.add_argument("--S", type=float, default=1.0)
    parser.add_argument("--kappa", type=float, default=10.0)
    parser.add_argument(
        '--nu',
        nargs='+',         
        type=float,      
        default=[0.1, 0.1, 0.1, 1, 1],
    )
    parser.add_argument(
        '--lambda_reg',
        nargs='+',         
        type=float,
        default=[0.1, 0.1, 0.1, 1, 1],   
    )
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    algs = [a.strip() for a in args.algorithms.split(",")]
    os.makedirs(args.output_dir, exist_ok=True)

    arg_list = []
    for alg, nu_, lambda_reg_ in zip(algs, args.nu, args.lambda_reg):
        arg_list.append(
            (
                alg,
                args.T,
                args.batch_size,
                args.d,
                args.K,
                args.latent_function,
                args.cnum,
                args.num_runs,
                args.seed,
                args.S,
                args.kappa,
                nu_,
                lambda_reg_,
                args.m,
            )
        )

    with Pool() as pool:
        results = [r for r in pool.map(process_algorithm, arg_list) if r]

    rounds = np.arange(1, args.T + 1)
    suffix = (
        f"T{args.T}_bs{args.batch_size}_d{args.d}_K{args.K}_nr{args.num_runs}"
        f"_lat{args.latent_function}_cnum{args.cnum}_S{args.S}_kappa{args.kappa}"
        f"_m{args.m}_seed{args.seed}"
    )

    avg = {alg: mean for alg, mean, _, _ in results}
    std = {alg: sd for alg, _, sd, _ in results}

    # Figure 1
    plt.figure(figsize=(10, 7))
    for alg, mean in avg.items():
        plt.plot(rounds, mean, label=alg)
    plt.xlabel("Time Steps")
    plt.ylabel("Cumulative Regret")
    plt.title(f"Avg Cumulative Regret ({args.latent_function})")
    plt.legend()
    plt.grid(True)
    fig1 = os.path.join(
        args.output_dir, f"avg_regret_{'_'.join(algs)}_{suffix}.png"
    )
    plt.savefig(fig1)
    plt.show()

    # Figure 2
    plt.figure(figsize=(10, 7))
    for alg in avg:
        plt.plot(rounds, avg[alg], label=alg)
        plt.fill_between(rounds, avg[alg] - 2 * std[alg], avg[alg] + 2 * std[alg], alpha=0.2)
    plt.xlabel("Time Steps")
    plt.ylabel("Cumulative Regret")
    plt.title(f"Avg Regret w/ Variance ({args.latent_function})")
    plt.legend()
    plt.grid(True)
    fig2 = os.path.join(
        args.output_dir, f"avg_var_{'_'.join(algs)}_{suffix}.png"
    )
    plt.savefig(fig2)
    plt.show()

    data_file = os.path.join(
        args.output_dir, f"regret_data_{'_'.join(algs)}_{suffix}.npz"
    )
    np.savez(data_file, **{alg: raw for alg, _, _, raw in results})


if __name__ == "__main__":
    main()
