import numpy as np
import torch

from sklearn.metrics import pairwise_distances
from sklearn.metrics import pairwise_kernels
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.neighbors import NearestNeighbors

from drf import drf


class Debiaser():
    def __init__(self, pi_hat_model = "logistic", beta_hat_model = "CME"):

        self.pi_hat_model = pi_hat_model
        self.beta_hat_model = beta_hat_model


    def train_eval_pi_hat(self, X_train, A_train, X_eval):
        pi_hat = None

        if self.pi_hat_model == "logistic":
            pi_hat= LogisticRegression(C=1e5, max_iter=1000).fit(X_train, A_train).predict_proba(X_eval)[:, 1]
        elif self.pi_hat_model == "RF":
            pi_hat= RandomForestClassifier().fit(X_train, A_train).predict_proba(X_eval)[:, 1]
        elif self.pi_hat_model == "boosting":
            pi_hat= AdaBoostClassifier().fit(X_train, A_train).predict_proba(X_eval)[:, 1]
        elif self.pi_hat_model == "experiment":
            pi_hat = np.zeros(len(X_eval)) + 0.5
        else:
            print("Pi hat model not found!")

        return pi_hat
    
    def train_eval_beta_hat(self, X_train, A_train, Y_train, X_eval, **kwargs):
        beta_hat = None
        
        if self.beta_hat_model == "CME":
            beta_hat = train_eval_CME(X_train, A_train, X_eval, **kwargs)
        elif self.beta_hat_model == "DRF":
            beta_hat = train_eval_DRF(X_train, A_train, Y_train, X_eval, **kwargs) 
        elif self.beta_hat_model == "1NN":
            beta_hat = train_eval_1NN(X_train, A_train, X_eval) 
        else: 
            print("Beta hat model not found!")

        return beta_hat
    
    def train_pi_hat(self, X, A):

        n = len(X)
        n2 = n // 2
        X1, X2 = X[:n2], X[n2:]
        A1, A2 = A[:n2], A[n2:]
        self.pi_hat_1 = self.train_eval_pi_hat(X_train = X2, A_train = A2, X_eval = X1)
        self.pi_hat_2 = self.train_eval_pi_hat(X_train = X1, A_train = A1, X_eval = X2)

    def train_beta_hat(self, X, A, Y, **kwargs):

        n = len(X)
        n2 = n // 2
        X1, X2 = X[:n2], X[n2:]
        A1, A2 = A[:n2], A[n2:]
        Y1, Y2 = Y[:n2], Y[n2:]
        self.beta_hat_1 = np.zeros((n, n2))
        self.beta_hat_2 = np.zeros((n, n-n2))
        self.beta_hat_1[n2:, :] = self.train_eval_beta_hat(X_train = X2, A_train = A2, Y_train = Y2, X_eval = X1, **kwargs) 
        self.beta_hat_2[:n2, :] = self.train_eval_beta_hat(X_train = X1, A_train = A1, Y_train = Y1, X_eval = X2, **kwargs)
        
    def find_weights_AIPW(self, X, A, Y, trained = False, **kwargs):

        n = len(X)
        n2 = n // 2
        X1, X2 = X[:n2], X[n2:]
        A1, A2 = A[:n2], A[n2:]
        Y1, Y2 = Y[:n2], Y[n2:]

        if trained:
            pi_hat_1 = self.pi_hat_1
            pi_hat_2 = self.pi_hat_2
            beta_hat_1 = self.beta_hat_1
            beta_hat_2 = self.beta_hat_2 
        else:
            pi_hat_1 = self.train_eval_pi_hat(X_train = X2, A_train = A2, X_eval = X1)
            pi_hat_2 = self.train_eval_pi_hat(X_train = X1, A_train = A1, X_eval = X2)
            beta_hat_1 = np.zeros((n, n2))
            beta_hat_2 = np.zeros((n, n-n2))
            beta_hat_1[n2:, :] = self.train_eval_beta_hat(X_train = X2, A_train = A2, Y_train = Y2, X_eval = X1, **kwargs) 
            beta_hat_2[:n2, :] = self.train_eval_beta_hat(X_train = X1, A_train = A1, Y_train = Y1, X_eval = X2, **kwargs)
        
        # print("A1/pi_hat_2.shape: ", (A1/pi_hat_2).shape)
        # print("A1/pi_hat_2: ", A1/pi_hat_2)
        # print("beta_hat_2: ", beta_hat_2)
        # print("A1/pi_hat_2 * beta_hat_2: ", A1/pi_hat_2 * beta_hat_2)
        correction_bias1 = (A1/pi_hat_1)*(np.vstack((np.eye(n2), np.zeros((n - n2, n2)))) - beta_hat_1)
        weights1 = (torch.tensor(beta_hat_1) + correction_bias1).mean(1)
        correction_bias2 = (A2/pi_hat_2)*(np.vstack((np.zeros((n2, n - n2)), np.eye(n - n2))) - beta_hat_2)
        weights2 = (torch.tensor(beta_hat_2) + correction_bias2).mean(1)


        #weights = torch.outer(weights1, weights2).flatten()
        weights = (weights1 + weights2) / 2
        weights = torch.outer(weights, weights).flatten()
        return weights
    
    
    def find_weights_PI(self, X, A, Y, trained = False, **kwargs):

        n = len(X)
        n2 = n // 2
        X1, X2 = X[:n2], X[n2:]
        A1, A2 = A[:n2], A[n2:]
        Y1, Y2 = Y[:n2], Y[n2:]

        if trained:
            beta_hat_1 = self.beta_hat_1
            beta_hat_2 = self.beta_hat_2
        else:   
            beta_hat_1 = np.zeros((n, n2))
            beta_hat_2 = np.zeros((n, n-n2))
            beta_hat_1[n2:, :] = self.train_eval_beta_hat(X_train = X2, A_train = A2, Y_train = Y2, X_eval = X1, **kwargs) 
            beta_hat_2[:n2, :] = self.train_eval_beta_hat(X_train = X1, A_train = A1, Y_train = Y1, X_eval = X2, **kwargs)

        weights1 = torch.tensor(beta_hat_1).mean(1)
        weights2 = torch.tensor(beta_hat_2).mean(1)

        #weights = torch.outer(weights1, weights2).flatten()
        weights = (weights1 + weights2) / 2
        weights = torch.outer(weights, weights).flatten()
        return weights
    

    def find_weights_IPW(self, X, A, Y, trained = False):

        n = len(X)
        n2 = n // 2
        X1, X2 = X[:n2], X[n2:]
        A1, A2 = A[:n2], A[n2:]
        Y1, Y2 = Y[:n2], Y[n2:]

        if trained:
            pi_hat_1 = self.pi_hat_1
            pi_hat_2 = self.pi_hat_2
        else:
            pi_hat_1 = self.train_eval_pi_hat(X_train = X2, A_train = A2, X_eval = X1)
            pi_hat_2 = self.train_eval_pi_hat(X_train = X1, A_train = A1, X_eval = X2)
        
        weights1 = (A1/pi_hat_1)*np.vstack((np.eye(n2), np.zeros((n - n2, n2))))
        weights2 = (A2/pi_hat_2)*np.vstack((np.zeros((n2, n - n2)), np.eye(n - n2)))

        #weights = torch.outer(weights1.mean(1), weights2.mean(1)).flatten()
        weights = (weights1 + weights2) / 2
        weights = weights1.mean(1) 
        weights = torch.outer(weights, weights).flatten()
        return weights
    
    
    def find_weights_IPW_PI_AIPW(self, X, A, Y, **kwargs):

        n = len(X)
        n2 = n // 2
        X1, X2 = X[:n2], X[n2:]
        A1, A2 = A[:n2], A[n2:]
        Y1, Y2 = Y[:n2], Y[n2:]
    
        self.train_pi_hat(X, A)
        self.train_beta_hat(X, A, Y, **kwargs)

        weights_IPW = self.find_weights_IPW(X, A, Y, trained = True)
        weights_PI = self.find_weights_PI(X, A, Y, trained = True)
        weights_AIPW = self.find_weights_AIPW(X, A, Y, trained = True)
        return weights_IPW, weights_PI, weights_AIPW


def train_eval_CME(X_train, A_train, X_eval, **kwargs):

    X = X_train[A_train == 1]

    sigmaKX = np.median(pairwise_distances(X, metric='euclidean'))**2
    KX = pairwise_kernels(X, metric='rbf', gamma=1.0/sigmaKX) 
    KX_X_eval = pairwise_kernels(X, X_eval, metric='rbf', gamma=1.0/sigmaKX) 

    if 'gamma' not in kwargs.keys():
        kwargs['gamma'] = sigmaKX

    beta_hat = np.zeros((len(X_train), len(X_eval)))
    beta_hat[A_train == 1] = np.linalg.solve(KX + kwargs['gamma'] * np.eye(len(KX)), KX_X_eval)

    return beta_hat 


def train_eval_DRF(X_train, A_train, Y_train, X_eval, **kwargs):

    if 'min_node_size' not in kwargs['kwargs'].keys():
         kwargs['kwargs']['min_node_size'] = 15
    if 'num_trees' not in kwargs['kwargs'].keys():
         kwargs['kwargs']['num_trees'] = 200

    X = X_train[A_train == 1]
    Y = Y_train[A_train == 1]

    # fit model
    DRF = drf(min_node_size = kwargs['kwargs']['min_node_size'], num_trees = kwargs['kwargs']['num_trees'], splitting_rule = "FourierMMD") #those are the default values
    DRF.fit(X, Y)
    
    beta_hat = np.zeros((len(X_train), len(X_eval)))
    beta_hat[A_train == 1] = DRF.predict(newdata = X_eval).weights.T

    return beta_hat 

def train_eval_1NN(X_train, A_train, X_eval):

    X = X_train[A_train == 1]
    beta_hat = np.zeros((len(X_train), len(X_eval)))
    beta_hat_X = np.zeros((len(X), len(X_eval)))

    neigh = NearestNeighbors(n_neighbors=1)
    neigh.fit(X)
    nn = np.array(neigh.kneighbors(X_eval, return_distance=False)).flatten()
    for i in range(len(X_eval)):
        beta_hat[nn[i], i] = 1
    beta_hat[A_train == 1] = beta_hat_X

    return beta_hat 

