import scipy
import numpy as np
from scipy.linalg import lstsq
from scipy.linalg import norm 
import pandas as pd
from sklearn import linear_model
import sklearn.metrics as metrics
from sklearn.base import clone
import matplotlib.pyplot as plt
import seaborn as sns
import os
from sklearn.preprocessing import PolynomialFeatures
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer



# Game helper functions -----

def isPowerOfTwo(n):
    return (n != 0) and ((n & (n-1))== 0)

def edge(a, b):
    # there is an edge from a to b in the oriented incidence matrix on the hypercube
    # if a XOR b is a power of 2
    return isPowerOfTwo(a ^ b)

def get_d(F, v = -1):
    dim = 2**F
    oper = []
    for i in range(dim):
        for j in range(i+1,dim):
            row = [0]*dim
            if edge(i,j):
                index = int(np.log2(abs(i-j)))
                consider_edge = (v == -1) or (index == v)
                # Each row will contain all zeros except for a -1 for the source and 1 for the destination. 
                row[i] = -1 if consider_edge else 0
                row[j] = 1 if consider_edge else 0
                oper.extend(row)
                # now decide which partial derivative this goes into. 
    # stack the matrices. It's more efficient to create a flat list and then shape it.
    D = np.array(oper).reshape(-1, dim)
    return D

def getShapleyProjection(v, F=None):
    if F is None:
        F = int(np.log2(len(v)))
    dim = len(v)
    oper = []
    dels = [list() for _ in range(F)]
    for i in range(dim):
        for j in range(i+1,dim):
            row = [0]*dim
            if edge(i,j):
                # Each row will contain all zeros except for a -1 for the source and 1 for the destination. 
                row[i] = -1
                row[j] = 1
                oper.extend(row)
                # now decide which partial derivative this goes into. 
                index = int(np.log2(abs(i-j)))
                try:
                    assert(2**index == abs(i-j))
                    for k in range(F):
                        if k == index:
                            dels[k].extend(row)
                        else:
                            dels[k].extend([0]*dim)
                except AssertionError as error:
                    print("{} is not {}".format(index, np.log(abs(i-j))))

    # stack the matrices. It's more efficient to create a flat list and then shape it.
    D = np.array(oper).reshape(-1, dim)
    partials = np.stack([np.array(m).reshape(-1, dim) for m in dels])

    dels = np.transpose(partials.dot(v))

    (vals, residues, rank, singular) = lstsq(D, dels)
    # vals = component games
    # D.dot(vals) = "gradient" of component games
    # dels = "partial gradients" of original game
    results = vals - vals[0]
    residualGame = D.dot(vals) - dels
    origGame = D.dot(v)
    return results, residualGame, origGame

def getShapleyResiduals(v, F=None):
    if F is None:
        F = int(np.log2(len(v)))
    results, residualGame, origGame = getShapleyProjection(v, F)
    return np.flip(results[-1]), norm(residualGame)/norm(origGame)

def getShapleyPartialResiduals(v, F=None):
    if F is None:
        F = int(np.log2(len(v)))
    results, residualGame, origGame = getShapleyProjection(v, F)
    return np.flip(results[-1]), np.flip(norm(residualGame, axis = 0)/norm(origGame, axis = 0))

# Shapley explanations for models ------

class RegressionGame():
    def __init__(self, X, y = None, mdl = None, function = None, transform = lambda x: x):
        self.X = X
        self.F = X.shape[1]
        self.dim = 2**self.F
        self.y = y
        self.mdl = mdl
        self.function = function
        self.transform = transform
        self.trainedModels = []
        self.MSEs = []
        self.imputer = self.fitImputer(X)
        if (mdl is not None):
            self.makeModelList()
            self.makeMSEList()


    # helper functions -----
        
    # Use index to generate binary representation of a set
    # Parameters: i = integer
    # Returns: i in binary
    def makeKey(self, i):
        return np.binary_repr(i, self.F)
    
    # Return dataset with only columns in the index's set
    # Parameters: i = integer
    # Returns: training data with columns subsetted using binary representation
    def getTrainingData(self, i):
        key = self.makeKey(i)
        S = np.array(list(key), dtype = int).astype(bool)
        return self.X[:,S]

    # Train model corresponding with index
    # Parameters: i = integer
    # Returns: trained model
    def makeModelEntry(self, i):
        if (i == 0):
            return None
        else:
            mdl_S = clone(self.mdl, safe = False)
            X_tr = self.getTrainingData(i)
            mdl_S.fit(self.transform(X_tr), self.y)
            return mdl_S

    # Train all models
    def makeModelList(self):
        self.trainedModels = list(map(self.makeModelEntry, range(self.dim)))

    # Fit imputer for conditional SHAP
    # Parameters: X = n x F numpy array
    # Returns: Iterative imputer trained on X
    def fitImputer(self, X):
        imp = IterativeImputer(random_state = 0, max_iter = 30, sample_posterior = True)
        imp.fit(self.transform(X))
        return imp

    # Predict using "full" trained model or function
    # Parameters: X = n x F numpy array
    # Returns: vector of predictions using model trained on all features
    def getWholePrediction(self, X):
        if (self.function is not None):
            return self.function(self.transform(X))
        else:
            return self.trainedModels[self.dim-1].predict(self.transform(X))
    
    # Predict using model trained on features in indexed subset
    # Parameters: i = integer, X = n x F numpy array
    # Returns: Predictions on X using model trained on features indicated by binary rep of i
    def getRetrainedPrediction(self, i, X):
        if (i == 0):
            return np.array(np.mean(self.y))
        else:
            key = self.makeKey(i)
            S = np.array(list(key), dtype = int).astype(bool)
            X_modified = X[:,S]
            return self.trainedModels[i].predict(self.transform(X_modified))  

    # Predict by masking features not in indexed subset with that from their marginal distributions
    # Parameters: i = integer, x = 1 x F numpy array, X_bg = n x F numpy array
    # Returns: Prediction on X_bg with unmasked features replaced by the value in x
    def getKernelSHAPPrediction(self, i, x, X_bg):
        key = self.makeKey(i)
        S = np.array(list(key), dtype = int)
        replace = (-1 * S + 1).astype(bool)
        X_modified = np.tile(x, (X_bg.shape[0], 1))#.astype(np.float64)
        X_modified[:,replace] = X_bg[:,replace]
        return np.mean(self.getWholePrediction(X_modified)).reshape((1))

    # Predict by masking features not in indexed subset with that from their conditional distributions
    # Parameters: i = integer, x = 1 x F numpy array, nsamples = integer
    # Returns: Prediction on nsamples x F array with masked features imputed from x
    def getConditionalSHAPPrediction(self, i, x, nsamples = 100):
        key = self.makeKey(i)
        S = np.array(list(key), dtype=int)
        replace = (-1 * S + 1).astype(bool)
        X_modified = np.tile(x, (nsamples, 1))#.astype(np.float64)
        X_modified[:,replace] = np.nan
        X_modified = self.imputer.transform(self.transform(X_modified))
        return np.mean(self.getWholePrediction(X_modified)).reshape((1))

    # Evaluate error of the model trained on indexed subset
    # Parameters: i = integer
    # Returns: MSE of model trained on features indicated by i
    def getMSE(self, i):
        if (i == 0):
            # predict mean for all values
            MSE = metrics.mean_squared_error(self.y, np.full(self.y.shape, np.mean(self.y)))
        else:
            MSE = metrics.mean_squared_error(self.y, self.getRetrainedPrediction(i, self.X))
        return MSE

    # Evaluate error on all models
    def makeMSEList(self):
        self.MSEs = list(map(self.getMSE, range(self.dim)))


    # game methods ------

    def getAccuracyGame(self):
        return np.array(self.MSEs)
    
    def getRetrainedGame(self, x):
        predGameMap = map(lambda z: self.getRetrainedPrediction(z, x).reshape((1)),
                          range(self.dim))
        return np.concatenate(list(predGameMap))

    def getKernelSHAPGame(self, x, X_bg = None):
        if (X_bg is None):
            X_bg = self.X
        predGameMap = map(lambda z: self.getKernelSHAPPrediction(z, x, X_bg).reshape((1)),
                          range(self.dim))
        return np.concatenate(list(predGameMap))

    def getConditionalSHAPGame(self, x):
        predGameMap = map(lambda z: self.getConditionalSHAPPrediction(z, x).reshape((1)),
                          range(self.dim))
        return np.concatenate(list(predGameMap))



# Kernel learning class ------

class LinearRegressionWithFeatureEngineering():
    def __init__(self, transform):
        self.model = linear_model.LinearRegression()
        self.dim = None
        self.transform = transform

    def fit(self, X, y):
        self.dim = X.shape[1]
        X_inter = self.transform(X)
        self.model.fit(X_inter, y)

    def predict(self, X):
        return self.model.predict(self.transform(X))

