
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import argparse
import time
from sklearn.metrics import auc
from multiprocessing import Pool, cpu_count

from dataclasses import dataclass, fields, replace
from datasets import *
from sa import BaseAlgorithm, AlgorithmParams
from rules import TwoWeightKnapsackRule, IntegerKnapsackRule
from neighbors import swap_high_rule, swap_low_rule, move_low_to_high, move_high_to_low
from helpers import ForwardFeatureSelection, get_interpolated_auc
from rules import Rule, Operator

@dataclass
class ComplementParams(AlgorithmParams):
    p: float # Deisred proportion to complement overall
    c: float # proportion of the data that the interpretable alg predicts on
    
class Complement(BaseAlgorithm): 
    def evaluate_rule(
        self, 
        knapsack_rule: TwoWeightKnapsackRule, 
        X: pd.DataFrame, 
        y: pd.Series, 
        mode='train',
    ): 
        
        assert(self.params.p >= self.params.c) 
        
        if mode == 'train': 
            probs = self.dataset.get_y_train_probs()
        elif mode == 'valid': 
            probs = self.dataset.get_y_valid_probs()
        elif mode == 'test': 
            probs = self.dataset.get_y_test_probs()
            
        bb_scores = self.convert_to_quantile(probs)
        interpretable_scores = self.convert_to_quantile(knapsack_rule.get_mask(X))
        
        bb_scores[interpretable_scores > 1 - self.params.c] = 1.1 
        return y[bb_scores >= np.quantile(bb_scores, 1 - self.params.p)].mean()
    
    def evaluate_rule_train(self, knapsack_rule: TwoWeightKnapsackRule):
        return self.evaluate_rule(knapsack_rule, self.dataset.get_X_train(), self.dataset.get_y_train(), 'train')
    
    def evaluate_rule_valid(self, knapsack_rule: TwoWeightKnapsackRule):
        return self.evaluate_rule(knapsack_rule, self.dataset.get_X_valid(), self.dataset.get_y_valid(), 'valid')

    def score_rule(self, knapsack_rule: TwoWeightKnapsackRule):
        return self.evaluate_rule(knapsack_rule, self.dataset.get_X_test(), self.dataset.get_y_test(), 'test')
    
    def get_acceptance_probability(
        self, 
        current_score: float, 
        proposed_score: float, 
        t: int, 
        T: int
    ) -> float:
        """
        Return the acceptance probability given the current score, proposed score, current iteration (t) 
        and total number of iterations (T).
        """
        if (proposed_score > current_score + 0.000):
            return 1
        if (proposed_score > current_score - 0.002):
            diff = (current_score - proposed_score) / 0.002
            return np.exp(-30 * diff * t / T)
        return 0
    
    @staticmethod
    def convert_to_quantile(predictions: np.ndarray): 
        predictions = np.argsort(predictions)
        output = np.array([None for _ in range(len(predictions))])
        for i, x in enumerate(predictions): 
            output[x] = i
        return output / len(output)

