import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import argparse
import time
import os
import copy

from datasets import Recidivism, FICO, Dataset, Schizo, Adults, Diabetes, Readmission
from rules import ORRule, Rule, Condition, Operator
from complement import Complement, ComplementParams
from consistency import Consistency, ConsistencySoft, CoverageConsistencyParams
from dataclasses import dataclass, replace
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor

from scipy.stats import norm

os.chdir('/home/evanyao/paper')

def parse_tree_to_rules(dtr: DecisionTreeRegressor, dataset: Dataset) -> list[Rule]:
    tree_ = dtr.tree_
    feature_names = dtr.feature_names_in_

    leaf_nodes_info = []
    def get_leaf_nodes(node_id=0, parent_split=None):
        if tree_.feature[node_id] != tree._tree.TREE_UNDEFINED:  # Not a leaf node
            feature = feature_names[tree_.feature[node_id]]
            threshold = tree_.threshold[node_id]
            left_child = tree_.children_left[node_id]
            right_child = tree_.children_right[node_id]

            # Add current split information to parent_split for child nodes
            current_split = parent_split + [(feature, '<=', threshold)] if parent_split else [(feature, '<=', threshold)]

            get_leaf_nodes(left_child, current_split)

            current_split = parent_split + [(feature, '>', threshold)] if parent_split else [(feature, '>', threshold)]

            get_leaf_nodes(right_child, current_split)
        else:  # Leaf node
            leaf_nodes_info.append((node_id, parent_split))

    get_leaf_nodes()

    leaf_node_rules = []

    # Print the leaf nodes information
    for leaf in leaf_nodes_info:
        if leaf[1]:
            lst = []
            for split in leaf[1]:
                lst.append(Condition(split[0], Operator.EQUAL, 0 if split[1] == '<=' else 1))
            leaf_node_rules.append(Rule(lst))

    leaf_node_rules.sort(key=lambda r: -dataset.get_y_train_quantile()[r.get_mask(dataset.get_X_train())].mean())
    return leaf_node_rules
 
class GreedyRuleList:
    def __init__(
        self,
        dataset: Dataset,
        max_depth: int,
        n_estimators: int = 20,
        support_lb: float = 0.025,
        tolerance: float = 0.9,
        use_quantile=True,
    ): 
        rfc = RandomForestRegressor(max_depth=max_depth, n_estimators=n_estimators, max_features=0.5)
        
        self.y_train_quantile = dataset.get_y_train_quantile() if use_quantile else dataset.get_y_train_preds()
        rfc.fit(dataset.get_X_train(), [0 if y < 0.5 else y for y in self.y_train_quantile])

        self.all_rules = [] 

        for est in tqdm.tqdm(rfc.estimators_):
            est.feature_names_in_ = rfc.feature_names_in_
            self.all_rules += parse_tree_to_rules(est, dataset)

        rules_table = []

        for r in tqdm.tqdm(self.all_rules): 
            msk = r.get_mask(dataset.get_X_train())
            msk_test = r.get_mask(dataset.get_X_test())

            rules_table.append({
                'rule': r,
                'mask': msk,
                'mask_test': msk_test,
                'support': np.mean(msk),
                'abbr': self.y_train_quantile[msk].mean(),
            })

        rules_table = pd.DataFrame(rules_table).sort_values('support')
        self.rules_table = rules_table

        rules_table = rules_table[rules_table.support > support_lb]
        rules_table['best_abbr'] = rules_table['abbr'][::-1].cummax()[::-1]
                
        self.pareto_rules = rules_table[rules_table['abbr'] >= tolerance * rules_table['best_abbr']].sort_values(['abbr', 'support'], ascending=False).reset_index(drop=True)
        self.dataset = dataset 
        self.support_lb = support_lb


    def get_next_index(self, indices, rules_pool):
        rule_list = ORRule(rules_pool.loc[indices]['rule'])

        covered = reduce(lambda x, y: x | y, rules_pool.loc[indices]['mask'])
        covered_test = reduce(lambda x, y: x | y, rules_pool.loc[indices]['mask_test'])

        iteration_info = pd.DataFrame()
        iteration_info['marginal_additional_coverage'] = rules_pool['mask'].apply(lambda msk: np.mean(msk & ~covered))
        iteration_info['new_coverage'] = rules_pool['mask'].apply(lambda msk: np.mean(msk | covered))
        iteration_info['new_abbr'] = rules_pool['mask'].apply(lambda msk: self.y_train_quantile[msk | covered].mean())

        iteration_info = iteration_info[iteration_info.marginal_additional_coverage >= self.support_lb].sort_values(['new_abbr', 'new_coverage'], ascending=False)
        
        return {
            'next_index': iteration_info.index[0],
            'rule_list': rule_list,
            'train_cover': np.mean(covered),
            'test_cover': np.mean(covered_test),
            'train_abbr': self.dataset.get_y_train_quantile()[covered].mean(),
            'test_abbr': self.dataset.get_y_test_quantile()[covered_test].mean()
        }

    def generate_rule(self, c: float = None, debug=False):
        rules_pool = self.pareto_rules.sample(frac=0.5).sort_index()
        
        indices = [rules_pool.index[0]]
        
        param_dummy = CoverageConsistencyParams(
            num_iter=0, 
            N=5,
            c=c,
            p=c,
            allow_high_low_switch=False,
            should_validate=True,
        )

        alg_dummy = ConsistencySoft(self.dataset, param_dummy, False, skeleton=True)
 
        log = []
        
        while True:
            info = self.get_next_index(indices, rules_pool)
            
            if debug:
                info['eval_in'] = alg_dummy.evaluate_rule_train(info['rule_list'])
                info['eval_out'] = alg_dummy.score_rule(info['rule_list'])
                info['cons_in'] = self.dataset.get_y_train_preds()[info['rule_list'].get_mask(self.dataset.get_X_train()) > 0].mean()
                info['cons_out'] = self.dataset.get_y_test_preds()[info['rule_list'].get_mask(self.dataset.get_X_test()) > 0].mean()

            log.append(info)
            indices.append(info['next_index'])
            
            if info['train_cover'] > c: 
                break
        
        log = pd.DataFrame(log)
        
        return log if debug else log.rule_list.iloc[-1]
    
