#!/usr/bin/env python
import argparse
import os
import json
from multiprocessing import Pool

import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

from env import LogisticBanditRealEnvironment
from agent_diag import LogisticBanditAgent
from utils import sigmoid


def simulate_experiment(
    algorithm: str,
    T: int,
    batch_size: int,
    d: int,  # not used
    K: int,
    latent_function: str,
    cnum: int,
    S: float,
    kappa: float,
    nu: float,
    lambda_reg: float,
    seed: int,
    m: int,
) -> list[float]:
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    env = LogisticBanditRealEnvironment(
        device, d=d, K=K, latent_function=latent_function, cnum=cnum, seed=seed
    )
    contexts = torch.tensor(
        env.get_contexts(T), dtype=torch.float32, device=device
    )
    rewards, probs = env.get_rewards(T)
    optimal_probs, _ = probs.max(dim=1)

    d = contexts.size(2)
    K = contexts.size(1)
    X = contexts.view(-1, d)
    r = rewards.view(-1, 1)
    dataset = torch.utils.data.TensorDataset(X, r)

    agent = LogisticBanditAgent(
        device=device,
        dataset=dataset,
        d=d,
        K=K,
        m=m,
        S=S,
        kappa=kappa,
        nu=nu,
        lambda_reg=lambda_reg,
        batch_size=batch_size,
        algorithm=algorithm,
        lr=1e-2,
        seed=seed,
    )

    cumulative_regret = 0.0
    regrets = []
    num_batches = T // batch_size
    for batch_t in tqdm(range(1, num_batches + 1)):
        start = (batch_t - 1) * batch_size
        end = batch_t * batch_size
        ctx = contexts[start:end]
        rew = rewards[start:end]
        pr = probs[start:end]
        opt = optimal_probs[start:end]

        idx, _ = agent.select_action(ctx)
        chosen_pr = torch.take_along_dim(pr, idx.unsqueeze(-1), dim=1).squeeze(1)
        batch_regret = (opt - chosen_pr).view(-1)

        for r_ in batch_regret:
            cumulative_regret += float(r_)
            regrets.append(cumulative_regret)

        agent.record(batch_t, idx)
        agent.update_model(num_epochs=100, update_batch_size=50)
        agent.update_design_matrix(
            torch.take_along_dim(ctx, idx.unsqueeze(-1), dim=1).squeeze(1)
        )

    return regrets


def run_repeated_experiments(
    algorithm: str,
    T: int,
    batch_size: int,
    d: int,
    K: int,
    latent_function: str,
    cnum: int,
    seed: int,
    S: float,
    kappa: float,
    nu: float,
    lambda_reg: float,
    num_runs: int,
    m: int,
) -> np.ndarray:
    all_runs = []
    for run in tqdm(range(num_runs), desc=f"{algorithm} runs"):
        regrets = simulate_experiment(
            algorithm,
            T,
            batch_size,
            d,
            K,
            latent_function,
            cnum,
            S,
            kappa,
            nu,
            lambda_reg,
            seed + run,
            m,
        )
        all_runs.append(regrets)
    return np.array(all_runs)


def process_algorithm(args: tuple) -> tuple:
    (
        algorithm,
        T,
        batch_size,
        d,
        K,
        latent_function,
        cnum,
        num_runs,
        seed,
        S,
        kappa,
        nu,
        lambda_reg,
        m,
    ) = args

    data = run_repeated_experiments(
        algorithm,
        T,
        batch_size,
        d,
        K,
        latent_function,
        cnum,
        seed,
        S,
        kappa,
        nu,
        lambda_reg,
        num_runs,
        m,
    )
    mean = data.mean(axis=0)
    std = data.std(axis=0)
    return algorithm, mean, std, data


def main():
    parser = argparse.ArgumentParser(
        description="Generate figures for real environment experiments"
    )
    parser.add_argument(
        "--algorithms",
        type=str,
        default="alg1,alg2,alg3,alg4,alg5",
    )
    parser.add_argument("--T", type=int, default=2000)
    parser.add_argument("--batch_size", type=int, default=50)
    parser.add_argument("--d", type=int, default=20)
    parser.add_argument("--m", type=int, default=20)
    parser.add_argument("--K", type=int, default=5)
    parser.add_argument(
        "--latent_function",
        type=str,
        default="mnist",
        choices=["mnist", "mushroom", "shuttle"],
    )
    parser.add_argument("--cnum", type=int, default=500)
    parser.add_argument("--num_runs", type=int, default=10)
    parser.add_argument("--output_dir", type=str, default="results/figures")
    parser.add_argument("--S", type=float, default=1.0)
    parser.add_argument("--kappa", type=float, default=10.0)
    parser.add_argument(
        '--nu',
        nargs='+',         
        type=float,      
        default=[0.1, 0.1, 0.1, 1, 1],
    )
    parser.add_argument(
        '--lambda_reg',
        nargs='+',         
        type=float,
        default=[0.1, 0.1, 0.1, 1, 1],   
    )
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    algs = [a.strip() for a in args.algorithms.split(",")]
    os.makedirs(args.output_dir, exist_ok=True)

    arg_list = []
    for alg, nu_, lambda_reg_ in zip(algs, args.nu, args.lambda_reg):
        arg_list.append(
            (
                alg,
                args.T,
                args.batch_size,
                args.d,
                args.K,
                args.latent_function,
                args.cnum,
                args.num_runs,
                args.seed,
                args.S,
                args.kappa,
                nu_,
                lambda_reg_,
                args.m,
            )
        )

    with Pool() as pool:
        results = [r for r in pool.map(process_algorithm, arg_list) if r]

    rounds = np.arange(1, args.T + 1)
    suffix = (
        f"T{args.T}_bs{args.batch_size}_d{args.d}_K{args.K}_nr{args.num_runs}"
        f"_lat{args.latent_function}_cnum{args.cnum}_S{args.S}_kappa{args.kappa}"
        f"_m{args.m}_seed{args.seed}"
    )

    avg = {alg: m for alg, m, _, _ in results}
    std = {alg: s for alg, _, s, _ in results}

    plt.figure(figsize=(10, 7))
    for alg, m in avg.items():
        plt.plot(rounds, m, label=alg)
    plt.xlabel("Timesteps")
    plt.ylabel("Cumulative Regret")
    plt.title(f"Avg Regret ({args.latent_function})")
    plt.legend()
    plt.grid(True)
    fig1 = os.path.join(args.output_dir, f"avg_real_{'_'.join(algs)}_{suffix}.png")
    plt.savefig(fig1)
    plt.show()

    plt.figure(figsize=(10, 7))
    for alg in avg:
        plt.plot(rounds, avg[alg], label=alg)
        plt.fill_between(rounds, avg[alg] - std[alg], avg[alg] + std[alg], alpha=0.2)
    plt.xlabel("Timesteps")
    plt.ylabel("Cumulative Regret")
    plt.title(f"Avg Regret w/ Var ({args.latent_function})")
    plt.legend()
    plt.grid(True)
    fig2 = os.path.join(args.output_dir, f"avg_var_real_{'_'.join(algs)}_{suffix}.png")
    plt.savefig(fig2)
    plt.show()

    data_file = os.path.join(args.output_dir, f"regret_data_{'_'.join(algs)}_{suffix}.npz")
    np.savez(data_file, **{alg: raw for alg, _, _, raw in results})


if __name__ == "__main__":
    main()
