from __future__ import annotations

import argparse
import os
import pickle
import time
import warnings
import traceback
from typing import Dict, List

import numpy as np

def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
    log = f"{filename}:{lineno}: {category.__name__}: {message}\n"
    log += ''.join(traceback.format_stack())
    print(log)

warnings.showwarning = warn_with_traceback

try:
    from plot import plot_regret, plot_times
    HAVE_PLOT = True
except Exception:
    HAVE_PLOT = False

from NSContextBandit import NSContextBandit
from environment_context import EnvironmentContext

from DAL_CBs import RegCB, SquareCB, Cover, DALContext
from ADAILTCBp import ADAILTCBPlusBandit, LinearArgmaxPolicy

from MASTERContext import MASTERContext

def make_dab_regcb(K: int, T: int, noise_variance:float, seed: int) -> DALContext:
    def factory():
        return RegCB(num_actions=K, horizon=T, mode="elimination", seed=seed)
    alg = DALContext(T=T, delta=1 / np.sqrt(T), noise_variance=noise_variance, base_factory=factory)
    alg.name = "DAB(RegCB)"
    return alg

def make_dab_square(K: int, T: int, noise_variance:float, seed: int) -> DALContext:
    def factory():
        return SquareCB(num_actions=K, horizon=T,
                        gamma_scale=200.0, gamma_exponent=1,
                        elim=False, mellowness=0.005, seed=seed)
    alg = DALContext(T=T, delta=1 / np.sqrt(T), noise_variance=noise_variance, base_factory=factory)
    alg.name = "DAB(Square)"
    return alg

def make_dab_cover_nu(K: int, T: int, noise_variance:float, seed: int) -> DALContext:
    def factory():
        return Cover(num_actions=K, horizon=T, m=24, psi=1.0, nounif=True, seed=seed, cb_type="mtr")
    alg = DALContext(T=T, delta=1 / np.sqrt(T), noise_variance=noise_variance, base_factory=factory)
    alg.name = "DAB(Cover-NU)"
    return alg

def make_ada_finite(K: int, d_ctx: int, T: int, P: int, seed: int) -> ADAILTCBPlusBandit:
    rng = np.random.default_rng(seed)
    bank = [LinearArgmaxPolicy(rng.normal(size=(K, d_ctx))) for _ in range(P)]
    alg = ADAILTCBPlusBandit(
        num_actions=K, horizon=T, erm_mode="finite",
        policies=bank, featurize_ctx=None, delta=0.05
    )
    alg.init_params = dict(
        num_actions=K, horizon=T, erm_mode="finite",
        policies=bank, featurize_ctx=None, delta=0.05
    )
    alg.name = "ADA-ILTCB+ finite Π"
    return alg

def make_ada_vw(K: int, T: int) -> ADAILTCBPlusBandit:
    alg = ADAILTCBPlusBandit(
        num_actions=K, horizon=T, erm_mode="vw",
        vw_erm_kwargs=dict(passes=1, l2=1e-8, learning_rate=0.5, interactions=("ax",)),
        featurize_ctx=None, delta=0.05
    )
    alg.init_params = dict(
        num_actions=K, horizon=T, erm_mode="vw",
        vw_erm_kwargs=dict(passes=1, l2=1e-8, learning_rate=0.5, interactions=("ax",)),
        featurize_ctx=None, delta=0.05
    )
    alg.name = "ADA-ILTCB+ VW-ERM"
    return alg


def make_master_ctx(K: int, d_ctx: int, T: int, P: int, seed: int) -> MASTERContext:
    def base_factory():
        return Cover(num_actions=K, horizon=T, m=8, psi=1.0, nounif=False, seed=seed)

    rng = np.random.default_rng(seed + 123)
    policies = [LinearArgmaxPolicy(rng.normal(size=(K, d_ctx))) for _ in range(P)]

    alg = MASTERContext(
        num_actions=K,
        horizon=T,
        delta=0.05,
        base_factory=base_factory,
        policies=policies,
        c2=2.0,
        seed=seed
    )
    alg.init_params = dict(
        num_actions=K,
        horizon=T,
        delta=0.05,
        base_factory=base_factory,
        policies=policies,
        c2=2.0,
        seed=seed
    )
    alg.name = "MASTER (Cover base, finite Π)"
    return alg



def main():
    parser = argparse.ArgumentParser(description="Run contextual bandit experiment (shared env per run).")
    parser.add_argument("--T", type=int, required=True, help="Horizon length")
    parser.add_argument("--xi", type=float, required=True, help="Hazard exponent; hazard = T^(-xi). Use 0 for none.")
    args = parser.parse_args()

    K = 100
    d_ctx = 10
    n_mc = 15
    noise_variance = 0.01
    reward_bound = 1.0
    P_bank = 100
    continuous = False
    seed = np.random.randint(0, 1000)
    seed = 205
    np.random.seed(seed)

    T = int(args.T)
    xi = float(args.xi)
    change_prob = (T ** (-xi)) if xi != 0.0 else 0.0
    bandit = NSContextBandit(
        num_actions=K,
        noise_variance=noise_variance,
        d=d_ctx,  
        d_c=d_ctx,
        reward_bound=reward_bound,
        continuous=continuous,
        seed=seed,
        context_mode='finite'
    )

    algorithms: List[object] = []
    algorithms.append(make_dab_square(K, T, noise_variance, seed))
    algorithms.append(make_ada_vw(K, T))
    algorithms.append(make_master_ctx(K, d_ctx, T, P_bank, seed))

    env = EnvironmentContext(
        bandit=bandit,
        algorithms=algorithms,
        T=T,
        change_prob=None if continuous else change_prob,
        continuous=continuous,
        reward_bound=reward_bound,
    )

    t0 = time.time()
    avg_regret, std_regret, avg_timings, avg_detections, avg_detection_delays = env.run_experiment(n_mc=n_mc)
    print('Total running time: ', time.time() - t0)
    if not os.path.exists('results'):
        os.makedirs('results')

    if not continuous:
        file_name = f'{bandit.__class__.__name__}_T_{T}_d_c_{d_ctx}_A_{K}_N_{n_mc}_xi_{xi}_seed_{seed}'
    else:
        file_name = f'{bandit.__class__.__name__}_T_{T}_d_c_{d_ctx}_A_{K}_N_{n_mc}_cont_seed_{seed}'

    with open(f'results/{file_name}.pkl', 'wb') as f:
        pickle.dump({
            'avg_regret': avg_regret,
            'std_regret': std_regret,
            'avg_timings': avg_timings,
            'avg_detections': avg_detections,
            'avg_detection_delays': avg_detection_delays
        }, f)

    
    plot_regret(avg_regret, std_regret, T, exp_name=file_name)
    plot_times(avg_timings, exp_name=file_name)
    

if __name__ == "__main__":
    main()
