import torch
import numpy as np
from math import ceil, log2
from .basic_bandit import BasicBandit
from src.bandit.contextual.neural_ucb import Network

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ------------------------------------------------------------------
# A helper function that performs OLS (using pseudo-inverse if needed)
# ------------------------------------------------------------------
def fit_ols(X, y, reg=1e-12):
    """
    Fits Theta by least squares, including pseudo-inverse if needed.
    X shape: (n_samples, h)
    y shape: (n_samples, num_objectives)  # if multi-objective
    Returns: Theta of shape (h, num_objectives)
    """
    XtX = X.T @ X
    inv_ = np.linalg.pinv(XtX + reg * np.eye(X.shape[1]))
    Theta_hat = inv_ @ X.T @ y
    return Theta_hat

def compute_empirical_pareto(means):
    """
    Given an array of shape (num_arms, num_objectives) 'means' (estimated means),
    return the set of arm indices that are 'empirically' Pareto optimal:
      S_r = { i : not exists j s.t. means[i] < means[j] in all coords }
    (We use strict coordinate-wise comparison for i ≺ j.)
    """
    num_arms, num_objectives = means.shape
    is_dominated = np.zeros(num_arms, dtype=bool)
    for i in range(num_arms):
        for j in range(num_arms):
            if i == j:
                continue
            # Check if means[i] is strictly dominated by means[j]
            # "i ≺ j" means means[i](c) < means[j](c) for ALL c
            if np.all(means[i] < means[j]):
                is_dominated[i] = True
                break
    return np.where(is_dominated == False)[0]

def frank_wolfe_gopt(X, max_iter=20, reg=1e-12):
    """
    Approximate G-optimal design using a simple Frank-Wolfe style approach.
    """
    n_active, h = X.shape
    w = np.ones(n_active) / n_active

    for t in range(max_iter):
        
        M = X.T @ (w[:, None] * X) + reg * np.eye(h)
        M_inv = np.linalg.pinv(M)
        variances = np.sum((X @ M_inv) * X, axis=1)
        i_star = np.argmax(variances)
        
        e_i_star = np.zeros(n_active)
        e_i_star[i_star] = 1.0
        
        alpha = 1.0 / (t + 2.0)
        w = (1 - alpha) * w + alpha * e_i_star

    return w

def integer_round_design(w, N):
    """
    Convert a distribution w into an integer allocation of size N.
    """
    n_active = w.shape[0]
    w = np.clip(w, 0.0, 1.0)
    w = w / max(w.sum(), 1e-15)

    raw = N * w
    floor_alloc = np.floor(raw).astype(int)
    total_floor = np.sum(floor_alloc)
    leftover = N - total_floor

    fractional_parts = raw - floor_alloc
    inds_desc = np.argsort(-fractional_parts)
    alloc = floor_alloc.copy()
    for i in inds_desc[:leftover]:
        alloc[i] += 1
    return alloc

class MLP_ConstrainedBandit(BasicBandit):
    """
    Implements a linear version of constarined bandit.
    """
    def __init__(
        self, 
        num_arms, 
        total_budget, 
        features, 
        num_objectives: int = 2, 
        constraints: list = [1.0],
        hidden_size=20, hidden_depth=1,
    ):
        """
        num_arms: int
          Number of arms K
        T: int
          Total budget
        features: np.ndarray of shape (K, h)
          Feature vectors of each arm, in R^h
        num_objectives: int
          Number of the rewards
        """
        super().__init__(num_arms)
        self.num_arms = num_arms
        self.T = total_budget
        self.X = features         # shape (K, h)
        self.num_objectives = num_objectives  # dimension of each reward
        # self.h = self.X.shape[1]  # dimension of features
        self.h = self.X.shape[0]  # number of arms
        self.log_h = ceil(log2(self.h)) if self.h > 1 else 1
        self.constraints = constraints
        self.hidden_size = hidden_size
        self.hidden_depth = hidden_depth
        
        self.mlp = Network(input_dim=self.h, output_dim=self.num_objectives, hidden_size=hidden_size, depth=hidden_depth).to(device)
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=0.05)
        
        # Call reset to initialize all other attributes
        self.reset()

    def reset(self):
        """
        Reset the bandit to rerun from scratch.
        """
        self.round = 1
        self.max_rounds = self.log_h
        self.budget_per_round = self.T // max(1, self.max_rounds)
        self.active_arms = list(range(self.num_arms))
        self.current_pulls = []
        self.current_rewards = []
        self.current_design_list = []
        self.design_index = 0
        self.current_round_pulls = 0
        self.t = 0
        self.Theta_hat_r = None

    def choose_action(self):
        """
        Return the next arm to pull according to the G-optimal design for the current round.
        If we have exhausted all rounds, we return None.
        """
        if self.round > self.max_rounds or len(self.active_arms) <= 1:
            print('None 1')
            return None
        
        if len(self.current_design_list) == 0 and self.current_round_pulls == 0:
            print('start build round')
            self._build_round_design()
        
            # print('current round pulls: ', self.current_round_pulls)
            # print('budget per round: ', self.budget_per_round)
            # print('current design list: ', self.current_design_list)
            # print('active arms: ', self.active_arms)
            # print('round: ', self.round)
            # print('max rounds: ', self.max_rounds)
            
            
        if self.current_round_pulls >= self.budget_per_round:
            self._end_round_elimination()
            if self.round > self.max_rounds or len(self.active_arms) <= 1:
                if len(self.active_arms) == 1:
                    return self.active_arms[0]
                print('None 2')
                return None
            self._build_round_design()

        if self.design_index < len(self.current_design_list):
            arm = self.current_design_list[self.design_index]
            self.design_index += 1
                
            return arm
        else:
            print('None 3')
            return None

    def update(self, play, reward):
        """
        We log the chosen arm and the observed reward for this step.
        reward can be a vector or dictionary. For a d-dimensional reward,
        store as a length-d array or similar.
        """
        if play is None:
            return
        self.current_pulls.append(play)
        if isinstance(reward, dict):
            vect = np.array(list(reward.values()), dtype=float)
        elif isinstance(reward, (list, tuple)):
            vect = np.array(reward, dtype=float)
        else:
            vect = np.array([reward], dtype=float)
        self.current_rewards.append(vect)
        
        self.t += 1
        self.current_round_pulls += 1
        
        if self.round == self.max_rounds and self.current_round_pulls == self.budget_per_round:
            self._end_round_elimination()

    def best_arm(self):
        return self.active_arms[0] if len(self.active_arms) > 0 else None

    def _build_round_design(self, uniform_design=True):
        """
        Build the G-optimal design for the current set of active arms Ar.
        If `uniform_design` is True, pulls each arm uniformly.
        """
        self.current_design_list = []
        self.design_index = 0
        self.current_round_pulls = 0

        if len(self.active_arms) == 0:
            return

        if uniform_design:
            # Uniform pulling design: allocate budget evenly across all active arms
            self.current_rewards = [reward for arm, reward in zip(self.current_pulls, self.current_rewards) if arm in self.active_arms]
            self.current_pulls = [arm for arm in self.current_pulls if arm in self.active_arms]
            count = [0] * len(self.active_arms)
            for arm in self.current_pulls:
                index = self.active_arms.index(arm)
                count[index] += 1
            
            
            num_active_arms = len(self.active_arms)
            pulls_per_arm = self.budget_per_round // num_active_arms
            leftover_pulls = self.budget_per_round % num_active_arms

            alloc = [pulls_per_arm] * num_active_arms
            for i in range(leftover_pulls):
                index = np.argmin(count)
                alloc[index] += 1
                count[index] += 1

            print("Uniform allocation:", alloc)

            design_list = []
            for i, count in enumerate(alloc):
                design_list.extend([self.active_arms[i]] * count)
            self.current_design_list = design_list
            print("Uniform design list:", design_list)

        else:
            # G-optimal design
            X_active = self.X[self.active_arms]  # shape (n_active, h)
            w = frank_wolfe_gopt(X_active, max_iter=100)
            alloc = integer_round_design(w, self.budget_per_round)
            print("G-optimal allocation:", alloc)

            design_list = []
            for i, count in enumerate(alloc):
                design_list.extend([self.active_arms[i]] * count)
            self.current_design_list = design_list
            print("G-optimal design list:", design_list)

            self.current_pulls = []
            self.current_rewards = []

    def _end_round_elimination(self):
        """
        Called when the current round's budget is fully used.
        """
        print('elimination')
        print('active: ', self.active_arms)
        if len(self.current_pulls) == 0:
            self.round += 1
            return

        # 1) Fit OLS on data from this round
        n = len(self.current_pulls)
        pulled_features = self.X[self.current_pulls]
        rewards_matrix = np.vstack(self.current_rewards)
        # self.scaler = StandardScaler().fit(self.X[self.current_pulls])
        # X_scaled = self.scaler.transform(pulled_features)
        feature_sample = torch.tensor(pulled_features, dtype=torch.float32).to(device) # X_scaled
        reward_sample = torch.tensor(rewards_matrix, dtype=torch.float32).to(device)
        self.mlp = Network(input_dim=self.X.shape[1], output_dim=self.num_objectives, hidden_size=self.hidden_size, depth=self.hidden_depth).to(device)
        self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=0.01, weight_decay=0.002 * ((n / 150) ** (-0.8)))
        print('End of round: ' , self.round, ', Weight decay: ', 0.002 * ((n / 150) ** (-0.5) ))
        loss_hist = []
        for epoch in range(10000):
            output = self.mlp(feature_sample)
            loss = self.loss_fn(output, reward_sample)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            loss_hist.append(loss.item())

        # 2) Compute estimated means & gaps
        X_active = self.X[self.active_arms]
        means = self.mlp(
            torch.tensor(X_active, dtype=torch.float32).to(device)
        ).detach().cpu().numpy()  # shape (n_active, d)

        # Build a quick check for feasibility in this round
        def is_feasible(idx):
            return means[idx, 1] >= self.constraints[0]

        # Identify which arms look feasible
        feasible_indices = [idx for idx in range(len(self.active_arms))
                            if is_feasible(idx)]

        # Pick J = the "empirical best" arm
        if len(feasible_indices) > 0:
            # Among feasible arms, pick one with highest objective
            # The argmax in objective dimension (0)
            feasible_objectives = means[feasible_indices, 0]
            best_feasible_idx = feasible_indices[np.argmax(feasible_objectives)]
            J = best_feasible_idx
        else:
            # No arm looks feasible; pick the one with smallest constraint dimension
            J = np.argmax(means[:, 1])

        # Now compute the gap delta(J, i) for each arm i
        gaps = np.zeros(len(self.active_arms))

        # Check whether J itself is feasible
        J_feasible = is_feasible(J)
        J_obj = means[J, 0]
        J_con = means[J, 1]
        tau = self.constraints[0]

        for i in range(len(self.active_arms)):
            if i == J:
                # We do not eliminate J itself, so set gap= -inf or just 0
                gaps[i] = -9999999.0
                continue

            i_obj = means[i, 0]
            i_con = means[i, 1]
            i_feasible = (i_con >= tau)

            if J_feasible:
                if i_feasible:
                    gaps[i] = J_obj - i_obj  
                else:
                    if i_obj <= J_obj:
                        # deceiver
                        gaps[i] = tau - i_con
                    else:
                        gaps[i] = max(J_obj - i_obj, tau - i_con)
            else:
                if not i_feasible:
                    gaps[i] = J_con - i_con
                else:
                    # i is feasible, J is infeasible
                    print('error: J is infeasible but i is feasible')

        # We eliminate arms with the largest gaps until the number of active arms is keep_count
        keep_count = int(np.ceil(self.h / (2 ** self.round)))
        keep_count = min(keep_count, len(self.active_arms))  # Clamp to avoid exceeding active arms

        
        print(f"Empirical best arm: {self.active_arms[J]} with reward {means[J]}")
        while len(self.active_arms) > keep_count:
            to_remove_idx = np.argmax(gaps)
            # print(f"Removing arm {self.active_arms[to_remove_idx]} with gap {gaps[to_remove_idx]} reward {means[to_remove_idx]}")
            # print(f"Is the Removing arm feasible? {is_feasible(to_remove_idx)}")
            arm_to_remove = self.active_arms[to_remove_idx]

            # Remove that arm from self.active_arms
            self.active_arms.remove(arm_to_remove)

            # Remove the corresponding gap
            gaps = np.delete(gaps, to_remove_idx)
            means = np.delete(means, to_remove_idx, axis=0)

        self.round += 1