from pathlib import Path
import torch
from typing import Callable, Dict, Type, List
from tqdm import tqdm
import multiprocessing as mp
from pathlib import Path
import argparse

from src.wpm import WPMSolver, wpm_swf
from src.kolm import KolmSolver, kolm_swf
from src.gini import GiniSolver, gini_swf

from src.armsets import BetaArmSet
from src.sampler import PiPSSampler
from src.swf_ucb import SWFUCB

from src.logger import generate_run_id, log_experiment

def run_single_experiment(
        seed: int,
        arm_set: BetaArmSet,
        n_rounds: int,
        num_alloc: int,
        delta: float,
        weights: torch.Tensor,
        solver_cls: Type,
        swf_func: Callable,
        pow_val: float,
        objective: str,
        verbose: bool = True,
) -> Dict:
    """
    Runs a single experiment for a given set of parameters.

    Args:
        seed (int): Random seed for reproducibility.
        n_arms (int): Number of arms in the bandit problem.
        n_rounds (int): Number of rounds to run the experiment for.
        num_alloc (int): Number of allocations to perform.
        delta (float): Parameter for the UCB algorithm.
        weights (torch.Tensor): Weights for the social welfare function.
        solver_cls (Type): Class of the solver to use.
        swf_func (Callable): Social welfare function.
        pow_val (float): Power value for the SWF.
        objective (str): Objective type for the experiment.
        verbose (bool, optional): Whether to print progress. Defaults to True.

    Returns:
        torch.Tensor: Results of the experiment.
    """
    gen = torch.Generator().manual_seed(seed)
    torch.manual_seed(seed)

    n_arms = arm_set.n_arms
    # Initialize solvers
    if objective == "gini":
        solver = solver_cls(weights, num_alloc)
    else:
        solver = solver_cls(weights, pow_val, num_alloc)

    # Calculate optimal allocation and SWF value
    opt_probs = solver.get_allocation_probabilities(arm_set.means)
    if objective == "gini":
        opt_swf = swf_func(arm_set.means * opt_probs, weights)
    else:
        opt_swf = swf_func(arm_set.means * opt_probs, weights, pow_val)

    # Initialize sampler and UCB
    sampler = PiPSSampler(generator=gen)
    ucb = SWFUCB(n_arms, num_alloc, solver, sampler, delta)

    iterator = tqdm(range(n_rounds), leave = False) if verbose else range(n_rounds)
    regret_sum = 0.0

    for t in iterator:
        inds = ucb.select_arms()
        rewards = arm_set.sample(inds)
        ucb.update(inds, rewards)

        # Calculate regret for current allocation
        current_probs = ucb.probs
        if objective == "gini":
            current_swf = swf_func(arm_set.means * current_probs, weights)
        else:
            current_swf = swf_func(arm_set.means * current_probs, weights, pow_val)

        regret = opt_swf - current_swf
        regret_sum += regret.item()

    res_dict = {
        'regret_sum': regret_sum,
        'opt_swf': opt_swf.item(),
    }
    return res_dict

def run_parallel_experiments(
    instance_seed: int,
        seeds: List[int],
        n_arms: int,
        n_rounds: int,
        num_alloc: int,
        delta: float,
        weights: torch.Tensor,
        solver_cls: Type,
        swf_func: Callable,
        pow_val: float,
        objective: str,
        verbose: bool = True,
) -> Dict:
    """
    Runs multiple experiments in parallel for a given set of parameters.

    Args:
        seeds (List[int]): List of random seeds for reproducibility.
        n_arms (int): Number of arms in the bandit problem.
        n_rounds (int): Number of rounds to run the experiment for.
        num_alloc (int): Number of allocations to perform.
        delta (float): Parameter for the solver.
        weights (torch.Tensor): Weights for the social welfare function.
        solver_cls (Type): Class of the solver to use.
        swf_func (Callable): Social welfare function.
        pow_val (float): Power value for the SWF.
        objective (str): Objective type for the experiment.
        verbose (bool, optional): Whether to print progress. Defaults to True.

    Returns:
        Dict: Results of the experiments.
    """
    instance_gen = torch.Generator().manual_seed(instance_seed)
    arm_set = BetaArmSet(n_arms=n_arms, gen=instance_gen)
    args_list = [
        (
            seed,
            arm_set,
            n_rounds,
            num_alloc,
            delta,
            weights,
            solver_cls,
            swf_func,
            pow_val,
            objective,
            verbose,  # Disable verbose for individual experiments
        )
        for seed in seeds
    ]
    with mp.Pool(processes=mp.cpu_count()) as pool:
        results = list(pool.starmap(run_single_experiment, args_list))

    regret_sums = torch.tensor([res['regret_sum'] for res in results], dtype=torch.float64)
    mean_regret = regret_sums.mean(dim=0)
    std_regret = regret_sums.std(dim=0)
    
    res_dict = {
        "regret_sums": regret_sums.tolist(),
        "mean_regret": mean_regret,
        "std_regret": std_regret,
    }
    return res_dict



def main(weight_type: str, objectives: str):
    instance_seed = 42

    # Experiment parameters
    n_arms = 50
    n_rounds = 10_000
    num_alloc_ls = [1, 5, 10, 25, 40]
    delta = 0.05


    if weight_type == "linear":
        weights = torch.flip(torch.linspace(1.0, 2.0, n_arms, dtype=torch.float64), dims=[0])
    else:
        weights = 0.9 ** torch.arange(n_arms, dtype=torch.float64)
    weights = weights / weights.sum()

    seeds = list(range(5))


    dir_id = generate_run_id(f"trends_pow_{weight_type}")
    run_dir = Path(f"runs/{dir_id}")

    if 'w' in objectives:
        pow_val_ls = torch.linspace(-5.0, 1.0, steps=13)
        pow_val_ls = torch.cat((torch.tensor(-torch.inf).unsqueeze(0), pow_val_ls), dim=0)

        print(f"Running experiments for WPM")
        for num_alloc in num_alloc_ls:
            print(f"Running experiments for num_alloc = {num_alloc}")
            res_wpm = {
                "pow_val_ls": pow_val_ls.tolist(),
                "regret_sums": [],
                "mean_regret": [],
                "std_regret": [],
            }
            for pow_val in pow_val_ls:
                print(f"Running experiments for pow_val = {pow_val.item()}")
                res = run_parallel_experiments(
                    instance_seed=instance_seed,
                    seeds=seeds,
                    n_arms=n_arms,
                    n_rounds=n_rounds,
                    num_alloc=num_alloc,
                    delta=delta,
                    weights=weights,
                    solver_cls=WPMSolver,
                    swf_func=wpm_swf,
                    pow_val=pow_val.item(),
                    objective="wpm",
                    verbose=True,
                )
                res_wpm["regret_sums"].append(res["regret_sums"])
                res_wpm["mean_regret"].append(res["mean_regret"].item())
                res_wpm["std_regret"].append(res["std_regret"].item())

            print(f"{'Power Value':<12} | {'Mean Regret':<15} | {'Std Dev':<15}")
            print("-" * 46)
            for pow_val, m, s in zip(pow_val_ls, res_wpm['mean_regret'], res_wpm['std_regret']):
                print(f"{pow_val:<12.4f} | {m:<15.4f} | {s:<15.4f}")

            id = generate_run_id(f"wpm")
            config = {
                "instance_seed": instance_seed,
                "n_arms": n_arms,
                "n_rounds": n_rounds,
                "num_alloc": num_alloc,
                "delta": delta,

                "objective": "wpm",
                "weights": weights.tolist(),
                "pow_val_ls": pow_val_ls,
            }
            log_experiment(config=config, results=res_wpm, run_id=id, run_dir=run_dir)
              

    if 'k' in objectives:
        pow_val_ls = torch.linspace(-5.0, 0.0, steps=11)
        pow_val_ls = torch.cat((torch.tensor(-torch.inf).unsqueeze(0), pow_val_ls), dim=0)

        print(f"Running experiments for Kolm")
        for num_alloc in num_alloc_ls:
            print(f"Running experiments for num_alloc = {num_alloc}")
            res_kolm = {
                "pow_val_ls": pow_val_ls.tolist(),
                "regret_sums": [],
                "mean_regret": [],
                "std_regret": [],
            }
            for pow_val in pow_val_ls:
                print(f"Running experiments for pow_val = {pow_val.item()}")
                res = run_parallel_experiments(
                    instance_seed=instance_seed,
                    seeds=seeds,
                    n_arms=n_arms,
                    n_rounds=n_rounds,
                    num_alloc=num_alloc,
                    delta=delta,
                    weights=weights,
                    solver_cls=KolmSolver,
                    swf_func=kolm_swf,
                    pow_val=pow_val.item(),
                    objective="kolm",
                    verbose=True,
                )
                res_kolm["regret_sums"].append(res["regret_sums"])
                res_kolm["mean_regret"].append(res["mean_regret"].item())
                res_kolm["std_regret"].append(res["std_regret"].item())

            print(f"{'Power Value':<12} | {'Mean Regret':<15} | {'Std Dev':<15}")
            print("-" * 46)
            for pow_val, m, s in zip(pow_val_ls, res_kolm['mean_regret'], res_kolm['std_regret']):
                print(f"{pow_val:<12.4f} | {m:<15.4f} | {s:<15.4f}")

            id = generate_run_id(f"kolm")
            config = {
                "instance_seed": instance_seed,
                "n_arms": n_arms,
                "n_rounds": n_rounds,
                "num_alloc": num_alloc,
                "delta": delta,

                "objective": "kolm",
                "weights": weights.tolist(),
                "pow_val_ls": pow_val_ls,
            }
            log_experiment(config=config, results=res_kolm, run_id=id, run_dir=run_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run experiments to analyze regret trends over time.")
    parser.add_argument("--weight-type", type=str, choices=["linear", "geometric"], default="exponential",)
    parser.add_argument("--objectives", type=str)
    args = parser.parse_args()
    weight_type = args.weight_type
    objectives = args.objectives
    main(weight_type, objectives)
