from typing import List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

from rules import TwoWeightKnapsackRule, IntegerKnapsackRule, Rule, Operator, Condition

def get_rule_candidates(X: pd.DataFrame, y: pd.Series, num_features: int) -> List[Rule]: 
    """
    Get a list of candidate rules from the data using Sparse Logistic Regression
    """
    rule_candidates = []
    for (feature_name, weight) in SparseLogisticRegression(X, y, num_features).get_coefs().items():
        rule_candidates.append(Rule.create_from_feature(feature_name, Operator.EQUAL, 1 if weight > 0 else 0))
    return rule_candidates

def get_nonzero_coefs(lr: LogisticRegression, feature_names: List[str]):
    coefs = pd.Series(lr.coef_[0], index=feature_names)
    coefs = coefs[np.abs(coefs) > 0.00001]
    return coefs.sort_values(ascending=False)

def get_interpolated_auc(
    rule: TwoWeightKnapsackRule,
    X: pd.DataFrame,
    y: pd.Series,
    bb_scores: list[float],
    desired_coverage: float,
    to_plot=False,
):
    """
    Applies a rule to the dataset X. Vary the cutoff value and plot the 
    complement AUC with respect to bb_scores at the desired coverage level. 
    """
    assert(len(y) == len(bb_scores))
    assert(len(y) == len(X))

    num_conditions_apply = rule.get_mask(X)

    vals = list(set(num_conditions_apply))
    vals.sort(reverse=True)

    bb_conf = y[
        bb_scores >= np.quantile(bb_scores, 1-desired_coverage)
    ].mean()

    confs = [bb_conf]
    trans = [0]

    for v in vals:        
        temp = bb_scores.copy()
        temp[num_conditions_apply >= v] = 1.1 
        
        conf = y[np.argsort(-temp)[:int(desired_coverage * len(bb_scores))]].mean()
        transp = np.sum(num_conditions_apply >= v) / len(num_conditions_apply) / desired_coverage
        
        if transp > 1.0: # How do we deal with this case???
            confs.append(conf)
            trans.append(1.0)
            break
            
        confs.append(conf)
        trans.append(transp)

    x_vals = list(np.arange(0, 1.05, 0.05))
    y_vals = [np.interp(x, trans, confs) for x in x_vals]
    
    if to_plot:
        return x_vals, y_vals
    return np.mean(y_vals)


class SparseLogisticRegression:
    def __init__(self, X_train: pd.DataFrame, y_train: pd.Series, num_coefs: int, sample_weight=None):
        self.X_train = X_train
        self.y_train = y_train
        self.num_coefs = num_coefs
        self.sample_weight = sample_weight

        self.search_for_C()

    def fit_model(self, C: float):
        lr = LogisticRegression(penalty='l1', solver='liblinear', C=C, max_iter=1000)
        lr.fit(self.X_train, self.y_train, sample_weight=self.sample_weight)
        n = len(get_nonzero_coefs(lr, self.X_train.columns))
        return n, lr

    def search_for_C(self):
        lower = 0.001
        upper = 1

        for _ in range(20):
            c = (lower + upper) / 2
            n, lr = self.fit_model(c)
            if self.num_coefs is None: 
                break
            if n < self.num_coefs:
                lower = c
            elif n > self.num_coefs:
                upper = c
            else:
                self.model = lr
                upper = c
        self.model = lr

    def predict_proba(self, X_test: pd.DataFrame):
        return self.model.predict_proba(X_test)[:, 1]

    def get_coefs(self):
        return get_nonzero_coefs(self.model, self.X_train.columns)

    def to_integer_knapsack(self, scale=1, only_one=False) -> IntegerKnapsackRule:
        rules = []
        for name, weight in self.get_coefs().items():
            if not only_one:
                rules.append(Rule.create_from_feature(name, Operator.EQUAL, 1, weight=weight * scale))
            if only_one: 
                rules.append(Rule.create_from_feature(name, Operator.EQUAL, 1, weight=weight / np.abs(weight)))
        return IntegerKnapsackRule(rules, name='Rounding from Dense Model')

class ForwardFeatureSelection:
    def __init__(self, X_train: pd.DataFrame, y_train: pd.Series, num_features: int, unitary=False, **lr_kwargs):
        self.X_train = X_train
        self.y_train = y_train
        self.num_features = num_features
        self.lr_kwargs = lr_kwargs
        

        self.features = []
        model = None

        for _ in range(self.num_features):
            model = self.forward(self.features)
            self.features = self.features + [model.new_feature]

        self.model: LogisticRegression = model['model']

    def process_subset(self, feature_set: List[str]):
        model = LogisticRegression(**self.lr_kwargs, solver='lbfgs', max_iter=10000)
        model.fit(self.X_train[feature_set], self.y_train)
        y_pred = model.predict_proba(self.X_train[feature_set])[:, 1]

        sample_weights = None

        if 'class_weight' in self.lr_kwargs:
            class_weight = self.lr_kwargs['class_weight']
            sample_weights = np.array([class_weight[y] for y in self.y_train])
            sample_weights = sample_weights / sum(sample_weights)

        return {
            'model': model,
            'log_loss': log_loss(self.y_train, y_pred, sample_weight=sample_weights),
        }

    def forward(self, feature_set: List[str]):
        remaining_features = [p for p in self.X_train.columns if p not in feature_set]
        results = []

        for p in remaining_features:
            r = self.process_subset(feature_set+[p])
            r.update({'new_feature': p})
            results.append(r)

        models = pd.DataFrame(results).sort_values('log_loss').reset_index(drop=True)
        return models.loc[models['log_loss'].values.argmin()]

    def predict_probability(self, X_test):
        if X_test.shape[1] == self.num_features:
            return self.model.predict_proba(X_test)[:, 1]

        return self.model.predict_proba(X_test[self.features])[:, 1]

    def get_coefs(self):
        coefs = get_nonzero_coefs(self.model, self.features)
        coefs = coefs / min(np.abs(coefs))
        coefs = coefs.sort_values(key=lambda x: -abs(x))
        return coefs

    def to_integer_knapsack(self, scale=1, only_one=False) -> IntegerKnapsackRule:
        rules = []
        for name, weight in self.get_coefs().items():
            if not only_one:
                rules.append(Rule.create_from_feature(name, Operator.EQUAL, 1, weight=weight * scale))
            if only_one: 
                rules.append(Rule.create_from_feature(name, Operator.EQUAL, 1, weight=weight / np.abs(weight)))
        return IntegerKnapsackRule(rules, name='Rounding from Subset Selection')

    def to_two_weight_knapsack(self, num_high):
        low_weight = []
        high_weight = []
        for name, weight in self.get_coefs()[num_high:].items():
            low_weight.append(Rule.create_from_feature(name, Operator.EQUAL, 1 if weight > 0 else 0))
        for name, weight in self.get_coefs()[:num_high].items():
            high_weight.append(Rule.create_from_feature(name, Operator.EQUAL, 1 if weight > 0 else 0))
        return TwoWeightKnapsackRule(high_weight_rules=high_weight, low_weight_rules=low_weight, name='Rounding from Subset Selection')


def get_coverage_confidence(
    y_probs: pd.Series,
    y_true: pd.Series,
):
    cov, conf = [], []
    for q in np.arange(0.02, 1.00, 0.02):
        y_predict = y_true[y_probs >= np.quantile(y_probs, q)]
        cov.append(y_predict.sum() / len(y_true))
        conf.append(y_predict.mean())
    return cov, conf


# class ComplementAUCGreedyStart(ComplementAUC):   
#     def get_start(self) -> TwoWeightKnapsackRule: 
#         """
#         Get the starting knapsack by using a greedy search algorithm.
#         """
        
#         universe_of_conditions = [
#             Rule.create_from_feature(f, Operator.EQUAL, 1)
#             for f in self.dataset.get_X_train().columns 
#         ] + [
#             Rule.create_from_feature(f, Operator.EQUAL, 0)
#             for f in self.dataset.get_X_train().columns 
#         ]
        
#         features_taken = []
        
#         rule_set = []
#         current_score = 0

#         for i in range(self.params.N):
#             best_new_set = None
#             for u in universe_of_conditions:
#                 current_best_copy = rule_set.copy()
#                 current_best_copy += [u]
                
#                 evaluation = self.evaluate_rule_train(
#                     TwoWeightKnapsackRule(current_best_copy, []),
#                 )

#                 if evaluation >= current_score: 
#                     best_new_set = current_best_copy.copy()
#                     current_score = evaluation 
            
#             if best_new_set is not None:
#                 rule_set = best_new_set.copy()
#                 print(current_score)
#             else: 
#                 rule_set = current_best_copy.copy()
            
#         return TwoWeightKnapsackRule(rule_set, [], name='Start')
