from pathlib import Path
import torch
from typing import Callable, Dict, Type, List
from tqdm import tqdm
import multiprocessing as mp
from pathlib import Path
import argparse

from src.wpm import WPMSolver, wpm_swf
from src.kolm import KolmSolver, kolm_swf
from src.gini import GiniSolver, gini_swf

from src.armsets import BetaArmSet
from src.sampler import PiPSSampler
from src.swf_ucb import SWFUCB

from src.logger import generate_run_id, log_experiment


def run_single_experiment(
    seed: int,
    arm_set: BetaArmSet,
    n_rounds_ls: List[int],
    num_alloc: int,
    delta: float,
    weights: torch.Tensor,
    solver_cls: Type,
    swf_func: Callable,
    pow_val: float,
    objective: str,
    verbose: bool = True,
) -> Dict:
    """
    Runs a single experiment for a given set of parameters.

    Args:
        seed (int): Random seed for reproducibility.
        n_arms (int): Number of arms in the bandit problem.
        n_rounds_ls (List[int]): List of round counts to run the experiment for.
        num_alloc (int): Number of allocations to perform.
        delta (int): Parameter for the UCB algorithm.
        weights (torch.Tensor): Weights for the social welfare function.
        solver_cls (Type): Class of the solver to use.
        swf_func (Callable): Social welfare function.
        pow_val (float): Power value for the SWF.
        objective (str): Objective type for the experiment.
        verbose (bool, optional): Whether to print progress. Defaults to True.

    Returns:
        torch.Tensor: Results of the experiment.
    """
    gen = torch.Generator().manual_seed(seed)
    torch.manual_seed(seed)

    n_arms = arm_set.n_arms
    # Initialize solvers
    if objective == "gini":
        solver = solver_cls(weights, num_alloc)
    else:
        solver = solver_cls(weights, pow_val, num_alloc)

    # Calculate optimal allocation and SWF value
    opt_probs = solver.get_allocation_probabilities(arm_set.means)
    if objective == "gini":
        opt_swf = swf_func(arm_set.means * opt_probs, weights)
    else:
        opt_swf = swf_func(arm_set.means * opt_probs, weights, pow_val)

    # Ensure n_rounds_ls is sorted
    n_rounds_ls = sorted(n_rounds_ls)

    # Initialize sampler and UCB
    sampler = PiPSSampler(generator=gen)
    ucb = SWFUCB(n_arms, num_alloc, solver, sampler, delta)

    regret_sums = []
    iterator = tqdm(range(n_rounds_ls[-1]), leave=False) if verbose else range(n_rounds_ls[-1])
    regret_sum = 0.0

    for t in iterator:
        inds = ucb.select_arms()
        rewards = arm_set.sample(inds)
        ucb.update(inds, rewards)

        # Calculate regret for current allocation
        current_probs = ucb.probs
        if objective == "gini":
            current_swf = swf_func(arm_set.means * current_probs, weights)
        else:
            current_swf = swf_func(arm_set.means * current_probs, weights, pow_val)

        regret = opt_swf - current_swf
        regret_sum += regret.item()

        if (t + 1) in n_rounds_ls:
            regret_sums.append(regret_sum)

    res_dict = {
        "time_steps": n_rounds_ls,
        "regret_sums": regret_sums,
        "opt_swf": opt_swf.item(),
    }
    return res_dict


def run_parallel_experiments(
    instance_seed: int,
    seeds: List[int],
    n_arms: int,
    n_rounds_ls: List[int],
    num_alloc: int,
    delta: float,
    weights: torch.Tensor,
    solver_cls: Type,
    swf_func: Callable,
    pow_val: float,
    objective: str,
    verbose: bool = True,
) -> Dict:
    """
    Runs multiple experiments in parallel for a given set of parameters.

    Args:
        seeds (List[int]): List of random seeds for reproducibility.
        n_arms (int): Number of arms in the bandit problem.
        n_rounds_ls (List[int]): List of round counts to run the experiment for.
        num_alloc (int): Number of allocations to perform.
        delta (float): Parameter for the UCB algorithm.
        weights (torch.Tensor): Weights for the social welfare function.
        solver_cls (Type): Class of the solver to use.
        swf_func (Callable): Social welfare function.
        pow_val (float): Power value for the SWF.
        objective (str): Objective type for the experiment.
        verbose (bool, optional): Whether to print progress. Defaults to True.

    Returns:
        Dict: Results of the experiments.
    """
    instance_gen = torch.Generator().manual_seed(instance_seed)
    arm_set = BetaArmSet(n_arms=n_arms, gen=instance_gen)
    args_list = [
        (
            seed,
            arm_set,
            n_rounds_ls,
            num_alloc,
            delta,
            weights,
            solver_cls,
            swf_func,
            pow_val,
            objective,
            verbose,  # Disable verbose for individual experiments
        )
        for seed in seeds
    ]

    with mp.Pool(processes=mp.cpu_count()) as pool:
        results = list(pool.starmap(run_single_experiment, args_list))

    regret_tensor = torch.tensor([res["regret_sums"] for res in results])
    opt_swf_tensor = torch.tensor([res["opt_swf"] for res in results])

    res_dict = {
        "time_steps": n_rounds_ls,
        "mean_regret": regret_tensor.mean(dim=0),
        "std_regret": regret_tensor.std(dim=0),
    }
    return res_dict


def main(weight_type: str, objectives: str):
    instance_seed = 42

    # Experiment parameters
    n_arms = 50
    n_rounds_ls = (2 ** torch.arange(9)) * 1_000
    num_alloc_ls = [1, 5, 10, 25, 40]
    delta = 0.05
    pow_val = -2.0

    if weight_type == "linear":
        weights = torch.flip(torch.linspace(1.0, 2.0, n_arms, dtype=torch.float64), dims=[0])
    else:
        weights = 0.9 ** torch.arange(n_arms, dtype=torch.float64)
    weights = weights / weights.sum()

    seeds = list(range(5))

    dir_id = generate_run_id(f"trends_T_{weight_type}")
    run_dir = Path(f"runs/{dir_id}")

    if "w" in objectives:
        print(f"Running experiments for WPM")
        for num_alloc in num_alloc_ls:
            print(f"Running experiments for num_alloc = {num_alloc}")

            res_wpm = run_parallel_experiments(
                instance_seed=instance_seed,
                seeds=seeds,
                n_arms=n_arms,
                n_rounds_ls=n_rounds_ls.tolist(),
                num_alloc=num_alloc,
                delta=delta,
                weights=weights,
                solver_cls=WPMSolver,
                swf_func=wpm_swf,
                pow_val=pow_val,
                objective="wpm",
                verbose=True,
            )
            print(f"{'Time Step':<12} | {'Mean Regret':<15} | {'Std Dev':<15}")
            print("-" * 46)
            for t, m, s in zip(
                res_wpm["time_steps"], res_wpm["mean_regret"], res_wpm["std_regret"]
            ):
                print(f"{t:<12} | {m.item():<15.4f} | {s.item():<15.4f}")

            id = generate_run_id(f"wpm_num_alloc_{num_alloc}")
            config = {
                "instance_seed": instance_seed,
                "n_arms": n_arms,
                "n_rounds_ls": n_rounds_ls.tolist(),
                "num_alloc": num_alloc,
                "delta": delta,
                "objective": "wpm",
                "weights": weights.tolist(),
                "pow_val": pow_val,
            }
            log_experiment(config=config, results=res_wpm, run_id=id, run_dir=run_dir)

    if "k" in objectives:
        print(f"Running experiments for Kolm")
        for num_alloc in num_alloc_ls:
            print(f"Running experiments for num_alloc = {num_alloc}")
            res_kolm = run_parallel_experiments(
                instance_seed=instance_seed,
                seeds=seeds,
                n_arms=n_arms,
                n_rounds_ls=n_rounds_ls.tolist(),
                num_alloc=num_alloc,
                delta=delta,
                weights=weights,
                solver_cls=KolmSolver,
                swf_func=kolm_swf,
                pow_val=pow_val,
                objective="kolm",
                verbose=True,
            )
            print(f"{'Time Step':<12} | {'Mean Regret':<15} | {'Std Dev':<15}")
            print("-" * 46)
            for t, m, s in zip(
                res_kolm["time_steps"], res_kolm["mean_regret"], res_kolm["std_regret"]
            ):
                print(f"{t:<12} | {m.item():<15.4f} | {s.item():<15.4f}")

            id = generate_run_id(f"kolm_num_alloc_{num_alloc}")
            config = {
                "instance_seed": instance_seed,
                "n_arms": n_arms,
                "n_rounds_ls": n_rounds_ls.tolist(),
                "num_alloc": num_alloc,
                "delta": delta,
                "objective": "kolm",
                "weights": weights.tolist(),
                "pow_val": pow_val,
            }
            log_experiment(config=config, results=res_kolm, run_id=id, run_dir=run_dir)

    if "g" in objectives:
        print(f"Running experiments for Gini")
        for num_alloc in num_alloc_ls:
            print(f"Running experiments for num_alloc = {num_alloc}")
            res_gini = run_parallel_experiments(
                instance_seed=instance_seed,
                seeds=seeds,
                n_arms=n_arms,
                n_rounds_ls=n_rounds_ls.tolist(),
                num_alloc=num_alloc,
                delta=delta,
                weights=weights,
                solver_cls=GiniSolver,
                swf_func=gini_swf,
                pow_val=0.0,
                objective="gini",
                verbose=True,
            )

            print(f"{'Time Step':<12} | {'Mean Regret':<15} | {'Std Dev':<15}")
            print("-" * 46)
            for t, m, s in zip(
                res_gini["time_steps"], res_gini["mean_regret"], res_gini["std_regret"]
            ):
                print(f"{t:<12} | {m.item():<15.4f} | {s.item():<15.4f}")

            id = generate_run_id(f"gini_num_alloc_{num_alloc}")
            config = {
                "instance_seed": instance_seed,
                "n_arms": n_arms,
                "n_rounds_ls": n_rounds_ls.tolist(),
                "num_alloc": num_alloc,
                "delta": delta,
                "objective": "gini",
                "weights": weights.tolist(),
                "pow_val": 0.0,
            }
            log_experiment(config=config, results=res_gini, run_id=id, run_dir=run_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run experiments to analyze regret trends over time."
    )
    parser.add_argument(
        "--weight-type",
        type=str,
        choices=["linear", "geometric"],
        default="geometric",
    )
    parser.add_argument(
        "--objectives",
        type=str,
        default="wkg",
        help="Objectives to run: w (WPM), k (Kolm), g (Gini)",
    )
    args = parser.parse_args()
    weight_type = args.weight_type
    objectives = args.objectives
    main(weight_type, objectives)
