import numpy as np
import pandas as pd

from enum import Enum
from functools import lru_cache
from typing import Any, List


class Operator(Enum):
    EQUAL = '='
    GREATER = '>='
    LESS = '<='

class Condition:
    """
    A single <= = or >= operation on a particular field with a particular weight.
    """
    def __init__(self, field_name: str, operator: Operator, value: Any):
        self.field_name = field_name
        self.operator = operator
        self.value = value

    def __str__(self):
        return "Rule: %s %s %s" % (self.field_name, self.operator.value, self.value)

    def __repr__(self):
        return self.__str__()

    def get_mask(self, df: pd.DataFrame) -> pd.Series:
        if self.operator == Operator.GREATER:
            return df[self.field_name] >= self.value
        if self.operator == Operator.LESS:
            return df[self.field_name] <= self.value
        return df[self.field_name] == self.value
    
    @classmethod 
    def parse_condition(cls, field_name): 
        if '_rev' in field_name:
            return Condition(field_name=field_name[:-4], operator=Operator.EQUAL, value=0)
        else: 
            return Condition(field_name=field_name, operator=Operator.EQUAL, value=1)

class Rule:
    """
    Class representing the AND of many conditions
    """
    def __init__(self, conditions_list: List[Condition], weight: float = 1):
        self.conditions_list: List[Condition] = conditions_list
        self.weight = weight

    def __str__(self):
        if len(self.conditions_list) == 1:
            return str(self.conditions_list[0]) + " (weight=%.2f)" % self.weight

        result = "An AND Rule of %d conditions (weight=%.2f)" % (len(self.conditions_list), self.weight)
        for r in self.conditions_list:
            result += '\n' + str(r)
        return result

    def __repr__(self):
        return self.__str__()

    def __getitem__(self, item):
        return self.conditions_list[item]

    def get_mask(self, X: pd.DataFrame) -> pd.Series:
        if len(self.conditions_list) == 1:
            return self.conditions_list[0].get_mask(X)

        mask = pd.Series([True for _ in range(len(X))], index=X.index)
        for c in self.conditions_list:
            mask = mask & c.get_mask(X)
        return mask

    def apply(self, X: pd.DataFrame, y: pd.Series):
        mask = self.get_mask(X)
        return pd.DataFrame([{
            'confidence': y[mask].mean(),
            'coverage': mask.mean(),
            'cutoff': 1,
        }])

    def features(self) -> List[str]:
        return [c.field_name for c in self.conditions_list]

    @classmethod
    def create_from_feature(cls, field_name: str, operator: Operator, value: float, weight: float=1.0):
        return Rule([Condition(field_name, operator, value)], weight=weight)
    
    @classmethod
    def parse_rule(cls, fields: List[str], weight: float=1.0):
        return Rule([Condition.parse_condition(f) for f in fields], weight=weight)


class IntegerKnapsackRule:
    def __init__(self, rules: List[Rule], name=''):
        self.rules = rules
        self.name = name

    def __repr__(self):
        s = "An integer knapsack (name=%s) with %d rules\n" % (self.name, len(self.rules))
        for r in self.rules:
            s += "Weight: %.1f - %s\n" % (r.weight, r)
        return s

    def get_mask(self, X: pd.DataFrame) -> np.ndarray:
        mask = np.zeros(len(X))
        for r in self.rules:
            mask += r.get_mask(X) * r.weight

        return mask

    def apply(self, X: pd.DataFrame, y: pd.Series = None) -> pd.DataFrame:
        mask = self.get_mask(X)
        values_in_mask = list(set(mask))
        values_in_mask.sort()

        data = []
        for i in values_in_mask:
            ys = y[mask >= i]
            coverage = np.sum(ys) / np.sum(y)
            coverage = len(ys) / len(y)
            if coverage < 0.002:
                pass
            
            data.append({
                'cutoff': i,
                'confidence': np.mean(ys),
                'coverage': coverage,
            })

        return pd.DataFrame(data)

    def __len__(self) -> int:
        return len(self.rules)

    def get_all_features(self) -> List[str]:
        results = []
        for r in self.rules:
            results += r.features()
        return results

