from util import *
from time import time
from tqdm import tqdm
import numpy as np

def evaluate_one(Alg, params, env, horizon):
    alg = Alg(env, horizon, params)

    regret = np.zeros(horizon)
    for t in tqdm(range(horizon), ncols=100, leave=False):
        env.randomize()
        alg.get_arms(env.X)
        arm = alg.get_arm(t)
        alg.update(t, arm, env.reward(arm))
        regret[t] = env.regret(arm)
    return regret, alg

def evaluate_compare(Alg, params, envs, horizon=1000, printout=True):
    name = Alg.__name__
    if printout:
        print(f"Evaluating {name}", end=" ")

    t0 = time()
    num_eps = len(envs)
    regrets_mat = np.zeros((horizon, num_eps))
    alg_list    = [None] * num_eps

    for i, env in enumerate(tqdm(envs, desc=name, ncols=100, unit="episode")):
        env.reset_random(seed=(env.seed or 0) + i)
        step_regret, alg_inst = evaluate_one(Alg, params, env, horizon)
        regrets_mat[:, i] = step_regret
        alg_list[i]       = alg_inst

    ep_cum   = regrets_mat.sum(axis=0)
    mean_reg = ep_cum.mean()
    std_err  = ep_cum.std(ddof=1) / np.sqrt(num_eps)

    if printout:
        print(
            f"Regret: {mean_reg:.2f} ± {std_err:.2f}  "
            f"(median {np.median(ep_cum):.2f}, max {ep_cum.max():.2f})"
        )
        print(f"Elapsed: {time() - t0:.1f} s")
        print("-" * 80)

    return mean_reg, std_err, ep_cum

def evaluate(Alg, params, envs, horizon=1000, printout=True):
    if printout:
        print(f"Evaluating {Alg.__name__}", end=" ")

    start = time()
    num_episode = len(envs)
    regret = np.zeros((horizon, num_episode))
    algs   = [None] * num_episode

    for i, env in enumerate(tqdm(envs, desc=Alg.__name__, ncols=100, unit="iter")):
        env.reset_random(seed=(env.seed or 0) + i)
        step_regret, alg_inst = evaluate_one(Alg, params, env, horizon)
        regret[:, i] = step_regret
        algs[i]      = alg_inst

    if printout:
        total = regret.sum(axis=0)
        mean_reg = total.mean()
        std_err  = total.std(ddof=1) / np.sqrt(num_episode)
        print(f"Regret: {mean_reg:.2f} ± {std_err:.2f}  "
              f"(median {np.median(total):.2f}, max {total.max():.2f})")
        print(f"{time() - start:.1f} seconds")
        print("+" * 80)

    return regret, algs


class ArmSet:
    def __init__(self, d, K, norm_X=1.0):
        self.d = d
        self.K = K
        self.norm_X = norm_X  # desired L2 norm of each arm vector

    def generate(self, rng):
        # sample isotropic Gaussian, normalize, then scale to norm_X
        X = rng.normal(0, 1, (self.K, self.d))
        lengths = np.linalg.norm(X, axis=1, keepdims=True)
        X = X / lengths * self.norm_X
        return X

class Bandit:
    def __init__(self, d, K, C=1, arm_set_type='fixed', model="linear", norm_theta=1.0, norm_X=1.0, noise_var=1.0, seed=None):
        """
        d: feature dimension
        K: number of arms per context
        C: number of contexts (C=1 fixed)
        arm_set_type: 'fixed' or 'contextual'
        model: 'linear' or 'logistic'
        norm_theta: desired ||theta||₂\        norm_X: desired ||x||₂ for each arm
        noise_var: variance of Gaussian noise
        seed: RNG seed
        """
        self.d = d
        self.K = K
        self.C = C
        self.arm_set_type = arm_set_type
        self.model = model
        self.noise_var = noise_var
        self.seed = seed or 0
        self.norm_theta = norm_theta
        # RNG
        self.rng = np.random.default_rng(self.seed)
        # sample true parameter theta, scale to norm_theta
        raw_theta = self.rng.standard_normal(d)
        self.theta = raw_theta / np.linalg.norm(raw_theta) * norm_theta
        # generate C contexts of arms
        self.arm_sets = []
        for _ in range(C):
            arms = ArmSet(d, K, norm_X=norm_X).generate(self.rng)
            self.arm_sets.append(arms)
        self.arm_sets = np.stack(self.arm_sets)
        # precompute rewards means per context
        if model == "linear":
            self.mu_sets = self.arm_sets.dot(self.theta)
        else:
            logits = self.arm_sets.dot(self.theta)
            self.mu_sets = 1 / (1 + np.exp(-logits))
        self.best_arms = np.argmax(self.mu_sets, axis=1)
        # initialize
        self.count = 0
        self._set_context(0)

    def _set_context(self, idx):
        self.context_idx = idx
        self.X = self.arm_sets[idx]
        self.mu = self.mu_sets[idx]
        self.best_arm = int(self.best_arms[idx])

    def reset_random(self, seed=None):
        self.rng = np.random.default_rng(seed if seed is not None else self.seed)
        self.count = 0
        self._set_context(0)

    def randomize(self):
        # choose context
        if self.arm_set_type == 'contextual' and self.C > 1:
            idx = self.rng.integers(0, self.C)
            self._set_context(idx)
        # generate reward
        if self.model == "linear":
            noise = self.rng.normal(0, np.sqrt(self.noise_var), size=self.K)
            self.rt = self.mu + noise
        else:
            self.rt = (self.rng.random(self.K) < self.mu).astype(float)
        self.count += 1

    def reward(self, arm):
        return float(self.rt[arm])

    def regret(self, arm):
        return float(self.rt[self.best_arm] - self.rt[arm])

    def pregret(self, arm):
        return float(self.mu[self.best_arm] - self.mu[arm])

    def print(self):
        return f"{self.model.capitalize()} bandit: d={self.d}, K={self.K}, C={self.C}, norm_X={self.arm_sets.shape}"