import copy
import itertools
from abc import ABC
from typing import Any, List
from functools import cached_property, lru_cache

import numpy as np
import pandas as pd
import tqdm
import random

from rules import Operator, Rule, Condition
from datasets import Dataset
from dataclasses import dataclass, fields
from sa import BaseAlgorithmParams, BaseAlgorithm
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn import tree

class ORRule:
    """
    Class representing the OR of many Rules
    """
    
    def __init__(self, rule_list: List[Rule], name=''):
        self.rule_list = rule_list 
        self.name = name
        
    def __str__(self): 
        result = "An OR Rule of the %d Rules" % len(self.rule_list)
        for r in self.rule_list: 
            result += '\n' + str(r)
        return result
    
    def __repr__(self): 
        return self.__str__()
    
    def __getitem__(self, item): 
        return self.rule_list[item]

    def get_mask(self, X: pd.DataFrame) -> pd.Series: 
        S = len(self.rule_list)
        mask = pd.Series([0 for _ in range(len(X))], index=X.index) 
        for i, r in enumerate(self.rule_list[::-1]): 
            apply = np.where(r.get_mask(X))[0]
            mask[apply] = i+1
        
        return mask
    
    def apply(self, X: pd.DataFrame, y: pd.Series) -> pd.Series:
        mask = self.get_mask(X)
        values_in_mask = list(set(mask))
        values_in_mask.sort()

        data = []
        for i in values_in_mask:      
            data.append({
                'cutoff': i,
                'support': np.mean(mask >= i),
                'confidence': y[mask >= i].mean(),
            })
            
        return pd.DataFrame(data)

@dataclass
class RuleListMinerParams(BaseAlgorithmParams): 
    max_depth: int = 2
    min_support: float = 0.025
    tolerance: float = 0.1
    use_quantile: bool = True
    
    
class RuleListMiner(BaseAlgorithm):
    def get_pool(self) -> List[Rule]:     
        y_train_target = self.dataset.get_y_train_quantile() if self.params.use_quantile else self.dataset.get_y_train_quantile() > 0.5
    
        self.all_rules = [] 

        rfc = RandomForestRegressor(max_depth=self.params.max_depth, n_estimators=300 if self.params.max_depth == 2 else 50, max_features=0.5)
        rfc.fit(self.dataset.get_X_train(), [0 if y < 0.5 else y for y in y_train_target])

        for est in rfc.estimators_:
            est.feature_names_in_ = rfc.feature_names_in_
            self.all_rules += parse_tree_to_rules(est, self.dataset)

        rules_table = []

        for r in self.all_rules: 
            msk = r.get_mask(self.dataset.get_X_train())
            msk_test = r.get_mask(self.dataset.get_X_test())

            rules_table.append({
                'rule': r,
                'mask': msk,
                'mask_test': msk_test,
                'support': np.mean(msk),
                'abbr': self.dataset.get_y_train_quantile()[msk].mean(),
                'consistency': self.dataset.get_y_train_preds()[msk].mean(),
            })

        rules_table = pd.DataFrame(rules_table).sort_values('support')
        rules_table = rules_table[(rules_table.support > self.params.min_support) & (rules_table.support < 0.25) & (rules_table.abbr > 0.50)]
        rules_table.loc[:, 'best_abbr'] = rules_table['abbr'][::-1].cummax()[::-1]
        
        self.rules_table = rules_table

        self.pareto_rules = rules_table[
            (1 + self.params.tolerance) * rules_table['abbr'] >= rules_table['best_abbr']
        ].sort_values(['abbr', 'support'], ascending=False).reset_index(drop=True)
        
        return list(self.pareto_rules['rule'])
        
    def get_start(self) -> ORRule: 
        return ORRule(rule_list=self.pool[:self.params.N], name='Start')


    def get_neighbor(self, rule: ORRule):
        rule_list_copy = rule.rule_list.copy()
        
        n1 = random.randint(0, len(rule_list_copy) - 1)
        n2 = random.randint(0, len(self.pool) - 1) 
        
        name1 = rule_list_copy[n1].conditions_list[0].field_name 
        name2 = self.pool[n2].conditions_list[0].field_name
        
        rule_list_copy[n1] = self.pool[n2] 
        
        return ORRule(rule_list=rule_list_copy, name='Swap %s for %s' % (name1, name2))


    
def parse_tree_to_rules(dtr: DecisionTreeRegressor, dataset: Dataset) -> list[Rule]:
    tree_ = dtr.tree_
    feature_names = dtr.feature_names_in_
    leaf_nodes_info = []
    
    def get_leaf_nodes(node_id=0, parent_split=None):
        if tree_.feature[node_id] != tree._tree.TREE_UNDEFINED:  # Not a leaf node
            feature = feature_names[tree_.feature[node_id]]
            threshold = tree_.threshold[node_id]
            left_child = tree_.children_left[node_id]
            right_child = tree_.children_right[node_id]
            
            leaf_nodes_info.append((node_id, parent_split))
            
            # Add current split information to parent_split for child nodes
            current_split = parent_split + [(feature, '<=', threshold)] if parent_split else [(feature, '<=', threshold)]

            get_leaf_nodes(left_child, current_split)

            current_split = parent_split + [(feature, '>', threshold)] if parent_split else [(feature, '>', threshold)]

            get_leaf_nodes(right_child, current_split)
        else:  # Leaf node
            leaf_nodes_info.append((node_id, parent_split))

    get_leaf_nodes()

    leaf_node_rules = []

    # Print the leaf nodes information
    for leaf in leaf_nodes_info:
        if leaf[1]:
            lst = []
            for split in leaf[1]:
                lst.append(Condition(split[0], Operator.LESS if split[1] == '<=' else Operator.GREATER, split[2]))
            leaf_node_rules.append(Rule(lst))

    leaf_node_rules.sort(key=lambda r: -dataset.get_y_train_quantile()[r.get_mask(dataset.get_X_train())].mean())
    return leaf_node_rules