import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from functools import lru_cache

from typing import List
from helpers import SparseLogisticRegression
from rules import Rule, Operator

class Dataset: 
    def __init__(
        self, 
        random_seed: int, 
        Q: int,
        test_size: float=0.5, 
        model_choice='RF',
        verbose=False,
    ): 
        self.random_seed = random_seed

        X = pd.get_dummies(self.get_X_raw(), prefix_sep='==')
        y = self.get_y()
        
        self.rule_candidates = []

        for col in X.columns: 
            if np.issubdtype(X[col].dtype, np.number) and X[col].nunique() > 2:
                num_quants = min(Q, X[col].nunique())
                for q in np.arange(1./num_quants, 1, 1./num_quants): 
                    val = np.quantile(X[col], q)
                    if y[X[col] <= val].mean() > y[X[col] >= val].mean():
                        self.rule_candidates.append(Rule.create_from_feature(col, Operator.LESS, val))
                    else: 
                        self.rule_candidates.append(Rule.create_from_feature(col, Operator.GREATER, val))
            if np.issubdtype(X[col].dtype, np.number) and X[col].nunique() == 2:
                if y[X[col] == 1].mean() > y[X[col] == 0].mean():
                    self.rule_candidates.append(Rule.create_from_feature(col, Operator.EQUAL, 1))
                else: 
                    self.rule_candidates.append(Rule.create_from_feature(col, Operator.EQUAL, 0))     

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=test_size, random_state=random_seed, stratify=y)
        
        self.features = self.X_train.columns          
        self.X_train = self.X_train[self.features]
        self.X_test = self.X_test[self.features]
        
        if model_choice == 'RF':
            self.model = RandomForestClassifier(max_depth=5, n_estimators=300)
        if model_choice == 'LR':
            self.model = LogisticRegression(C=0.1, max_iter=1000)
        if model_choice == 'SVC':
            self.model = SVC(kernel='linear', probability=True)
        self.model.fit(self.X_train, self.y_train) 
        
        self.rule_candidates_info = [] 
        
        for r in self.rule_candidates:
            msk = r.get_mask(self.get_X_train())
            self.rule_candidates_info.append({
                'rule': r,
                'support': np.mean(msk),
                'abbr': self.get_y_train_quantile()[msk].mean(),
            })
        
        self.rule_candidates_info = pd.DataFrame(self.rule_candidates_info).sort_values('abbr', ascending=False)
        self.rule_candidates = list(self.rule_candidates_info[
            (self.rule_candidates_info['support'] >= 0.025) &
            (self.rule_candidates_info['support'] <= 0.95) &
            (self.rule_candidates_info['abbr'] > 0.49)
        ].rule)
        
        if verbose: 
            print('Loaded Dataset with %d columns => %d rules' % 
                 (len(X.columns), len(self.rule_candidates)))
    
    def get_X_raw(self) -> pd.DataFrame:
        raise NotImplemented 
        
    def get_y_raw(self) -> pd.Series:
        raise NotImplemented 

    def get_X_train(self) -> pd.DataFrame: 
        return self.X_train.reset_index(drop=True)
    
    def get_X_test(self) -> pd.DataFrame: 
        return self.X_test.reset_index(drop=True)
    
    def get_y_train(self) -> pd.Series: 
        return self.y_train.reset_index(drop=True)
    
    def get_y_test(self) -> pd.Series: 
        return self.y_test.reset_index(drop=True)
    
    @lru_cache(maxsize=5)
    def get_y_train_probs(self) -> np.ndarray:
        return self.model.predict_proba(self.X_train)[:, 1]
    
    @lru_cache(maxsize=5)
    def get_y_test_probs(self) -> np.ndarray:
        return self.model.predict_proba(self.X_test)[:, 1]

    @lru_cache(maxsize=5)
    def get_y_train_quantile(self) -> np.ndarray:
        return self.get_quantile(self.get_y_train_probs())
    
    @lru_cache(maxsize=5)
    def get_y_test_quantile(self) -> np.ndarray:
        return self.get_quantile(self.get_y_test_probs())
    
    @lru_cache(maxsize=5)
    def get_y_train_preds(self, q = 0.5) -> np.ndarray:
        y_pred = self.get_y_train_probs()
        return (y_pred >= np.quantile(y_pred, q)).astype(int)
    
    @lru_cache(maxsize=5)
    def get_y_test_preds(self, q=0.5) -> np.ndarray:
        y_pred = self.get_y_test_probs()
        return (y_pred >= np.quantile(y_pred, q)).astype(int)
        
    def get_quantile(self, predictions: np.ndarray): 
        predictions = np.argsort(predictions)
        output = np.array([None for _ in range(len(predictions))])
        for i, x in enumerate(predictions): 
            output[x] = i
        return output / len(output)

    
class Recidivism(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_recid.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_recid.csv')['Recidivism_Within_3years']
    
class Diabetes(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_diabetes.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_diabetes.csv')['Diabetes_binary']

class FICO(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_fico.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_fico.csv')['RiskPerformance']
    
class Schizo(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_schizo.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_schizo.csv')['label']
    
class Adults(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_adult.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_adult.csv')['income']

class Readmission(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_readmission.csv', low_memory=False).iloc[:20000]
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_readmission.csv').iloc[:20000]['readmitted'] == '>30'