from rules import IntegerKnapsackRule, Rule, Condition 
from sa import BaseAlgorithm, BaseAlgorithmParams
from dataclasses import dataclass

import random
import pandas as pd
import numpy as np
from typing import List
from rules import Operator, Condition


class ChecklistRule(IntegerKnapsackRule):
    def __init__(self, conditions: List[Rule], name=''):
        self.conditions = conditions
        self.name = name

        super().__init__(conditions, name)

    def __str__(self):
        cond = "\n".join([str(r) for r in self.conditions])

        return "---A Checklist of the Conditions---\n%s\n---" % (cond)

    def __repr__(self):
        return self.__str__()

    def __len__(self):
        return len(self.conditions)

    
@dataclass
class ChecklistMinerParams(BaseAlgorithmParams): 
    Q: int = 5
    min_support: float = 0.025
    

class ChecklistMiner(BaseAlgorithm):
    def get_pool(self) -> List[Rule]: 
        self.conditions_candidates = []
        
        X = self.dataset.get_X_train()
        y = self.dataset.get_y_train_quantile() 
        
        for col in X.columns: 
            if np.issubdtype(X[col].dtype, np.number) and X[col].nunique() > 2:
                num_quants = min(self.params.Q, X[col].nunique())
                for q in np.arange(1./num_quants, 1, 1./num_quants): 
                    val = np.quantile(X[col], q)
                    if y[X[col] <= val].mean() > y[X[col] >= val].mean():
                        self.conditions_candidates.append(Rule.create_from_feature(col, Operator.LESS, val))
                    else: 
                        self.conditions_candidates.append(Rule.create_from_feature(col, Operator.GREATER, val))
            if np.issubdtype(X[col].dtype, np.number) and X[col].nunique() == 2:
                if y[X[col] == 1].mean() > y[X[col] == 0].mean():
                    self.conditions_candidates.append(Rule.create_from_feature(col, Operator.EQUAL, 1))
                else: 
                    self.conditions_candidates.append(Rule.create_from_feature(col, Operator.EQUAL, 0))     

        self.conditions_candidates_info = [] 
        
        for r in self.conditions_candidates:
            msk = r.get_mask(X)
            self.conditions_candidates_info.append({
                'rule': r,
                'support': np.mean(msk),
                'abbr': y[msk].mean(),
            })
        
        self.conditions_candidates_info = pd.DataFrame(self.conditions_candidates_info).sort_values('abbr', ascending=False)
        
        return list(self.conditions_candidates_info[
            (self.conditions_candidates_info['support'] >= self.params.min_support) &
            (self.conditions_candidates_info['support'] <= 1 - self.params.min_support) &
            (self.conditions_candidates_info['abbr'] > 0.49)
        ].rule)
        
    def get_start(self) -> ChecklistRule: 
        return ChecklistRule(conditions=self.pool[:self.params.N], name='Start')
    
    def get_neighbor(self, rule: ChecklistRule):
        conditions_copy = rule.conditions.copy()
        
        n1 = random.randint(0, len(conditions_copy) - 1)
        n2 = random.randint(0, len(self.pool) - 1) 
        
        name1 = conditions_copy[n1].conditions_list[0].field_name 
        name2 = self.pool[n2].conditions_list[0].field_name
        
        conditions_copy[n1] = self.pool[n2] 
        
        return ChecklistRule(conditions=conditions_copy, name='Swap %s for %s' % (name1, name2))