import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import itertools
from sa import BaseAlgorithm, AlgorithmParams
from rules import TwoWeightKnapsackRule, IntegerKnapsackRule, Rule, Condition, ORRule
from dataclasses import dataclass, fields, replace

@dataclass
class CoverageConsistencyParams(AlgorithmParams):
    c: float
    p: float

class Coverage(BaseAlgorithm):
    def evaluate_rule(
        self, 
        knapsack_rule: TwoWeightKnapsackRule, 
        X: pd.DataFrame, 
        y: pd.Series
    ) -> float: 
        return self.confidence_at_coverage(knapsack_rule, X, y, self.params.c)
    
    def evaluate_rule_train(self, knapsack_rule: TwoWeightKnapsackRule):
        return Coverage.confidence_at_coverage(
            knapsack_rule, 
            self.dataset.get_X_train(), 
            self.dataset.get_y_train(), 
            self.params.c,
        )
     
    def evaluate_rule_valid(self, knapsack_rule: TwoWeightKnapsackRule):
        return Coverage.confidence_at_coverage(
            knapsack_rule, 
            self.dataset.get_X_valid(), 
            self.dataset.get_y_valid(), 
            self.params.c,
        )
     
    def score_rule(self, knapsack_rule: TwoWeightKnapsackRule) -> float: 
        return self.confidence_at_coverage(
            knapsack_rule, 
            self.dataset.get_X_test(), 
            self.dataset.get_y_test(), 
            self.params.c,
        )

    def analyze_rule_metrics(self, info: dict, rule: IntegerKnapsackRule): 
        for c in np.arange(0.025, 1.025, 0.025):
            info[('train_confidence', '%.2f' % c)] = Coverage.confidence_at_coverage(
                rule, 
                self.dataset.get_X_train(), 
                self.dataset.get_y_train(), 
                c,
            )
            
            info[('test_confidence', '%.2f' % c)] = Coverage.confidence_at_coverage(
                rule, 
                self.dataset.get_X_test(), 
                self.dataset.get_y_test(), 
                c,
            )
            
        return info
    
    @classmethod
    def confidence_at_coverage(
        cls,
        knapsack_rule: IntegerKnapsackRule, 
        X: pd.DataFrame, 
        y: pd.Series,
        cov_value: float
    ) -> float:
        """
        Evaluates a rule on the dataset (X, y) and returns the confidence at a given coverage value
        with the appropriate interpolation.
        """
        thresholds = knapsack_rule.apply(X, y) 
        row_first = pd.DataFrame([{'cutoff': 0, 'confidence': np.mean(y), 'coverage': 1.0}])
        row_last = pd.DataFrame([{'cutoff': max(thresholds.cutoff) + 1, 'confidence': 1.0, 'coverage': 0.0}])

        thresholds = pd.concat([row_first, thresholds, row_last])

        if cov_value > thresholds.coverage.max() or cov_value < thresholds.coverage.min():
            return 0

        try:
            left = thresholds[thresholds.coverage <= cov_value].sort_values('coverage').iloc[-1]
            right = thresholds[thresholds.coverage >= cov_value].sort_values('coverage').iloc[0]
        except:
            return 0

        x = (cov_value - left.coverage) / (right.coverage - left.coverage + 0.001)
        return (x * right.coverage * right.confidence + (1-x) * left.coverage * left.confidence) / (x * right.coverage + (1-x) * left.coverage)
    

class Consistency(Coverage):
    def evaluate_rule(
        self, 
        knapsack_rule: TwoWeightKnapsackRule, 
        X: pd.DataFrame, 
        y: pd.Series
    ) -> float: 
        return Coverage.confidence_at_coverage(knapsack_rule, X, y, self.params.c)
    
    def evaluate_rule_train(self, knapsack_rule: TwoWeightKnapsackRule) -> float:
        return Coverage.confidence_at_coverage(
            knapsack_rule, 
            self.dataset.get_X_train(), 
            self.dataset.get_y_train_preds(1 - self.params.p), 
            self.params.c,
        )
     
    def evaluate_rule_valid(self, knapsack_rule: TwoWeightKnapsackRule) -> float:
        return Coverage.confidence_at_coverage(
            knapsack_rule, 
            self.dataset.get_X_valid(), 
            self.dataset.get_y_valid_preds(1 - self.params.p), 
            self.params.c,
        )
     
    def score_rule(self, knapsack_rule: TwoWeightKnapsackRule) -> float: 
        return Coverage.confidence_at_coverage(
            knapsack_rule, 
            self.dataset.get_X_test(), 
            self.dataset.get_y_test_preds(1 - self.params.p), 
            self.params.c,
        )

    def analyze_rule_metrics(self, info: dict, rule: IntegerKnapsackRule): 
        return
        for p in np.arange(0.025, 1.025, 0.025):
            for c in np.arange(0.025, p + 0.025, 0.025):
                info[('train_consistency', 'c=%.2f' % c)] = Coverage.confidence_at_coverage(
                    rule, 
                    self.dataset.get_X_train(), 
                    self.dataset.get_y_train_preds(1 - p), 
                    c,
                )

                info[('test_consistency', '%.2f' % c)] = Coverage.confidence_at_coverage(
                    rule, 
                    self.dataset.get_X_test(), 
                    self.dataset.get_y_test_preds(1 - p), 
                    c,
                )


class ConsistencySoft(Coverage):
    def get_preds(
        self, 
        knapsack_rule: TwoWeightKnapsackRule, 
        X: pd.DataFrame, 
        y: pd.Series,
    ):
        y_preds = self.convert_to_quantile(knapsack_rule.get_mask(X))
        return y[y_preds >= 1 - self.params.c] 
    
    def evaluate_rule(
        self, 
        knapsack_rule: TwoWeightKnapsackRule, 
        X: pd.DataFrame, 
        y: pd.Series,
    ) -> float: 
#         y_preds = self.convert_to_quantile(knapsack_rule.get_mask(X))
#         return np.mean(y[y_preds >= 1 - self.params.c])
        
        vals = self.get_preds(knapsack_rule, X,  y) 
        return np.mean(vals) 
    
    def evaluate_rule_train(self, knapsack_rule: TwoWeightKnapsackRule) -> float:
        return self.evaluate_rule(
            knapsack_rule, 
            self.dataset.get_X_train(), 
            self.dataset.get_y_train_quantile(),
        )
     
    def evaluate_rule_valid(self, knapsack_rule: TwoWeightKnapsackRule) -> float:
        # return self.evaluate_rule_train(knapsack_rule)
        return self.evaluate_rule(
            knapsack_rule, 
            self.dataset.get_X_train(), 
            self.dataset.get_y_train_quantile(),
        )     
    
    def score_rule(self, knapsack_rule: TwoWeightKnapsackRule) -> float: 
        return self.evaluate_rule(
            knapsack_rule, 
            self.dataset.get_X_test(), 
            self.dataset.get_y_test_quantile(),
        )

    def analyze_rule_metrics(self, info: dict, rule: IntegerKnapsackRule): 
        for c in np.arange(0.025, 1.025, 0.025):
            info[('train_consistency_soft', '%.2f' % c)] = np.mean(
                self.dataset.get_y_train_quantile()[
                    self.convert_to_quantile(knapsack_rule.get_mask(self.dataset.get_X_train())) >= 1 - c
                ]
            )
            
            info[('test_consistency_soft', '%.2f' % c)] = np.mean(
                self.dataset.get_y_test_quantile()[
                    self.convert_to_quantile(knapsack_rule.get_mask(self.dataset.get_X_test())) >= 1 - c
                ]
            )