
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from functools import lru_cache

from typing import List
from rules import Rule, Operator

class Dataset: 
    def __init__(
        self, 
        random_seed: int=0, 
        test_size: float=0.7, 
        model_choice='RF',
        verbose=False,
    ): 
        self.random_seed = random_seed

        X = pd.get_dummies(self.get_X_raw(), prefix_sep='==')
        y = self.get_y()

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=test_size, random_state=random_seed, stratify=y)
        
        self.features = self.X_train.columns          
        self.X_train = self.X_train[self.features]
        self.X_test = self.X_test[self.features]
        
        if model_choice == 'RF':
            self.model = RandomForestClassifier(max_depth=5, n_estimators=300)
        if model_choice == 'LR':
            self.model = LogisticRegression(C=0.1, max_iter=1000)
        if model_choice == 'SVC':
            self.model = SVC(kernel='linear', probability=True)
        self.model.fit(self.X_train, self.y_train) 

        if verbose: 
            print('Loaded Dataset with %d columns => %d rules' % 
                 (len(X.columns), len(self.rule_candidates)))
    
    def get_X_raw(self) -> pd.DataFrame:
        raise NotImplemented 
        
    def get_y_raw(self) -> pd.Series:
        raise NotImplemented 

    def get_X_train(self) -> pd.DataFrame: 
        return self.X_train.reset_index(drop=True)
    
    def get_X_test(self) -> pd.DataFrame: 
        return self.X_test.reset_index(drop=True)
    
    def get_y_train(self) -> pd.Series: 
        return self.y_train.reset_index(drop=True)
    
    def get_y_test(self) -> pd.Series: 
        return self.y_test.reset_index(drop=True)
    
    @lru_cache(maxsize=5)
    def get_y_train_probs(self) -> np.ndarray:
        return self.model.predict_proba(self.X_train)[:, 1]
    
    @lru_cache(maxsize=5)
    def get_y_test_probs(self) -> np.ndarray:
        return self.model.predict_proba(self.X_test)[:, 1]

    @lru_cache(maxsize=5)
    def get_y_train_quantile(self) -> np.ndarray:
        return self.get_quantile(self.get_y_train_probs())
    
    @lru_cache(maxsize=5)
    def get_y_test_quantile(self) -> np.ndarray:
        return self.get_quantile(self.get_y_test_probs())
    
    @lru_cache(maxsize=5)
    def get_y_train_preds(self, q = 0.5) -> np.ndarray:
        y_pred = self.get_y_train_probs()
        return (y_pred >= np.quantile(y_pred, q)).astype(int)
    
    @lru_cache(maxsize=5)
    def get_y_test_preds(self, q=0.5) -> np.ndarray:
        y_pred = self.get_y_test_probs()
        return (y_pred >= np.quantile(y_pred, q)).astype(int)
        
    def get_quantile(self, predictions: np.ndarray): 
        predictions = np.argsort(predictions)
        output = np.array([None for _ in range(len(predictions))])
        for i, x in enumerate(predictions): 
            output[x] = i
        return output / len(output)

    
class Recidivism(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_recid.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_recid.csv')['Recidivism_Within_3years']
    
class Diabetes(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_diabetes.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_diabetes.csv')['Diabetes_binary']

class FICO(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_fico.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_fico.csv')['RiskPerformance']
    
class Schizo(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_schizo.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_schizo.csv')['label']
    
class Adults(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_adult.csv')
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_adult.csv')['income']

class Readmission(Dataset):
    def get_X_raw(self) -> pd.DataFrame:
        return pd.read_csv('data/X_readmission.csv', low_memory=False).iloc[:20000]
    def get_y(self) -> pd.Series:
        return pd.read_csv('data/y_readmission.csv').iloc[:20000]['readmitted'] == '>30'
