
from weaver.verification import VerificationMethod, FilteringStrategyType
import numpy as np
import itertools

from collections import defaultdict 
from loguru import logger
try:
    from metal.label_model import LabelModel
except:
    pass

from sklearn.metrics import precision_score, recall_score, accuracy_score

logger.remove()
logger.add(lambda msg: print(msg, end=""), format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {message}", level="INFO")

class WeakSupervision(VerificationMethod):
    def compute_vote_matrices(self, dataset, cumulative_mask, judges_only, rms_only):
        vote_matrices = []
        true_labels = []
        for i, sample in enumerate(dataset):
            votes = []

            if cumulative_mask is not None:
                mask_idxs = np.where(np.array(cumulative_mask)[i] == 1)[0]
            else:
                mask_idxs = None

            if not judges_only: # if set to false
                for j, rm in enumerate(self.reward_models):
                    if rm in sample:
                        s = np.array(sample[rm])[mask_idxs] if mask_idxs is not None else np.array(sample[rm])
                        thresh = sample[f"{rm}_threshold"]
                        v = np.array([1 if (x is not None and x > thresh) else (0 if x is not None else None) for x in s])
                        v = np.where(v == None, 0, v)
                        votes.append(v)

            if not rms_only:
                for judge in self.lm_judges:
                    if judge in sample:
                        s = np.array(sample[judge])[mask_idxs] if mask_idxs is not None else np.array(sample[judge])
                        s = np.where(s == None, 0, s).astype(int)
                        votes.append(s)

            vote_matrices.append(np.array(votes).T)

            problem_labels = np.array(sample['answer_correct']).astype(int) 
            if mask_idxs is not None:
                problem_labels = problem_labels[mask_idxs]
            true_labels.append(problem_labels)

        return vote_matrices, true_labels
    

    def _get_ws_estimated_verifier_class_accuracies(self, label_model, n_verifiers):
        """
            After computing the label model, we get each verifier's class-conditional accuracy, estimated using WS:
            Pr(lf_i = y | y = 1) and Pr(lf_i = y | y = 0)

            This is useful in comparing WS's estimated TPR/FPR against the true TPR/FPR. 

            Args:
            - label_model: trained WS label model.
            - n_verifiers: number of verifiers 

            Returns:
            - (n_verifiers, n_classes) accuracy matrix A where A[i, j] = Pr(lf_i = y | y = j).
        """

        weights = label_model.get_conditional_probs().reshape((n_verifiers, -1, label_model.k))

        weights = weights[:, 1:, :]  # This keeps only the last two rows of the 3-row dimension
        verifier_accuracies = np.array([np.diag(matrix) for matrix in weights])

        TNR = np.array([matrix[0, 0] for matrix in weights])
        TPR = np.array([matrix[1, 1] for matrix in weights])
        FPR = np.array([matrix[1, 0] for matrix in weights]) 
        FNR = np.array([matrix[0, 1] for matrix in weights])

        return TPR, TNR, FPR, FNR

    def _get_ws_estimated_verifier_accuracies(self, label_model, n_verifiers):
        """
            After computing the label model, we get each verifier's accuracy, estimated using WS:
            Pr(lf_i = y)

            This can be used in weighting each verifier's score.  

            Args:
            - label_model: trained WS label model.
            - n_verifiers: number of verifiers 

            Returns:
            - (n_verifiers) accuracy vector A where A[i] = Pr(lf_i = y).
        """
        TPR, TNR, FPR, FNR = self._get_ws_estimated_verifier_class_accuracies(label_model, n_verifiers)
        # law of total probability: Pr(lf_i = y) = Pr(lf_i = y | y = 1)Pr(y = 1) + Pr(lf_i = y | y = 0)Pr(y = 0)
        return TNR * label_model.p[0] + TPR * label_model.p[1]


    def _get_deps(self, votes, truth):
        """
            Compute the dependencies between verifiers using the inverse covariance matrix of the scores.
            Returns the top k dependencies as a list of tuples, where k is 10% of the number of edges.
        """
        # For now, assume we have access to the true covariance matrix --- stack both votes and labels
        all_scores = np.hstack([votes, truth[:, np.newaxis]])
        cov = np.cov(all_scores.T)
        cov = cov + 1e-6 * np.eye(cov.shape[0]) # add small value to diagonal to make it invertible
        inv_cov = np.linalg.inv(cov)
        # remove the last row/column
        inv_cov = inv_cov[:, :-1]
        inv_cov = inv_cov[:-1, :]

        m = inv_cov.shape[0]
        k = int(0.1 * (len(inv_cov) * (len(inv_cov) - 1) // 2)) # desired density 
        if k == 0:
            k = 1
        deps = []
        sorted_idxs = np.argsort(-np.abs(inv_cov), axis=None)
        for idx in sorted_idxs:
            i = int(np.floor(idx / m))
            j = idx % m 
            if (j, i) in deps or i == j:
                continue
            deps.append((i, j))
            if len(deps) == k:
                break

        return deps 
    

    def _drop_deps(self, votes, truth, current_verifiers, k=3):
        """
            Select the top k maximally independent verifiers based on the inverse covariance matrix of the scores.
        """
        logger.info(f"Finding maximally independent verifier set of size {k}")
        n_verifiers = votes.shape[-1]
        triple_to_marginal = {}
        triple_to_sparsity = {}
        for triple in itertools.combinations(range(n_verifiers), k):
            triple = list(triple)
            
            # compute inverse covariance matrix on the selected verifiers + truth 
            selected_scores = np.hstack([votes[:, triple], truth[:, np.newaxis]])
            selected_cov = np.cov(selected_scores.T)

            try:
                selected_inv_cov = np.linalg.inv(selected_cov)
            except np.linalg.LinAlgError:
                selected_cov = selected_cov + 1e-6 * np.eye(selected_cov.shape[0]) # add small value to diagonal to make it invertible
                selected_inv_cov = np.linalg.inv(selected_cov)

            # discard the covariance with the true answer
            selected_inv_cov = selected_inv_cov[:, :-1]
            selected_inv_cov = selected_inv_cov[:-1, :]

            # set diagonal to 0 (we don't count dependencies with itself)
            np.fill_diagonal(selected_inv_cov, 0)

            # record largest magnitude element 
            s = np.abs(selected_inv_cov).max()
            triple_to_sparsity[tuple(triple)] = s
            
            marginals = votes[:, triple].mean(axis=0)
            triple_to_marginal[tuple(triple)] = marginals

        sorted_sparsity = {k: v for k, v in sorted(triple_to_sparsity.items(), key=lambda item: item[1])}

        top_triple = list(sorted_sparsity.keys())[0]
        triple_names = [v for i, v in enumerate(current_verifiers) if i in top_triple]
        logger.info(f"Top triple: {triple_names}, sparsity: {sorted_sparsity[top_triple]}")
        if any(triple_to_marginal[top_triple] > 0.9) or any(triple_to_marginal[top_triple] < 0.1):
            logger.warning(f"Some of the verifiers in the top triple have marginal probabilities that are too extreme: {triple_names}, {triple_to_marginal[top_triple]}")

        top_triple = np.array(list(top_triple))
        votes = votes[:, top_triple]
        return votes, top_triple


    def fit(self, dataset, cumulative_mask, class_balance, level, n_epochs, mu_epochs, lr, use_deps, judges_only, rms_only, drop_imbalanced_verifiers):
        vote_matrices, true_labels = self.compute_vote_matrices(dataset, cumulative_mask, judges_only, rms_only)
        n_problems = len(vote_matrices)
        n_generations, n_verifiers = vote_matrices[0].shape
        results = defaultdict(list)
        if level == "per_problem":
            all_tpr = []
            all_tnr = []
            all_fpr = []
            all_fnr = []
            for i, (votes, truth) in enumerate(zip(vote_matrices, true_labels)):
                logger.info(f"Processing problem {i+1}/{n_problems}")
                votes = np.array(votes)
                truth = np.array(truth)

                if drop_imbalanced_verifiers:
                    marginals = votes.mean(axis=0)
                    balanced_idxs = np.where((marginals > 0.1) & (marginals < 0.9))[0]
                    balanced_models = [v for i, v in enumerate((self.reward_models + self.lm_judges)) if i in balanced_idxs]
                    logger.info(f"Only using balanced verifiers: {balanced_models}")
                    votes = votes[:, balanced_idxs]

                n_generations, n_verifiers = votes.shape
                
                label_model = LabelModel(k=2, seed=123)
                cb_args = {'Y_dev': truth+1, 'class_balance': None} if class_balance is None else {'Y_dev': None, 'class_balance': np.array([1 - class_balance, class_balance])}

                if use_deps == 'model':
                    # For now, assume we have access to the true covariance matrix --- stack both votes and labels
                    deps = self._get_deps(votes, truth)
                elif use_deps == 'drop':
                    votes, remaining_idxs = self._drop_deps(votes, truth, balanced_models if drop_imbalanced_verifiers else self.reward_models+self.lm_judges)
                    n_verifiers = len(remaining_idxs)
                    
                votes_scaled = votes + 1
                
                    
                label_model.train_model(
                    votes_scaled, 
                    deps = deps if use_deps == 'model' else [],
                    L_train_continuous=None,
                    abstains=False, 
                    symmetric=False, 
                    n_epochs=1000 if n_epochs is None else n_epochs, 
                    mu_epochs=10000 if mu_epochs is None else mu_epochs,
                    log_train_every=100,
                    lr=0.001 if lr is None else lr,
                    **cb_args,
                )
                probs = label_model.predict_proba(votes_scaled)
                scores = probs[:, 1] 

                estimated_accuracies = self._get_ws_estimated_verifier_accuracies(label_model, n_verifiers)
                TPR, TNR, FPR, FNR = self._get_ws_estimated_verifier_class_accuracies(label_model, n_verifiers)
                estimated_recall = TPR
                all_tpr.append(TPR)
                all_tnr.append(TNR)
                all_fpr.append(FPR)
                all_fnr.append(FNR)

                cb = class_balance if class_balance is not None else truth.mean()
                estimated_precision = (TPR * cb) / (TPR * cb + (1 - TNR) * (1 - cb))
                
                if cumulative_mask is not None:
                    mask_idxs = np.where(np.array(cumulative_mask)[i] == 1)[0]
                    full_scores = np.ones(len(cumulative_mask[i]))*(-10e8)
                    full_scores[mask_idxs] = scores
                    scores = full_scores

                if drop_imbalanced_verifiers or use_deps == 'drop':

                    if drop_imbalanced_verifiers and use_deps == 'drop':
                        kept_idxs = balanced_idxs[remaining_idxs]
                    elif drop_imbalanced_verifiers:
                        kept_idxs = balanced_idxs
                    elif use_deps == 'drop':
                        kept_idxs = remaining_idxs

                    true_estimated_accuracies = true_estimated_recall = true_estimated_precision = np.zeros(self.scores.shape[-1])
                    true_estimated_accuracies[kept_idxs] = estimated_accuracies
                    true_estimated_recall[kept_idxs] = estimated_recall
                    true_estimated_precision[kept_idxs] = estimated_precision
                else:
                    true_estimated_accuracies = estimated_accuracies
                    true_estimated_recall = estimated_recall
                    true_estimated_precision = estimated_precision


                results['acc'].append(true_estimated_accuracies)
                results['recall'].append(true_estimated_recall)
                results['precision'].append(true_estimated_precision)
                results['scores'].append(scores)

        elif level == "all_data":
            vote_matrices = np.array([votes_per_generation for votes_per_problem in vote_matrices for votes_per_generation in votes_per_problem]).reshape((-1, n_verifiers))
            true_labels = np.array([label for true_labels_per_problem in true_labels for label in true_labels_per_problem]).flatten()

            balanced_idxs = None
            if drop_imbalanced_verifiers:
                marginals = vote_matrices.mean(axis=0)
                balanced_idxs = np.where((marginals > 0.1) & (marginals < 0.9))[0]
                balanced_models = [v for i, v in enumerate((self.reward_models + self.lm_judges)) if i in balanced_idxs]
                logger.info(f"Only using balanced verifiers: {balanced_models}")
                vote_matrices = vote_matrices[:, balanced_idxs]
                n_verifiers = len(balanced_idxs)


            label_model = LabelModel(k=2, seed=0)
            cb_args = {'Y_dev': true_labels+1, 'class_balance': None} if class_balance is None else {'Y_dev': None, 'class_balance': np.array([1 - class_balance, class_balance])}

            if use_deps == 'model':
                # For now, assume we have access to the true covariance matrix --- stack both votes and labels
                deps = self._get_deps(vote_matrices, true_labels)
                print(f"Deps are: {deps}")
            elif use_deps == 'drop':
                vote_matrices, remaining_idxs = self._drop_deps(vote_matrices, true_labels, balanced_models if drop_imbalanced_verifiers else self.reward_models+self.lm_judges)
                n_verifiers = len(remaining_idxs)

            votes_scaled = vote_matrices + 1
            label_model.train_model(
                votes_scaled, 
                deps = deps if use_deps == 'model' else [],
                L_train_continuous=None,
                abstains=False, 
                symmetric=False, 
                n_epochs=5000 if n_epochs is None else n_epochs, 
                mu_epochs=10000 if mu_epochs is None else mu_epochs,
                log_train_every=1000,
                lr=0.0001 if lr is None else lr,
                **cb_args,
            )
            probs = label_model.predict_proba(votes_scaled)
            scores = probs[:, 1] 
            pseudolabels = np.round(scores).reshape(n_problems, -1) # 0 or 1 

            estimated_accuracies = self._get_ws_estimated_verifier_accuracies(label_model, n_verifiers)
            TPR, TNR, FPR, FNR = self._get_ws_estimated_verifier_class_accuracies(label_model, n_verifiers)
            logger.info(f"TPR: {TPR}, TNR: {TNR}")
            logger.info(f"Mu: {label_model.mu}")
            estimated_recall = TPR 
            if class_balance is None:
                class_balance = true_labels.mean()

            all_tnr = TNR 
            all_tpr = TPR 
            all_fpr = FPR
            all_fnr = FNR

            estimated_precision = (TPR * class_balance) / (TPR * class_balance + (1 - TNR) * (1 - class_balance))
            
            estimated_normalized_precision = np.zeros(n_verifiers)
            estimated_normalized_recall = np.zeros(n_verifiers)
            estimated_selection_accuracy = np.zeros(n_verifiers)
            vote_matrices = vote_matrices.reshape(n_problems, n_generations, n_verifiers)
            for j in range(n_verifiers):
                row_precision = [] 
                row_recall = []
                for i in range(n_problems):
                    if sum(pseudolabels[i]) == 0:
                        continue 
                    row_precision.append(precision_score(pseudolabels[i], vote_matrices[i, :, j], zero_division=0))
                    row_recall.append(recall_score(pseudolabels[i], vote_matrices[i, :, j], zero_division=0))
                estimated_normalized_precision[j] = np.array(row_precision).mean()
                estimated_normalized_recall[j] = np.array(row_recall).mean()

                best_idxs = self.scores[:, :, j].argmax(axis=1)
                estimated_selection_accuracy[j] = pseudolabels[np.arange(n_problems), best_idxs].mean()

            if cumulative_mask is not None:
                flattened_cumulative_mask = np.array(cumulative_mask).flatten()
                flattened_mask_idxs = np.where(flattened_cumulative_mask == 1)[0]
                full_scores = np.ones(len(flattened_cumulative_mask))*(-10e8)
                full_scores[flattened_mask_idxs] = scores
                scores = full_scores


            if drop_imbalanced_verifiers or use_deps == 'drop':
                if drop_imbalanced_verifiers and use_deps == 'drop':
                    kept_idxs = balanced_idxs[remaining_idxs]
                elif drop_imbalanced_verifiers:
                    kept_idxs = balanced_idxs
                elif use_deps == 'drop':
                    kept_idxs = remaining_idxs
                true_estimated_accuracies = true_estimated_recall = true_estimated_precision = true_estimated_selection_accuracy = true_estimated_normalized_recall = true_estimated_normalized_precision = np.zeros(self.scores.shape[-1])
                true_estimated_accuracies[kept_idxs] = estimated_accuracies
                true_estimated_recall[kept_idxs] = estimated_recall
                true_estimated_precision[kept_idxs] = estimated_precision
                true_estimated_selection_accuracy[kept_idxs] = estimated_selection_accuracy
                true_estimated_normalized_recall[kept_idxs] = estimated_normalized_recall
                true_estimated_normalized_precision[kept_idxs] = estimated_normalized_precision
            else:
                true_estimated_accuracies = estimated_accuracies
                true_estimated_recall = estimated_recall
                true_estimated_precision = estimated_precision
                true_estimated_selection_accuracy = estimated_selection_accuracy
                true_estimated_normalized_recall = estimated_normalized_recall
                true_estimated_normalized_precision = estimated_normalized_precision

            
            results['acc'].extend([true_estimated_accuracies] * n_problems)
            results['recall'].extend([true_estimated_recall] * n_problems)
            results['precision'].extend([true_estimated_precision] * n_problems)
            results['selection_accuracy'].extend([true_estimated_selection_accuracy] * n_problems)
            results['normalized_recall'].extend([true_estimated_normalized_recall] * n_problems)
            results['normalized_precision'].extend([true_estimated_normalized_precision] * n_problems)
            results['scores'].extend(scores)

        weights = {
            "tpr": all_tpr,
            "tnr": all_tnr,
            "fpr": all_fpr,
            "fnr": all_fnr,
        }
        return results, weights

    def validate_params(self, class_balance, level, metric, judges_only, rms_only):
        assert class_balance is None or isinstance(class_balance, float)
        assert level in ["per_problem", "all_data"]
        assert metric in ["top3", "scores", "acc", "recall", "precision", "selection_accuracy", "normalized_recall", "normalized_precision"] 
        assert not (judges_only and rms_only)

    def filter(self, dataset, cumulative_mask, class_balance, level, metric, use_deps=False, top_verifiers = None, n_epochs=None, mu_epochs=None, lr=None, judges_only=False, rms_only=False, drop_imbalanced_verifiers=False, **kwargs):
        self.validate_params(class_balance, level, metric, judges_only, rms_only)
        shift = kwargs.pop('shift', True)
        if judges_only:
            if len(self.lm_judges) == 0:
                return None
            self.scores = self.scores[:, :, self.lm_judges_idxs]
        if rms_only:
            if len(self.reward_models) == 0:
                return None
            self.scores = self.scores[:, :, self.reward_model_idxs]

        results, weights = self.fit(dataset, cumulative_mask, class_balance, level, n_epochs, mu_epochs, lr, use_deps, judges_only, rms_only, drop_imbalanced_verifiers)
        filtered_idxs = []

        if metric != 'scores':
            weights = np.array(results[metric])

            if "selection_accuracy" not in metric and shift:
                weights = np.maximum(0, weights - 0.5)
            
            if top_verifiers is not None:
                bottom_k_indices = np.argsort(weights, axis=1)[:, :weights.shape[1] - top_verifiers]

                # Create a mask of ones (keep all values by default)
                mask = np.ones_like(weights, dtype=bool)

                # Use advanced indexing to set False for the bottom (10-k) elements
                rows = np.arange(weights.shape[0])
                mask[rows[:, None], bottom_k_indices] = False

                # Apply the mask to zero out the bottom (10-k) elements
                new_weights = weights * mask

                weights = new_weights

            if cumulative_mask is not None:
                mask = np.array(cumulative_mask)[..., np.newaxis]
                masked_scores = np.where(mask, self.scores, -10e8)
            else:
                masked_scores = self.scores 

            ensembled_scores = (masked_scores * weights[:, np.newaxis, :]).sum(axis=-1) # (n_problems, n_generations)
            if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
                filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
            elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
                filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
            else:
                ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        elif metric == 'scores':
            if cumulative_mask is not None:
                raise NotImplementedError("We have not implemented WS for second stage")
            scores = np.array(results['scores']).reshape(len(dataset), -1)
            for i, _ in enumerate(dataset):
                if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
                    best_idxs = np.argsort(-scores[i], kind='stable')[:self.filter_strategy.params]
                elif self.filter_strategy.name == FilteringStrategyType.BEST_TIED.value:
                    best_score = scores[i].max()
                    best_idxs = np.where(scores[i] == best_score)[0]
                elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
                    best_idxs = np.where(scores[i] >= self.filter_strategy.params)[0]
                else:
                    ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")
                filtered_idxs.append(best_idxs.tolist())

        return filtered_idxs, weights if isinstance(weights, dict) else weights[0]
