import os
import numpy as np
import random
import torch
import pandas as pd
import torch
import torch.nn.functional as F

SQRT_CONSTL = 1e-10
SQRT_CONSTR = 1e10

def stable_log1pex(x):
    return -torch.min(x, torch.tensor(0.)) + torch.log1p(torch.exp(-torch.abs(x)))

class Config:
    model = 'CFR-TF'
    alpha = 0.005
    beta = 0.01
    lrate = 0.001
    step_put = 100
    iterations = 3000
    batch_size = 5000
    batch_flag = True

    num = 20000
    tnum = 3000
    exps = 10

    dim = 20
    low  = -1
    high =  1

    w_add  = 0
    y0_add = 1
    y1_add = 2
    d0_add = 0
    d1_add = -0.5
    noise_scale = 0.00

    seed = 888
    seed_base = 666
    seed_mul = 888

def cat(data_list, axis=1):
    try:
        output=torch.cat(data_list,axis)
    except:
        output=np.concatenate(data_list,axis)

    return output

def set_seed(seed=2023):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class saveResult(object):
    def __init__(self) -> None:
        self.full = pd.DataFrame(columns=['epoch', 'loss', 'trt_pmse', 'trt_lam', 'Py0_mse', 'Py1_mse', 'Lam0_mse', 'Lam1_mse', 
                                          'LPEHE', 'HPEHE', 'PEHE', 'Y0_acc', 'Y1_acc', 'ATE_bias', 'PPEHE'])

    def one(self, epoch, loss, trt_pmse, trt_lam, hat_d0, hat_y0, hat_d1, hat_y1, lam0, y0, p0, lam1, y1, p1):

        loss = loss.detach().numpy()
        trt_pmse = trt_pmse.detach().numpy()
        trt_lam = trt_lam.detach().numpy()
        hat_d1 = 1 / hat_d1
        hat_d0 = 1 / hat_d0

        Py0_mse = ((hat_y0 - p0) ** 2).mean().detach().numpy()
        Py1_mse = ((hat_y1 - p1) ** 2).mean().detach().numpy()
        Lam0_mse = ((hat_d0 - lam0) ** 2).mean().detach().numpy()
        Lam1_mse = ((hat_d1 - lam1) ** 2).mean().detach().numpy()

        pred_y0 = (hat_y0 > 0.5).float()  
        y0_acc = ((pred_y0 - y0) ** 2).mean().detach().numpy()
        pred_y1 = (hat_y1 > 0.5).float()
        y1_acc = ((pred_y1 - y1) ** 2).mean().detach().numpy()

        hatD_CATE  =  hat_d1 - hat_d0
        truD_CATE  =  lam1 - lam0
        LPEHE = torch.sqrt(torch.mean((truD_CATE - hatD_CATE) ** 2)).detach().numpy()

        hatP_CATE  =  hat_y1 - hat_y0
        truP_CATE  =  p1 - p0
        PPEHE = torch.sqrt(torch.mean((truP_CATE - hatP_CATE) ** 2)).detach().numpy()

        ATE_bias = np.abs((hatP_CATE.mean() - truP_CATE.mean()).detach().numpy())

        hatH_CATE  =  hat_y1 - hat_y0
        truH_CATE =  y1 - y0
        HPEHE = torch.sqrt(torch.mean((truH_CATE - hatH_CATE) ** 2)).detach().numpy()

        hatY_CATE  =  pred_y1 - pred_y0
        truY_CATE =  y1 - y0
        PEHE = torch.sqrt(torch.mean((truY_CATE - hatY_CATE) ** 2)).detach().numpy()

        new_row = {'epoch': epoch, 'loss':loss, 'trt_pmse': trt_pmse, 'trt_lam': trt_lam, 'Py0_mse': Py0_mse, 'Py1_mse': Py1_mse, 'Lam0_mse': Lam0_mse, 'Lam1_mse': Lam1_mse, 
                   'LPEHE':LPEHE, 'HPEHE': HPEHE, 'PEHE': PEHE, 'Y0_acc': y0_acc, 'Y1_acc': y1_acc, 'ATE_bias': ATE_bias, 'PPEHE': PPEHE}
        
        self.full = pd.concat([self.full, pd.DataFrame(new_row, index=[0])], ignore_index=True)

class baseResult(object):
    def __init__(self) -> None:
        self.full = pd.DataFrame(columns=['epoch', 'loss', 'f_error', 'Py0_mse', 'Py1_mse', 'PPEHE', 'HPEHE', 'PEHE', 'Y0_acc', 'Y1_acc', 'ATE_bias'])

    def one(self, epoch, loss, f_error, hat_y0, hat_y1, y0, p0, y1, p1):

        Py0_mse = ((hat_y0 - p0) ** 2).mean()
        Py1_mse = ((hat_y1 - p1) ** 2).mean()

        pred_y0 = (hat_y0 > 0.5).astype(float)
        y0_acc = ((pred_y0 - y0) ** 2).mean()
        pred_y1 = (hat_y1 > 0.5).astype(float)
        y1_acc = ((pred_y1 - y1) ** 2).mean()

        hatP_CATE  =  hat_y1 - hat_y0
        truP_CATE  =  p1 - p0
        PPEHE = np.sqrt(np.mean((truP_CATE - hatP_CATE) ** 2))

        ATE_bias = np.abs((hatP_CATE.mean() - truP_CATE.mean()))

        hatH_CATE  =  hat_y1 - hat_y0
        truH_CATE =  y1 - y0
        HPEHE = np.sqrt(np.mean((truH_CATE - hatH_CATE) ** 2))

        hatY_CATE  =  pred_y1 - pred_y0
        truY_CATE =  y1 - y0
        PEHE = np.sqrt(np.mean((truY_CATE - hatY_CATE) ** 2))

        new_row = {'epoch': epoch, 'loss':loss, 'f_error': f_error, 'Py0_mse': Py0_mse, 'Py1_mse': Py1_mse, 
                   'PPEHE': PPEHE, 'HPEHE': HPEHE, 'PEHE': PEHE, 'Y0_acc': y0_acc, 'Y1_acc': y1_acc, 'ATE_bias': ATE_bias}
        
        self.full = pd.concat([self.full, pd.DataFrame(new_row, index=[0])], ignore_index=True)

class batchData(object):
    def __init__(self, data_list, batch_size=500):
        self.data_list = data_list
        self.length = len(self.data_list[0])
        self.batch_size = batch_size
        
    def get_batch(self, batch_size=None, n=None):
        if batch_size is None:
            batch_size = self.batch_size
        if n is None:
            n = self.length

        I = random.sample(range(0, n), batch_size)
        batch_list = []
        for item in self.data_list:
            batch_list.append(item[I])
        return batch_list
    
    def get_all(self, n=None):
        if n is None:
            n = self.length

        batch_list = []
        for item in self.data_list:
            batch_list.append(item[:n])
        return self.data_list


class trainData(object):
    def __init__(self, data=None, path=None):
        if data is None:
            data = np.load(path)

        self.X = torch.from_numpy(data['X'])
        self.W = torch.from_numpy(data['W'])
        self.T = torch.from_numpy(data['T'])
        self.D = torch.from_numpy(data['D'])
        self.Y = torch.from_numpy(data['Y'])
        self.G = torch.from_numpy(data['G'])
        self.P = torch.from_numpy(data['P'])
        self.L = torch.from_numpy(data['L'])
    
    def all(self):
        return self.X, self.W, self.T, self.D, self.Y, self.G, self.P, self.L

    def float(self):
        self.X = self.X.float()
        self.W = self.W.float()
        self.T = self.T.float()
        self.D = self.D.float()
        self.Y = self.Y.float()
        self.G = self.G.float()
        self.P = self.P.float()
        self.L = self.L.float()

    def numpy(self):
        self.X = self.X.numpy()
        self.W = self.W.numpy()
        self.T = self.T.numpy()
        self.D = self.D.numpy()
        self.Y = self.Y.numpy()
        self.G = self.G.numpy()
        self.P = self.P.numpy()
        self.L = self.L.numpy()
    
    def W1(self):
        return self.X[self.W[:,0]==1], self.W[self.W[:,0]==1], self.T[self.W[:,0]==1], \
               self.D[self.W[:,0]==1],self.Y[self.W[:,0]==1],self.G[self.W[:,0]==1], \
               self.P[self.W[:,0]==1],self.L[self.W[:,0]==1]
    
    def W0(self):
        return self.X[self.W[:,0]==0], self.W[self.W[:,0]==0], self.T[self.W[:,0]==0], \
               self.D[self.W[:,0]==0],self.Y[self.W[:,0]==0],self.G[self.W[:,0]==0], \
               self.P[self.W[:,0]==0],self.L[self.W[:,0]==0]
    
class testData(object):
    def __init__(self, data=None, path=None):
        if data is None:
            data = np.load(path)

        self.X = torch.from_numpy(data['X'])
        self.Lam0 = torch.from_numpy(data['Lam0'])
        self.Lam1 = torch.from_numpy(data['Lam1'])
        self.P0 = torch.from_numpy(data['P0'])
        self.P1 = torch.from_numpy(data['P1'])
        self.Y0 = torch.from_numpy(data['Y0'])
        self.Y1 = torch.from_numpy(data['Y1'])
    
    def float(self):
        self.X = self.X.float()
        self.Lam0 = self.Lam0.float()
        self.Lam1 = self.Lam1.float()
        self.P0 = self.P0.float()
        self.P1 = self.P1.float()
        self.Y0 = self.Y0.float()
        self.Y1 = self.Y1.float()

    def numpy(self):
        self.X = self.X.numpy()
        self.Lam0 = self.Lam0.numpy()
        self.Lam1 = self.Lam1.numpy()
        self.P0 = self.P0.numpy()
        self.P1 = self.P1.numpy()
        self.Y0 = self.Y0.numpy()
        self.Y1 = self.Y1.numpy()

    def all(self):
        return self.X, self.Lam0, self.Lam1, self.P0, self.P1, self.Y0, self.Y1
    


def safe_sqrt(x, lbound=SQRT_CONSTL, rbound=SQRT_CONSTR):
    ''' Numerically safe version of TensorFlow sqrt '''
    return torch.sqrt(torch.clamp(x, lbound, rbound))
    
def lindisc(Xc,Xt,p):
    ''' Linear MMD '''

    mean_control = torch.mean(Xc,dim=0)
    mean_treated = torch.mean(Xt,dim=0)

    c = torch.square(2*p-1)*0.25
    f = torch.sign(p-0.5)

    mmd = torch.sum(torch.square(p*mean_treated - (1-p)*mean_control))
    mmd = f*(p-0.5) + safe_sqrt(c + mmd)

    return mmd

def mmd2_lin(Xc,Xt,p):
    ''' Linear MMD '''

    mean_control = torch.mean(Xc,dim=0)
    mean_treated = torch.mean(Xt,dim=0)

    mmd = torch.sum(torch.square(2.0*p*mean_treated - 2.0*(1.0-p)*mean_control))

    return mmd

def pdist(sample_1, sample_2, norm=2, eps=1e-5):
    """Compute the matrix of all squared pairwise distances.
    Arguments
    ---------
    sample_1 : torch.Tensor or Variable
        The first sample, should be of shape ``(n_1, d)``.
    sample_2 : torch.Tensor or Variable
        The second sample, should be of shape ``(n_2, d)``.
    norm : float
        The l_p norm to be used.
    Returns
    -------
    torch.Tensor or Variable
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
    n_1, n_2 = sample_1.size(0), sample_2.size(0)
    norm = float(norm)
    if norm == 2.:
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
        norms = (norms_1.expand(n_1, n_2) +
                 norms_2.transpose(0, 1).expand(n_1, n_2))
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
        return torch.sqrt(eps + torch.abs(distances_squared))
    else:
        dim = sample_1.size(1)
        expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim)
        expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim)
        differences = torch.abs(expanded_1 - expanded_2) ** norm
        inner = torch.sum(differences, dim=2, keepdim=False)
        return (eps + inner) ** (1. / norm)