import numpy as np
import torch
import torch.nn as nn
import copy as copy
import torchvision.transforms as transforms
import torch.nn.functional as F


class RMSELoss(nn.Module):
    def __init__(self, upper):
        super().__init__()
        self.mse = nn.MSELoss(reduce=False)
        self.upper = upper
    def forward(self, yhat, y):
        return torch.sqrt(self.mse(yhat, y)).clamp(0, self.upper)

class L2D_loss(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.n_agents = args.n_agents
        self.beta = args.beta
        self.args = args
        self.alpha = args.alpha
        self.surrogate_loss = nn.LogSoftmax(dim=1)
        self.K = np.arange(0, self.n_agents).tolist()
        self.Lambda = args.Lambda
        self.upper = args.upper
        self.normalizer_cardinality =  self.upper + self.Lambda * np.sum(self.beta)
        self.num_classes = args.num_classes
        self.rmse = RMSELoss(args.upper)


    def surrogate_01(self, outputs):
        return - self.surrogate_loss(outputs)

    def loss_terms(self, target, output):
        if self.args.task == 'classification':
            target_cls = target.long()[:,None]
            loss = 1 - (output == target_cls).long()
            return loss
        elif self.args.task == 'regression':
            target = target.repeat(output.size()[1], 1).T
            loss = self.rmse(output, target)
            return loss


    def get_cost(self, classification, target, expert = False, i=None):
        cls_loss = self.loss_terms(target, classification)
        if expert:
            cost = self.alpha * cls_loss + torch.tensor(self.beta, device=cls_loss.device)[:,None].repeat(1,len(target)).T
        else:
            cost = cls_loss

        if self.args.task == 'classification':
            return cost.detach()
        else:
            return cost

    def cost_tau(self, cost_agents, labels):
        # 1) Sum costs over all K agents for each sample:
        #    total_per_sample has shape [batch_size, 1]
        total_per_sample = cost_agents.sum(dim=1, keepdim=True)
        #    ↳ total_per_sample[i, 0] = Σ_{q=1}^K cost_agents[i, q]

        # 2) Subtract each agent’s cost from its sample’s total:
        #    tau_j has shape [batch_size, K]
        tau_j = total_per_sample - cost_agents
        #    ↳ tau_j[i, j] = total_per_sample[i, 0] - cost_agents[i, j]
        #                  = Σ_q cost_agents[i, q] - cost_agents[i, j]

        return tau_j

    def forward(self, outputs, labels, agents, images=None, model=None, val=None, attacked=None):
        if images is not None:
            cost_model = self.get_cost(agents[:,:self.num_classes], labels, expert = False)
            cost_experts = self.get_cost(agents[:,self.num_classes:], labels, expert = True)
            cost_agents = torch.concatenate((cost_model, cost_experts), dim=1)

            tau_j = self.cost_tau(cost_agents, labels)
            loss = self.surrogate_01(outputs)
            defer_loss = (tau_j * loss).sum(dim = 1).mean()
            return defer_loss
        else:
            return None

    def get_agent_cost(self, outputs, labels, agents, selector=None): # get the overall cost
        cost_model = self.get_cost(agents[:, :self.num_classes], labels, expert=False)
        cost_experts = self.get_cost(agents[:, self.num_classes:], labels, expert=True)
        cost_agents = torch.concatenate((cost_model, cost_experts), dim=1)
        return cost_agents

    def true_deferral_top_k(self, outputs, labels, agents, selector=None): # compute the l_def true loss
        cost_agents = self.get_agent_cost(outputs, labels, agents, selector=selector)

        id_deferral = torch.argsort(outputs, dim=1, descending=True)
        costs_top_k = torch.gather(cost_agents, 1, id_deferral)
        true_deferral_loss = costs_top_k.cumsum(dim=1)

        if selector is not None:
            selector_id = torch.argmax(selector, dim=1)
            true_deferral_loss = true_deferral_loss[torch.arange(len(labels)), selector_id]

        return true_deferral_loss.cpu().numpy(), id_deferral.cpu().numpy()


    def cost_cardinality(self, agents, labels, id_deferral):
        correct_list = []
        if self.args.task != 'regression':
            for k in range(agents.shape[1]):
                # Get the indices of the top (k+1) agents for each sample.
                ids_top_k = id_deferral[:, :k + 1]
                prediction_top_k = torch.gather(agents, 1, ids_top_k)
                correct_top_k = (prediction_top_k == labels.unsqueeze(1))
                correct_list.append(correct_top_k.any(dim=1).float())
            return 1 - torch.stack(correct_list, dim=1)
        else:
            for k in range(agents.shape[1]):
                # Get the indices of the top (k+1) agents for each sample.
                ids_top_k = id_deferral[:, :k + 1]
                prediction_top_k = torch.gather(agents, 1, ids_top_k)
                is_correct_joint = self.rmse(prediction_top_k, labels.unsqueeze(1).repeat(1, prediction_top_k.size(1)))
                is_correct = torch.min(is_correct_joint, dim=1)[0].squeeze()
                correct_list.append(is_correct.float())
            return torch.stack(correct_list, dim=1)

    def costs_for_cardinality_metric(self, outputs_rejector, inputs, outputs_selector, labels, agents):
        with torch.no_grad():
            id_deferral = torch.argsort(outputs_rejector, dim=1, descending=True)
            if self.args.decision == 'indicator':
                loss_metric = self.cost_cardinality(agents, labels, id_deferral)
            elif self.args.decision == 'majority':
                loss_metric = vote_majority_loss(id_deferral, labels, agents, selector=None, args=self.args, rmse=self.rmse)
            elif self.args.decision == 'weighted':
                loss_metric = vote_weighted_loss(outputs_rejector, labels, agents, selector=None, args=self.args, rmse=self.rmse)
            cum_sum_costs = self.querying_costs(id_deferral, labels)
            return loss_metric, cum_sum_costs

    def cardinality_aware(self, rejector, inputs, outputs_selector, labels, agents):
        rejector = rejector.eval()
        outputs_rejector = rejector(inputs).detach()
        loss_metric, cum_sum_costs = self.costs_for_cardinality_metric(outputs_rejector, inputs, None, labels, agents)
        # Compute the cardinality-aware loss
        loss = self.surrogate_01(outputs_selector)
        cost = 1 - (loss_metric + self.Lambda*cum_sum_costs)/self.normalizer_cardinality
        cardinality_loss = (cost * loss).sum(dim=1)
        return cardinality_loss.mean()

    def cardinality_aware_true(self, rejector, inputs, outputs_selector, labels, agents):
        rejector.eval()
        outputs_rejector = rejector(inputs).detach()
        loss_metric, cum_sum_costs = self.costs_for_cardinality_metric(outputs_rejector, inputs, None,
                                                                       labels, agents)

        s = argmax_last(outputs_selector, dim=1)
        # Compute the cardinality-aware loss
        cost = loss_metric + self.Lambda*cum_sum_costs
        cardinality_loss = cost[torch.arange(len(labels)), s]

        if self.args.task != 'regression':
            accuracy = 1 - loss_metric[torch.arange(len(labels)), s]
        else:
            accuracy = loss_metric[torch.arange(len(labels)), s]
        sum_consultation_cost = cum_sum_costs[torch.arange(len(labels)), s]

        id_deferral = torch.argsort(outputs_rejector, dim=1, descending=True)
        deferral_ratio = occurence(id_deferral, s, n_agents=self.n_agents + self.num_classes)

        return (cardinality_loss.cpu().numpy(), s.cpu().numpy(), accuracy.cpu().numpy(),
                sum_consultation_cost.cpu().numpy(), deferral_ratio)

    def querying_costs(self, id_deferral, labels):
        querying_cost = np.concatenate([np.ones(self.num_classes)*self.args.beta_model, np.array(self.beta)])[None, :].repeat(len(labels), axis=0)
        extracted_costs = torch.gather(torch.tensor(querying_cost).to('cuda'), 1, id_deferral)
        cum_sum_costs = extracted_costs.cumsum(dim=1)
        return cum_sum_costs


    def MRR(self, outputs_rejector, labels, agents):
        cost_agents = self.get_agent_cost(outputs_rejector, labels, agents)
        id_deferral = torch.argsort(outputs_rejector, dim=1, descending=True)
        costs_top_k = torch.gather(cost_agents, 1, id_deferral)
        MRR = torch.argmin(costs_top_k, dim=1).float()
        return MRR.cpu().numpy()

    def proba_top_k(self, outputs_rejector):
        id_deferral = torch.argsort(outputs_rejector, dim=1, descending=True)
        # difference between adjacent scores
        proba = torch.softmax(outputs_rejector, dim=1)
        gather = torch.gather(proba, 1, id_deferral)
        return gather.cpu().numpy()


def majority_vote_first_numpy(row_np):
    """
    Given a 1D tensor of class predictions,
    return the class with the majority vote.
    In case of a tie, return the class whose first occurrence is earliest.
    """
    # Convert the row to a NumPy array
    # Get unique classes along with their first occurrence indices and counts
    unique, first_indices, counts = np.unique(row_np, return_index=True, return_counts=True)

    # Determine the maximum count
    max_count = np.max(counts)
    # Mask for classes that have the max count (ties)
    tie_mask = counts == max_count
    tied_classes = unique[tie_mask]
    tied_first_indices = first_indices[tie_mask]

    # Choose the class whose first occurrence is the smallest index
    best_idx = np.argmin(tied_first_indices)
    return tied_classes[best_idx]

def vote_majority_loss(id_deferral, labels, agents, selector=None, args=None, rmse=None):
    with torch.no_grad():
        predictions = torch.gather(agents, 1, id_deferral)
        # vote majority
        pred_system = []
        for k in range(0, id_deferral.size(1)):
            prediction_k = predictions[:, :k + 1]
            if args.task == 'classification':
                prediction_k_np = prediction_k.cpu().numpy()
                predictions_vote_majority = torch.tensor(np.apply_along_axis(majority_vote_first_numpy, axis=1, arr=prediction_k_np), device=labels.device)
            else:
                predictions_vote_majority = prediction_k.mean(dim=1)
            pred_system.append(predictions_vote_majority)
        pred_system = torch.stack(pred_system, dim=1)
        if selector is not None:
            s = torch.argmax(selector, dim=1)
            pred_system = pred_system[torch.arange(len(labels)), s][:, None]
        if args.task == 'classification':
            vote_loss = 1 - (pred_system == labels.unsqueeze(1)).float()
        elif args.task == 'regression':
            vote_loss = rmse(pred_system, labels.unsqueeze(1).repeat(1, pred_system.size(1)))
        return vote_loss


def vote_weighted_loss(outputs_rejector, labels, agents, selector=None, args=None, rmse=None):
    with torch.no_grad():
        id_deferral = torch.argsort(outputs_rejector, dim=1, descending=True)
        predictions = torch.gather(agents, 1, id_deferral)
        ordered_score = torch.gather(outputs_rejector, 1, id_deferral)

        pred_system = []
        for k in range(0, id_deferral.size(1)):
            prediction_k = predictions[:, :k + 1]
            weight_k = torch.softmax(ordered_score[:,:k+1], dim=1)
            if args.task == 'classification':
                predictions_vote_weighted = compute_weighted_vote_vectorized(prediction_k, weight_k, n_class=args.num_classes)
            else:
                predictions_vote_weighted = (prediction_k * weight_k).sum(dim=1) / weight_k.sum(dim=1)
            pred_system.append(predictions_vote_weighted)
        pred_system = torch.stack(pred_system, dim=1)

        if selector is not None:
            s = torch.argmax(selector, dim=1)
            pred_system = pred_system[torch.arange(len(labels)), s][:, None]

        if args.task == 'classification':
            vote_loss = 1 - (pred_system == labels.unsqueeze(1)).float()
        elif args.task == 'regression':
            vote_loss = rmse(pred_system, labels.unsqueeze(1).repeat(1, pred_system.size(1)))

        return vote_loss


def compute_weighted_vote_vectorized(prediction_k, weight, n_class=10):
    """
    Computes the weighted vote for classification in a fully vectorized manner.

    Parameters:
        prediction_k (torch.Tensor): Tensor of shape (B, k) with integer class labels.
        weight (torch.Tensor): Tensor of shape (B, k) with corresponding weights.

    Returns:
        torch.Tensor: A tensor of shape (B,) with the predicted class for each batch instance.
    """

    B, k = prediction_k.shape

    # One-hot encode predictions: shape (B, k, n_class)
    one_hot = F.one_hot(prediction_k.long(), num_classes=n_class).to(weight.dtype)

    # Compute weighted sum per class: shape (B, n_class)
    weighted_sum = (one_hot * weight.unsqueeze(-1)).sum(dim=1)

    # Compute first occurrence positions for each prediction:
    pos = torch.arange(k, device=prediction_k.device).unsqueeze(0).expand(B, k).float()  # shape (B, k)
    # Reshape one_hot for per-class operations: from (B, k, n_class) to (B, n_class, k)
    one_hot_t = one_hot.transpose(1, 2)
    pos_expanded = pos.unsqueeze(1).expand(B, n_class, k)
    # Mask out positions where the class did not occur; use k as the “default” high value.
    masked_pos = pos_expanded.masked_fill(one_hot_t == 0, k)
    # For each batch and class, take the minimum index (i.e. first occurrence)
    min_pos, _ = masked_pos.min(dim=2)  # shape: (B, n_class)

    # For each batch, find the maximum weighted sum over classes
    max_weighted, _ = weighted_sum.max(dim=1, keepdim=True)  # shape (B, 1)
    # Create a mask selecting only classes with the maximum weighted sum
    is_max = (weighted_sum == max_weighted)  # shape: (B, n_class)
    # For classes not having the maximum weighted sum, set min_pos to k (a high value)
    selected_pos = min_pos.masked_fill(~is_max, k)
    # The predicted class is the one with minimum first-occurrence index among those
    predicted_class = selected_pos.argmin(dim=1)  # shape: (B,)

    return predicted_class



def vectorized_select(id_deferral, selector_id):
    """
    Selects the first selector_id[i] + 1 elements from each row of id_deferral.

    Args:
        id_deferral (torch.Tensor): Tensor of shape (N, C).
        selector_id (torch.Tensor): Tensor of shape (N,), where each element is the
            last index to include. We add 1 to include that index.

    Returns:
        A tuple of N tensors, where the i-th tensor contains the first (selector_id[i] + 1)
        elements of id_deferral[i].
    """
    # Compute the lengths per row (each row gets selector_id[i] + 1 elements)
    lengths = (selector_id + 1).tolist()  # list of integers, one per row

    # Create a tensor of column indices (shape: (1, C))
    col_indices = torch.arange(id_deferral.size(1), device=id_deferral.device).unsqueeze(0)

    # Create a mask that is True for positions we want to keep.
    # For each row i, we want to keep columns where index < selector_id[i] + 1.
    mask = col_indices < (selector_id + 1).unsqueeze(1)

    # Apply the mask to id_deferral.
    # This flattens the selected elements into a 1D tensor.
    selected = id_deferral[mask]

    # Split the flattened tensor back into a tuple of tensors with the desired lengths.
    result = torch.split(selected, lengths)

    return result

import torch
import torch.nn.functional as F
import torch
import torch.nn.functional as F


import torch
import torch.nn.functional as F

def weighted_ensemble_vectorized_regression(outputs, labels, agents, selector=None, rmse=None):
    """
    Adapts the same idea from 'weighted_majority_vectorized' to regression.

    Args:
      outputs:  FloatTensor of shape (batch_size, num_agents)
                "confidence" for each agent.
      labels:   FloatTensor of shape (batch_size,)
                true regression targets.
      agents:   FloatTensor of shape (batch_size, num_agents)
                each agent’s regression prediction.
      selector: Optional FloatTensor of shape (batch_size, num_agents).
                If given, we pick exactly one ensemble size k per sample
                via argmax(selector, dim=1).

    Returns:
      A NumPy array of the per-sample, per-ensemble-size error.
      - If selector is None => shape (batch_size, num_agents),
        giving error for each partial ensemble size k=1..num_agents.
      - If selector is not None => shape (batch_size, 1),
        giving error only at the chosen ensemble size for each sample.
    """
    # 1) Sort each sample’s agents in descending order by outputs
    sorted_indices = torch.argsort(outputs, dim=1, descending=True)
    # Reorder predictions to match that sorted order
    sorted_predictions = torch.gather(agents, 1, sorted_indices)

    # 2) Compute softmax weights and reorder them
    weights = F.softmax(outputs, dim=1)  # shape (batch_size, num_agents)
    sorted_weights = torch.gather(weights, 1, sorted_indices)

    # 3) Weighted partial sums for each top-k ensemble
    #    partial_sums[i, k] = sum_{j=1..k} w_j * pred_j (in sorted order)
    partial_sums = torch.cumsum(sorted_predictions * sorted_weights, dim=1)

    #    partial_w[i, k] = sum_{j=1..k} w_j
    partial_w = torch.cumsum(sorted_weights, dim=1)

    # 4) Ensemble outputs for each k = partial_sums / partial_w
    epsilon = 1e-12
    ensemble_preds = partial_sums / (partial_w + epsilon)  # (batch_size, num_agents)

    # 5) Compute squared error (MSE component) for each sample and each k
    #    shape => (batch_size, num_agents)
    errors = rmse(ensemble_preds, labels.unsqueeze(1).repeat(1, ensemble_preds.size(1)))

    # 6) If a selector is provided, pick exactly one k per sample
    if selector is not None:
        s = torch.argmax(selector, dim=1)  # shape (batch_size,)
        # Gather the error at each sample’s chosen k => shape (batch_size,)
        chosen_error = errors[torch.arange(len(labels)), s]

        # Match the classification code’s shape => (batch_size, 1)
        chosen_error = chosen_error.unsqueeze(1)
        return chosen_error.cpu().numpy()

    else:
        # If no selector, return the error for all k => shape (batch_size, num_agents)
        return errors.cpu().numpy()



def weighted_majority_vectorized(outputs, labels, agents, num_classes, selector=None):
    # Rank agents in descending order by their outputs
    sorted_indices = torch.argsort(outputs, dim=1, descending=True)
    predictions = torch.gather(agents, 1, sorted_indices)

    # Convert predictions to LongTensor as required by one_hot
    predictions = predictions.long()

    # Compute softmax weights and reorder them based on sorted indices
    weights = F.softmax(outputs, dim=1)
    sorted_weights = torch.gather(weights, 1, sorted_indices)

    batch_size, num_agents = predictions.size()

    # Create one-hot encoding for the predictions: shape (batch_size, num_agents, num_classes)
    one_hot = F.one_hot(predictions, num_classes=num_classes).float()

    # Multiply weights with one-hot encodings, broadcasting weights to match shape
    weighted_votes = sorted_weights.unsqueeze(-1) * one_hot

    # Compute cumulative weighted votes along the agent dimension (progressively including more agents)
    cumulative_votes = torch.cumsum(weighted_votes, dim=1)

    # For each k (from 1 to num_agents), determine the candidate with maximum cumulative weight
    pred_system = torch.argmax(cumulative_votes, dim=2)

    # Calculate indicator loss (0 if prediction is correct, 1 otherwise) for each k
    if selector is not None:
        s = torch.argmax(selector, dim=1)
        pred_system = pred_system[torch.arange(len(labels)), s][:, None]

    indicator_loss = (pred_system == labels.unsqueeze(1)).float()

    # Return the mean error rate for each k as a NumPy array
    return indicator_loss.cpu().numpy()


def occurence(id_deferral, k_prediction, n_agents=1):
    # First, bring the tensors to CPU and convert to NumPy arrays.
    id_deferral_np = id_deferral.cpu().numpy()  # shape (n, 5)
    k_prediction_np = k_prediction.cpu().numpy()  # shape (n,)

    # Create a mask for columns to select: for each row i, we select columns j where j < (k_prediction[i] + 1)
    n, m = id_deferral_np.shape
    cols = np.arange(m)  # shape (m,)
    mask = cols < (k_prediction_np[:, None] + 1)  # shape (n, m) via broadcasting

    # Use the mask to extract the desired values; the result is a flattened array
    selected_values = id_deferral_np[mask]

    # Count the occurrences of values [0, 1, 2, 3, 4] using np.bincount with minlength=5
    deferral_ratio = np.bincount(selected_values, minlength=n_agents)
    return deferral_ratio


def argmax_last(tensor, dim=0):
    # Flip the tensor along the given dimension
    flipped_tensor = tensor.flip(dims=[dim])
    # Get the index of the first maximum in the flipped tensor
    last_index_in_flipped = torch.argmax(flipped_tensor, dim=dim)
    # Convert the index from the flipped tensor back to the original tensor's index.
    original_index = tensor.size(dim) - 1 - last_index_in_flipped
    return original_index


def mode_first_occurrence_v2(tensor, dim=1):
    """
    Compute the mode along a given dimension, breaking ties by taking
    the first occurrence in the row.

    Parameters:
        tensor (torch.Tensor): Input tensor of shape [B, N].
        dim (int): Dimension along which to compute the mode. Assumed to be 1.

    Returns:
        mode_vals (torch.Tensor): Tensor of mode values for each row.
        first_occurrence_indices (torch.Tensor): Indices of the first occurrence of the mode in each row.
    """
    # Assuming tensor shape is [B, N] along dim 1
    B, N = tensor.shape

    # Create a boolean matrix [B, N, N] where each [i,j,k] indicates if tensor[i, j] equals tensor[i, k]
    eq_matrix = tensor.unsqueeze(2) == tensor.unsqueeze(1)  # [B, N, N]

    # Count how many times each element appears along each row: shape [B, N]
    counts = eq_matrix.sum(dim=2)

    # For each row, find the maximum count (the mode frequency)
    max_counts = counts.max(dim=1, keepdim=True).values  # [B, 1]

    # Create a mask for candidates: positions with the maximum count
    candidate_mask = counts == max_counts  # [B, N]

    # Get the index of the first occurrence among candidates in each row.
    # torch.argmax returns the first index of the maximum value (here True is 1.0, False is 0.0)
    first_occurrence_indices = candidate_mask.float().argmax(dim=1)  # [B]

    # Use gather to extract the mode values using the indices
    mode_vals = tensor.gather(1, first_occurrence_indices.unsqueeze(1)).squeeze(1)

    return mode_vals, first_occurrence_indices








