import torch
import numpy as np
from math import ceil, log2
import pandas as pd
import matplotlib.pyplot as plt
from .basic_bandit import BasicBandit
from src.bandit.contextual.neural_ucb import Network
from sklearn.preprocessing import StandardScaler

device = 'cuda' if torch.cuda.is_available() else 'cpu'

### only for debugging ###
df = pd.read_csv("./results/MO/xsum_WhiteBox_Llama3/prompts_43.csv")
true_mean = np.asarray([list(df['mean_scores_rouge']), list(df['mean_scores_brevity'])]).T
##########################

# ------------------------------------------------------------------
# A helper function that performs OLS (using pseudo-inverse if needed)
# ------------------------------------------------------------------
def fit_ols(X, y, reg=1e-12):
    """
    Fits Theta by least squares, including pseudo-inverse if needed.
    X shape: (n_samples, h)
    y shape: (n_samples, num_objectives)  # if multi-objective
    Returns: Theta of shape (h, num_objectives)
    """
    XtX = X.T @ X
    inv_ = np.linalg.pinv(XtX + reg * np.eye(X.shape[1]))
    Theta_hat = inv_ @ X.T @ y
    return Theta_hat

def compute_empirical_pareto(means):
    """
    Given an array of shape (num_arms, num_objectives) 'means' (estimated means),
    return the set of arm indices that are 'empirically' Pareto optimal:
      S_r = { i : not exists j s.t. means[i] < means[j] in all coords }
    (We use strict coordinate-wise comparison for i ≺ j.)
    """
    num_arms, num_objectives = means.shape
    is_dominated = np.zeros(num_arms, dtype=bool)
    for i in range(num_arms):
        for j in range(num_arms):
            if i == j:
                continue
            # Check if means[i] is strictly dominated by means[j]
            # "i ≺ j" means means[i](c) < means[j](c) for ALL c
            if np.all(means[i] < means[j]):
                is_dominated[i] = True
                break
    return np.where(is_dominated == False)[0]

def min_gap_estimates(means):
    """
    Compute the empirical sub-optimality gap for each arm (Equation (5) style).
    
    means: array of shape (num_arms, num_objectives)
      Empirical means of each arm in the current round.
    
    Returns
    -------
    gaps: np.ndarray of shape (num_arms,)
      The estimated gap for each arm.
      
    Notation used:
      M_ij[i, j] = max_c [ means[i, c] - means[j, c] ]
      m_ij[i, j] = min_c [ means[j, c] - means[i, c] ]
      
      If arm i is in the empirical Pareto set, we compute
         gaps[i] = min_{j != i} ( M_ij[i,j]  ∧  [max(0, M_ij[j,i]) + max(0, subopt_gap[j])] ).
      Otherwise (emp. suboptimal),
         gaps[i] = max_j m_ij[i, j].
    """
    num_arms, d = means.shape
    
    M_ij = np.zeros((num_arms, num_arms))
    m_ij = np.zeros((num_arms, num_arms))
    for i in range(num_arms):
        for j in range(num_arms):
            diff = means[i] - means[j]  # shape (d,)
            M_ij[i, j] = np.max(diff)
            m_ij[i, j] = np.min(-diff)  # = min( means[j] - means[i] )
    

    # Identify empirical Pareto set
    pareto_local_inds = compute_empirical_pareto(means)
    pareto_mask = np.zeros(num_arms, dtype=bool)
    pareto_mask[pareto_local_inds] = True
    
    subopt_gap = np.zeros(num_arms)
    for j in range(num_arms):
        subopt_gap[j] = np.max(m_ij[j, :])   # max over k
    
    gaps = np.zeros(num_arms)
    for i in range(num_arms):
        if not pareto_mask[i]:
            # Empirically sub-optimal => use "subopt gap" definition
            gaps[i] = subopt_gap[i]
        else:
            # Empirically Pareto => eq. (5) "delta^*_i"
            best_val = np.inf
            for j in range(num_arms):
                if j == i:
                    continue
                left = M_ij[i, j]  # M(i,j)
                # M(j,i)^+ = max(0, M_ij[j, i])
                # (Delta^*_j)^+ ~ subopt_gap[j]^+ if j is sub-opt, else we do similarly
                # but typically we do max(0, subopt_gap[j]) anyway
                right = max(0.0, M_ij[j, i]) + max(0.0, subopt_gap[j])
                val = min(left, right)
                if val < best_val:
                    best_val = val
            # If there's only one arm, we get best_val = inf => set 0
            gaps[i] = best_val if np.isfinite(best_val) else 0.0
    
    return gaps

def frank_wolfe_gopt(X, max_iter=20, reg=1e-12):
    """
    Approximate G-optimal design using a simple Frank-Wolfe style approach.
    """
    n_active, h = X.shape
    w = np.ones(n_active) / n_active

    for t in range(max_iter):
        
        M = X.T @ (w[:, None] * X) + reg * np.eye(h)
        M_inv = np.linalg.pinv(M)
        variances = np.sum((X @ M_inv) * X, axis=1)
        i_star = np.argmax(variances)
        
        e_i_star = np.zeros(n_active)
        e_i_star[i_star] = 1.0
        
        alpha = 1.0 / (t + 2.0)
        w = (1 - alpha) * w + alpha * e_i_star

    return w

def integer_round_design(w, N):
    """
    Convert a distribution w into an integer allocation of size N.
    """
    n_active = w.shape[0]
    w = np.clip(w, 0.0, 1.0)
    w = w / max(w.sum(), 1e-15)

    raw = N * w
    floor_alloc = np.floor(raw).astype(int)
    total_floor = np.sum(floor_alloc)
    leftover = N - total_floor

    fractional_parts = raw - floor_alloc
    inds_desc = np.argsort(-fractional_parts)
    alloc = floor_alloc.copy()
    for i in inds_desc[:leftover]:
        alloc[i] += 1
    return alloc

class MLP_EGEFixedBudget(BasicBandit):
    """
    Implements Algorithm 2 from the paper: GEGE (G-optimal Empirical Gap Elimination)
    for the fixed-budget setting, using G-optimal design.
    """
    def __init__(self, num_arms, T, features, num_objectives, hidden_size=30, hidden_depth=1):
        """
        num_arms: int
          Number of arms K
        T: int
          Total budget
        features: np.ndarray of shape (K, h)
          Feature vectors of each arm, in R^h
        num_objectives: int
          Number of the rewards
        """
        super().__init__(num_arms)
        self.num_arms = num_arms
        self.T = T
        self.X = features         # shape (K, h)
        self.num_objectives = num_objectives  # dimension of each reward
        self.h = self.X.shape[1]  # dimension of features
        print('dimension is ', self.h)
        self.log_h = ceil(log2(self.h)) if self.h > 1 else 1
        # self.log_h = ceil(log2(self.num_arms))
        self.hidden_size = hidden_size
        self.hidden_depth = hidden_depth
        self.mlp = Network(input_dim=self.h, output_dim=self.num_objectives, hidden_size=self.hidden_size, depth=self.hidden_depth).to(device)
        self.loss_fn = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=0.05)
        
        # Call reset to initialize all other attributes
        self.reset()

    def reset(self):
        """
        Reset the bandit to rerun from scratch.
        """
        self.round = 1
        self.max_rounds = self.log_h
        self.budget_per_round = self.T // max(1, self.max_rounds)
        self.active_arms = list(range(self.num_arms))  # A_r
        self.B = []  # set of arms classified as "Pareto-optimal" (and removed)
        self.D = []  # set of arms classified as sub-optimal
        self.current_pulls = []
        self.current_rewards = []
        self.current_design_list = []
        self.design_index = 0
        self.current_round_pulls = 0
        self.t = 0
        self.Theta_hat_r = None
        self.arm_rewards = np.zeros([self.num_arms, self.num_objectives])
        self.arm_pulls = np.zeros(self.num_arms, dtype=np.int32) + 1e-12

    def choose_action(self):
        """
        Return the next arm to pull according to the G-optimal design for the current round.
        If we have exhausted all rounds, we return None.
        """
        if self.round > self.max_rounds or len(self.active_arms) <= 1:
            print('None 1')
            return None
        
        if len(self.current_design_list) == 0 and self.current_round_pulls == 0:
            print('start build round')
            self._build_round_design()
        
            # print('current round pulls: ', self.current_round_pulls)
            # print('budget per round: ', self.budget_per_round)
            # print('current design list: ', self.current_design_list)
            # print('active arms: ', self.active_arms)
            # print('round: ', self.round)
            # print('max rounds: ', self.max_rounds)
            
            
        if self.current_round_pulls >= self.budget_per_round:
            self._end_round_elimination()
            if self.round > self.max_rounds or len(self.active_arms) <= 1:
                if len(self.active_arms) == 1:
                    return self.active_arms[0]
                print('None 2')
                return None
            self._build_round_design()

        if self.design_index < len(self.current_design_list):
            arm = self.current_design_list[self.design_index]
            self.design_index += 1
                
            return arm
        else:
            print('None 3')
            return None

    def update(self, play, reward):
        """
        We log the chosen arm and the observed reward for this step.
        reward can be a vector or dictionary. For a d-dimensional reward,
        store as a length-d array or similar.
        """
        if play is None:
            return
        self.current_pulls.append(play)
        if isinstance(reward, dict):
            vect = np.array(list(reward.values()), dtype=float)
        elif isinstance(reward, (list, tuple)):
            vect = np.array(reward, dtype=float)
        else:
            vect = np.array([reward], dtype=float)
        self.current_rewards.append(vect)
        
        # Save rewards for each arm
        self.arm_rewards[play] += vect
        self.arm_pulls[play] += 1
        
        self.t += 1
        self.current_round_pulls += 1
        
        if self.round == self.max_rounds and self.current_round_pulls == self.budget_per_round:
            self._end_round_elimination()

    def best_arm(self):
        """
        Returns the final "Pareto set" = B_{final} union A_{final} 
        after all rounds end. If the algorithm hasn't ended yet,
        it returns the partial classification.
        """
        return self.B + self.active_arms

    @property
    def pareto_front(self):
        """
        Returns the final Pareto front, which is the union of:
          - Remaining active arms
          - Arms classified as Pareto-optimal (B)
        """
        return set(self.active_arms).union(self.B)

    def _build_round_design(self, uniform_design=True):
        """
        Build the G-optimal design for the current set of active arms Ar.
        If `uniform_design` is True, pulls each arm uniformly.
        """
        self.current_design_list = []
        self.design_index = 0
        self.current_round_pulls = 0

        if len(self.active_arms) == 0:
            return

        if uniform_design:
            # Uniform pulling design: allocate budget evenly across all active arms
            self.current_rewards = [reward for arm, reward in zip(self.current_pulls, self.current_rewards) if arm in self.active_arms]
            self.current_pulls = [arm for arm in self.current_pulls if arm in self.active_arms]
            count = [0] * len(self.active_arms)
            for arm in self.current_pulls:
                index = self.active_arms.index(arm)
                count[index] += 1
            
            
            num_active_arms = len(self.active_arms)
            pulls_per_arm = self.budget_per_round // num_active_arms
            leftover_pulls = self.budget_per_round % num_active_arms

            alloc = [pulls_per_arm] * num_active_arms
            for i in range(leftover_pulls):
                index = np.argmin(count)
                alloc[index] += 1
                count[index] += 1

            print("Uniform allocation:", alloc)

            design_list = []
            for i, count in enumerate(alloc):
                design_list.extend([self.active_arms[i]] * count)
            self.current_design_list = design_list
            print("Uniform design list:", design_list)

        else:
            # G-optimal design
            X_active = self.X[self.active_arms]  # shape (n_active, h)
            w = frank_wolfe_gopt(X_active, max_iter=100)
            alloc = integer_round_design(w, self.budget_per_round)
            print("G-optimal allocation:", alloc)

            design_list = []
            for i, count in enumerate(alloc):
                design_list.extend([self.active_arms[i]] * count)
            self.current_design_list = design_list
            print("G-optimal design list:", design_list)

            self.current_pulls = []
            self.current_rewards = []

    def _end_round_elimination(self):
        """
        Called when the current round's budget is fully used.
        """
        print('elimination')
        print('active: ', self.active_arms)
        count = {x: 0 for x in self.active_arms}
        for x in self.current_pulls:
            if x in self.active_arms:
                count[x] += 1
        print('count in current pulls:', count)
        if len(self.current_pulls) == 0:
            self.round += 1
            return

        # 1) Fit
        n = len(self.current_pulls)
        pulled_features = self.X[self.current_pulls]
        rewards_matrix = np.vstack(self.current_rewards)
        # self.scaler = StandardScaler().fit(self.X[self.current_pulls])
        # X_scaled = self.scaler.transform(pulled_features)
        feature_sample = torch.tensor(pulled_features, dtype=torch.float32).to(device) # X_scaled
        reward_sample = torch.tensor(rewards_matrix, dtype=torch.float32).to(device)
        self.mlp = Network(input_dim=self.h, output_dim=self.num_objectives, hidden_size=self.hidden_size, depth=self.hidden_depth).to(device)
        self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=0.01, weight_decay=0.002 * ((n / 150) ** (-0.8)))
        print('End of round: ' , self.round, ', Weight decay: ', 0.002 * ((n / 150) ** (-0.5) ))
        loss_hist = []
        for epoch in range(10000):
            output = self.mlp(feature_sample)
            loss = self.loss_fn(output, reward_sample)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            loss_hist.append(loss.item())
        # plt.figure()
        # plt.plot(loss_hist[100:])
        # plt.savefig(f'./results/MO/xsum_WhiteBox_Llama3/loss_{self.round}.png')
        # plt.close()
        

        # 2) Compute estimated means & gaps
        X_active = self.X[self.active_arms]
        means_est = self.mlp(
            torch.tensor(X_active, dtype=torch.float32).to(device)
        ).detach().cpu().numpy()  # shape (n_active, d)
        
        ### Debugging ###
        # print("estimated means", means_est)
        # empirical = self.arm_rewards[self.active_arms] / self.arm_pulls[self.active_arms][:, None]
        # print("empirical means", empirical)
        # print("true means", true_mean[self.active_arms])
        # bound = np.max(np.abs(true_mean[self.active_arms]))
        # print("average error:", np.mean(np.abs(means_est - true_mean[self.active_arms])) / bound)
        # print("average error (empirical):", np.mean(np.abs(empirical - true_mean[self.active_arms])) / bound)
        # if self.round:
        #     plt.figure()
        #     plt.scatter(true_mean[self.active_arms][:, 0], true_mean[self.active_arms][:, 1], color='blue', alpha=0.3, marker=".", label="True")
        #     plt.scatter(empirical[:, 0], empirical[:, 1], color='red', alpha=0.3, marker=".", label="Empirical")
        #     plt.scatter(means_est[:, 0], means_est[:, 1], color='green', alpha=0.3, marker=".", label="Estimated")
        #     # not_in_active = [x if not (x in self.active) for x in range(self.num_arms)]
        #     plt.scatter(true_mean[self.B][:, 0], true_mean[self.B][:, 1], color='black', alpha=0.1, marker="+", label="Not Active, Pareto")
        #     plt.scatter(true_mean[self.D][:, 0], true_mean[self.D][:, 1], color='black', alpha=0.1, marker=".", label="Not Active, Suboptimal")
        #     plt.xlabel("Objective 1")
        #     plt.ylabel("Objective 2")
        #     plt.title("2D Rewards Visualization")
        #     plt.legend()
        #     plt.savefig(f'./results/MO/xsum_WhiteBox_Llama3/2D_rewards_{self.round}.png')
        #     plt.close()
        ### End of debugging ###
        
        gaps = min_gap_estimates(means_est)
        # print(gaps)

        # 3) Determine which arms are empirically optimal
        pareto_local_inds = compute_empirical_pareto(means_est)
        # print(pareto_local_inds)
        pareto_mask = np.zeros(len(self.active_arms), dtype=bool)
        pareto_mask[pareto_local_inds] = True

        # 4) Sort arms by ascending gap, with empirically optimal arms prioritized
        indices = list(range(len(self.active_arms)))

        def sort_key(i):
            return (gaps[i], 0 if pareto_mask[i] else 1)

        indices.sort(key=sort_key)

        # 5) Decide how many to keep
        keep_count = int(np.ceil(self.h / (2 ** (self.round - 1))))
        keep_count = min(keep_count, len(indices))  # Clamp to avoid exceeding active arms
        # keep_count = int(np.ceil(self.num_arms / (2 ** self.round)))
        # keep_count = min(keep_count, len(indices))  # Clamp to avoid exceeding active arms

        keep_local = indices[:keep_count]
        remove_local = indices[keep_count:]

        # 6) Build new active arms, separate removed ones into B vs D
        new_active_arms = [self.active_arms[i] for i in keep_local]
        pareto_set_global = [self.active_arms[idx] for idx in pareto_local_inds]

        for idx in remove_local:
            arm_removed = self.active_arms[idx]
            if arm_removed in pareto_set_global:
                self.B.append(arm_removed)
                # print(arm_removed, 'is Pareto-optimal')
            else:
                self.D.append(arm_removed)
                # print(arm_removed, 'is sub-optimal')

        self.active_arms = new_active_arms
        # print('left active arms: ', self.active_arms)
        self.round += 1
        
        return
