from weaver.verification import VerificationMethod, FilteringStrategyType

import numpy as np
from tqdm import tqdm
from collections import defaultdict
from loguru import logger
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from copy import deepcopy
from scipy.special import logsumexp
logger.remove()
logger.add(lambda msg: print(msg, end=""), format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {message}", level="INFO")
import warnings

warnings.simplefilter("error", RuntimeWarning)  # Convert warnings into errors


class TopRMs(VerificationMethod):
    def filter(self, dataset, cumulative_mask, metric, m, **kwargs):
        """
            Pick the m best RMs according to some metric (selection@1 acc, prediction acc, precision, recall)
            We then either do an average of their non-thresholded scores and select the top k indices per problem, 
            or we threshold the scores and return the subset of indices where all scores exceed the threshold.
        """
        if len(self.reward_models) == 0:
            return None


        rm_metrics = [(rm, self.verifier_metrics[rm][metric]) for rm in self.reward_models] # warning: RM performance depends on threshold; how do we pass this in?
        top_rms = sorted(rm_metrics, key=lambda x: x[1], reverse=True)[:m]
        top_rm_names = [rm[0] for rm in top_rms]
        logger.info(f"Selected top {m} RMs based on {metric}:")
        for rm, score in top_rms:
            logger.info(f"  {rm}: {score:.3f}")



        top_rm_idxs = np.array([self.reward_model_idxs[self.reward_models.index(judge)] for judge in top_rm_names ])
        weights = np.zeros(self.scores.shape[-1])
        weights[top_rm_idxs] = 1

        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            logger.warning("top-k for a judge has arbitrary tie-breaking policy.")

        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, self.scores, -10e8)
        else:
            masked_scores = self.scores 
        
        ensembled_scores = masked_scores[:, :, top_rm_idxs].mean(axis=-1) # average across judges 
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        return filtered_idxs, weights
        

class TopJudges(VerificationMethod):
    def filter(self, dataset, cumulative_mask, metric, m, **kwargs):
        """
        Pick the m best judges according to some metric (selection@1 acc, prediction acc, precision, recall)
        We then select all generations where its average judge pass rate exceeds some threshold.
        """
        if len(self.lm_judges) == 0:
            return None

        # Get metric values for each judge
        judge_metrics = [(lm, self.verifier_metrics[lm][metric]) for lm in self.lm_judges] 
        top_judges = sorted(judge_metrics, key=lambda x: x[1], reverse=True)[:m]
        top_judge_names = [judge[0] for judge in top_judges]
        logger.info(f"Selected top {m} judges based on {metric}:")
        for judge, score in top_judges:
            logger.info(f"  {judge}: {score:.3f}")

        top_judge_idxs = np.array([self.lm_judges_idxs[self.lm_judges.index(judge)] for judge in top_judge_names ])
        weights = np.zeros(self.scores.shape[-1])
        weights[top_judge_idxs] = 1

        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            logger.warning("top-k for a judge has arbitrary tie-breaking policy.")

        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, self.scores, -10e8)
        else:
            masked_scores = self.scores 
        
        ensembled_scores = masked_scores[:, :, top_judge_idxs].mean(axis=-1) # average across judges 
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.BEST_TIED.value:
            max_pass_rate = ensembled_scores.max(axis=1) # per problem 
            filtered_idxs = [np.where(row == max_pass_rate[i])[0] for i, row in enumerate(ensembled_scores)]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        return filtered_idxs, weights


class Majority_X_then_Top_Y_Judges(VerificationMethod):
    def filter(self, dataset, cumulative_mask, majority_k=20, top_k=3, metric="accuracy", **kwargs):
        """
        Two-stage filtering:
        1. Majority@K: Find the K most common answers
        2. Use top K judges to score each answer, select the one with highest approval rate
        Then return all samples with the selected answer.
        """

        raise NotImplementedError("Need to implement cumulative_mask handling")


class Majority_X_then_Top_Y_RMs(VerificationMethod):
    def filter(self, dataset, cumulative_mask, majority_k=20, top_k=3, metric="accuracy", threshold=0.5, **kwargs):
        """
        Two-stage filtering:
        1. Majority@K: Find the K most common answers
        2. Use top K reward models to score each answer, select ones that pass threshold
        Then return all samples with the selected answer.
        """

        raise NotImplementedError("Need to implement cumulative_mask handling")
    

class WeightedEnsemble(VerificationMethod):
    def filter(self, dataset, cumulative_mask, metric, **kwargs):
        """
            Obtain weights for each verifier using some metric (selection@1 acc, prediction acc, precision, recall).
            Compute a weighted average over all verifiers using this, and select the top k generations per problem using this average.
        """
        shift = kwargs.pop('shift', True)
        weights = np.array([self.verifier_metrics[v][metric] for v in self.reward_models + self.lm_judges]) # get_metric(dataset, metric) # self.results['selection@1']

        if "selection_accuracy" not in metric and shift:
            weights = np.maximum(0, weights - 0.5)

        if all(weights == 0):
            logger.warning("All weights are 0, quitting!")
            return None 
        
        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, self.scores, -10e8)
        else:
            masked_scores = self.scores 

        ensembled_scores = masked_scores @ weights
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        return filtered_idxs, weights


class WeightedRMEnsemble(VerificationMethod):
    def filter(self, dataset, cumulative_mask, metric, **kwargs):
        """
            Obtain weights for each verifier using some metric (selection@1 acc, prediction acc, precision, recall).
            Compute a weighted average over all verifiers using this, and select the top k generations per problem using this average.
        """
        if len(self.reward_models) == 0:
            return None

        shift = kwargs.pop('shift', True)
        weights = np.array([self.verifier_metrics[v][metric] for v in self.reward_models]) # get_metric(dataset, metric) # self.results['selection@1']

        if "selection_accuracy" not in metric and shift:
            weights = np.maximum(0, weights - 0.5)

        if all(weights == 0):
            logger.warning("All weights are 0, quitting!")
            return None 
        
        rm_scores = self.scores[:, :, self.reward_model_idxs]
        
        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, rm_scores, -10e8)
        else:
            masked_scores = rm_scores 

        ensembled_scores = masked_scores @ weights
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        return filtered_idxs, weights
    

class WeightedJudgeEnsemble(VerificationMethod):
    def filter(self, dataset, cumulative_mask, metric, **kwargs):
        """
            Obtain weights for each verifier using some metric (selection@1 acc, prediction acc, precision, recall).
            Compute a weighted average over all verifiers using this, and select the top k generations per problem using this average.
        """
        if len(self.lm_judges) == 0:
            return None

        shift = kwargs.pop('shift', True)

        weights = np.array([self.verifier_metrics[v][metric] for v in self.lm_judges]) # get_metric(dataset, metric) # self.results['selection@1']

        if "selection_accuracy" not in metric and shift:
            weights = np.maximum(0, weights - 0.5)

        if all(weights == 0):
            logger.warning("All weights are 0, quitting!")
            return None 
        
        judge_scores = self.scores[:, :, self.lm_judges_idxs]
        
        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, judge_scores, -10e8)
        else:
            masked_scores = judge_scores 

        ensembled_scores = masked_scores @ weights
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        return filtered_idxs, weights


class LogisticRegressionEnsemble(VerificationMethod):
    def filter(self, dataset, cumulative_mask, train_split=None, use_l1=False, per_problem=False, **kwargs):
        """
            Obtain weights for each verifier using some metric (selection@1 acc, prediction acc, precision, recall).
            Compute a weighted average over all verifiers using this, and select the top k generations per problem using this average.
        """
        original_shape = self.scores.shape
        true_labels = np.array(dataset['answer_correct']).astype(int)         
        # Create mask if none provided (all True)
        if cumulative_mask is None:
            cumulative_mask = np.ones((original_shape[0], original_shape[1]), dtype=bool)
        else:
            cumulative_mask = np.array(cumulative_mask).astype(bool)

        # Create output array filled with -inf
        full_probs = np.full((original_shape[0], original_shape[1]), float('-inf'))
        if per_problem:
            for i in range(original_shape[0]):
                # Get the scores for the current problem
                problem_scores = self.scores[i, :, :]
                # Flatten the masked scores and labels
                flat_features = problem_scores[cumulative_mask[i]]
                flat_labels = true_labels[i][cumulative_mask[i]]

                if train_split is not None:
                    pos_idxs = np.where(flat_labels == 1)[0]
                    neg_idxs = np.where(flat_labels == 0)[0]
                    pos_train_idxs = np.random.choice(pos_idxs, int(train_split*len(pos_idxs)), replace=False)
                    neg_train_idxs = np.random.choice(neg_idxs, int(train_split*len(neg_idxs)), replace=False)
                    train_idxs = np.concatenate([pos_train_idxs, neg_train_idxs])
                    train_features = flat_features[train_idxs]
                    train_labels = flat_labels[train_idxs]
                else:
                    train_features = flat_features
                    train_labels = flat_labels


                if len(np.unique(train_labels)) == 1:
                    # only one class , so all "scores" are just -inf --> default to first sample 
                    continue 

                # Fit model and get predictions
                if use_l1:
                    clf = LogisticRegression(penalty='l1', solver='liblinear', random_state=42, max_iter=1000).fit(train_features, train_labels)
                else:
                    clf = LogisticRegression(random_state=42, max_iter=1000).fit(train_features, train_labels)
                masked_probs = clf.predict_proba(flat_features)[:, 1]
                # Place predictions back in original positions
                full_probs[i].flat[cumulative_mask[i]] = masked_probs
        else:
            # Flatten everything and apply mask
            flat_mask = cumulative_mask.flatten()
            flat_features = self.scores.reshape(-1, self.scores.shape[-1])[flat_mask]
            flat_labels = true_labels.flatten()[flat_mask]
            
            # Fit model and get predictions
            if use_l1:
                clf = LogisticRegression(penalty='l1', solver='liblinear', random_state=0, max_iter=1000).fit(flat_features, flat_labels)
            else:
                clf = LogisticRegression(random_state=0, max_iter=1000).fit(flat_features, flat_labels)
            
            print(clf.coef_[0])
            print(clf.intercept_)
            print(clf.score(flat_features, flat_labels))

            masked_probs = clf.predict_proba(flat_features)[:, 1]
            
            # Place predictions back in original positions
            full_probs.flat[flat_mask] = masked_probs

        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-full_probs, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in full_probs]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        return filtered_idxs


class NaiveBayesEnsemble(VerificationMethod):

    def get_nb_accs(self, binary_scores: np.ndarray, masked_flattened_true_labels: np.ndarray, cumulative_mask: np.ndarray, uniform):
        flattened_binary_scores = binary_scores.reshape(-1, binary_scores.shape[-1])
        flattened_binary_scores = flattened_binary_scores[cumulative_mask.flatten()]

        indices_0 = np.where(masked_flattened_true_labels == 0)[0]
        indices_1 = np.where(masked_flattened_true_labels == 1)[0]
        
        tnr = []
        tpr = []
        for i in range(flattened_binary_scores.shape[1]):
            verifier_scores = flattened_binary_scores[:, i]
            tnr.append(accuracy_score(masked_flattened_true_labels[indices_0], verifier_scores[indices_0]))
            tpr.append(accuracy_score(masked_flattened_true_labels[indices_1], verifier_scores[indices_1]))
    
        tpr = np.array(tpr)
        tnr = np.array(tnr)

        tpr = np.clip(tpr, 0.01, 0.99)
        #tpr = np.ones(len(tpr)) * 0.7
        tnr = np.clip(tnr, 0.01, 0.99)
        #tnr = np.ones(len(tnr)) * 0.7

        if uniform:
            tpr = np.ones(len(tpr)) * 0.7
            tnr = np.ones(len(tnr)) * 0.7

        print(f"TPR: {tpr}, TNR: {tnr}")
        return tpr, tnr
    
    def get_nb_probs(self, binary_scores, masked_flattened_true_labels, cumulative_mask, train_split=None, uniform=False):
        if train_split:
            pos_idxs = np.where(masked_flattened_true_labels == 1)[0]
            neg_idxs = np.where(masked_flattened_true_labels == 0)[0]
            pos_train_idxs = np.random.choice(pos_idxs, int(train_split*len(pos_idxs)), replace=False)
            neg_train_idxs = np.random.choice(neg_idxs, int(train_split*len(neg_idxs)), replace=False)
            train_idxs = np.concatenate([pos_train_idxs, neg_train_idxs])
            train_scores = binary_scores[train_idxs]
            train_labels = masked_flattened_true_labels[train_idxs]
            train_cumulative_mask = cumulative_mask[train_idxs]

            if len(np.unique(train_labels)) == 1:
                return np.zeros(len(binary_scores)), np.full(self.scores.shape[-1], None), np.full(self.scores.shape[-1], None)
        else:
            train_scores = binary_scores
            train_labels = masked_flattened_true_labels
            train_cumulative_mask = cumulative_mask

        tpr, tnr = self.get_nb_accs(train_scores, train_labels, train_cumulative_mask, uniform)
        fpr = 1 - tnr
        fnr = 1 - tpr

        # Compute log-likelihoods
        log_likelihood_y1 = np.sum(
            np.log(binary_scores * tpr + (1 - binary_scores) * fnr), axis=-1
        )  # (problems x samples)

        log_likelihood_y0 = np.sum(
            np.log(binary_scores * fpr + (1 - binary_scores) * tnr), axis=-1
        )  # (problems x samples)

        # Use masked data for class balance
        cb = train_labels.mean()

        # Compute log posteriors (log-space multiplication turns into addition)
        log_posterior_y1 = log_likelihood_y1 + np.log(cb)
        log_posterior_y0 = log_likelihood_y0 + np.log(1 - cb)

        # Use logsumexp for stability: log(exp(a) + exp(b)) = logsumexp([a, b])
        log_prob_y1_given_features = log_posterior_y1 - logsumexp([log_posterior_y1, log_posterior_y0], axis=0)

        # Convert back to probability space
        prob_y1_given_features = np.exp(log_prob_y1_given_features)

        prob_y1_given_features[~cumulative_mask] = float('-inf')

        return prob_y1_given_features, tpr, tnr 


    def filter(self, dataset, cumulative_mask, train_split=None, per_problem=False, uniform=False, **kwargs):
        """
            Obtain weights for each verifier using some metric (selection@1 acc, prediction acc, precision, recall).
            Compute a weighted average over all verifiers using this, and select the top k generations per problem using this average.
        """
        true_labels = np.array(dataset['answer_correct']).astype(int)
        # construct binary scores
        if cumulative_mask is None:
            cumulative_mask = np.ones((self.scores.shape[0], self.scores.shape[1]), dtype=bool)
        else:
            cumulative_mask = np.array(cumulative_mask, dtype=bool)
            
        if per_problem:
            bad_count = 0
            # Initialize lists to store TPR and TNR for each problem
            prob_y1_given_features = np.zeros((self.scores.shape[0], self.scores.shape[1]))
            all_tpr = []
            all_tnr = []
            for i in range(self.scores.shape[0]):

                # Get the scores for the current problem
                problem_scores = self.binary_scores[i, :, :] # n_samples x n_verifiers
                masked_flattened_problem_labels = true_labels[i][cumulative_mask[i]]

                if len(np.unique(masked_flattened_problem_labels)) ==1:
                    all_tpr.append(np.full(self.scores.shape[-1], None))
                    all_tnr.append(np.full(self.scores.shape[-1], None))
                    bad_count += 1
                    continue # just sample randomly 
                prob_y1_given_features[i], tpr, tnr = self.get_nb_probs(problem_scores, masked_flattened_problem_labels, cumulative_mask[i], train_split, uniform)
                if tpr[0] is None:
                    bad_count += 1
                all_tpr.append(tpr)
                all_tnr.append(tnr)
            print(f"Bad count: {bad_count}")
        else:
            masked_flattened_true_labels = true_labels.flatten()[cumulative_mask.flatten()]        
            prob_y1_given_features, all_tpr, all_tnr = self.get_nb_probs(self.binary_scores, masked_flattened_true_labels, cumulative_mask, None, uniform)

        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-prob_y1_given_features, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.BEST_TIED.value:
            best_scores = prob_y1_given_features.max(axis=1)
            filtered_idxs = [np.where(row == best_scores[i]) for i, row in enumerate(prob_y1_given_features)]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in prob_y1_given_features]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        weights = {
            "tpr": np.array(all_tpr),
            "tnr": np.array(all_tnr)
        }

        return filtered_idxs, weights
