import os
import numpy as np
import pandas as pd
import random
import torch
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils import Config, set_seed, trainData, testData, batchData, saveResult, mmd2_lin, safe_sqrt

class MLP(nn.Module):
    def __init__(self, dims, activate):
        super(MLP, self).__init__()
        self.num_layers = len(dims) - 1
        self.layers = nn.ModuleList()
        self.batchnorms = nn.ModuleList()
        self.activate = activate
        
        for i in range(self.num_layers):
            self.layers.append(nn.Linear(dims[i], dims[i+1]))
            self.batchnorms.append(nn.BatchNorm1d(dims[i+1]))

        self.convert_to_double()

    def convert_to_double(self):
        self.double()
        for layer in self.layers:
            layer.double()
    
    def convert_to_float(self):
        self.float()
        for layer in self.layers:
            layer.float()
            
    def forward(self, x):
        _pre = x
        _next = x
        for i in range(self.num_layers):
            _next = self.layers[i](_pre)
            if self.activate[i] == "relu":
                _next = nn.functional.relu(_next) + _pre
                _next = self.batchnorms[i](_next)
            elif self.activate[i] == "elu":
                _next = nn.functional.elu(_next) + _pre
                _next = self.batchnorms[i](_next)
            elif self.activate[i] == "sigmoid":
                _next = torch.sigmoid(_next)
            elif self.activate[i] == "exp":
                _next = torch.exp(_next)
            elif self.activate[i] == "nexp":
                _next = 1/torch.exp(_next)
            elif self.activate[i] == "softplus":
                _next = torch.log(torch.exp(_next)+1) # softplus
            _pre = _next
        return _next

class Nets(nn.Module):
    def __init__(self, dim, p_treated, mode='CFR_DF', alpha=0.001, beta=0.001):
        super(Nets, self).__init__()
        self.ry0 = MLP([dim, dim, dim], ['elu','elu'])
        self.ry1 = MLP([dim, dim, dim], ['elu','elu'])
        self.rd0 = MLP([dim, dim], ['elu'])
        self.rd1 = MLP([dim, dim], ['elu'])

        self.y1 = MLP([dim, 1], ['sigmoid'])
        self.d1 = MLP([dim, 1], ["softplus"])

        self.y0 = MLP([dim, 1], ['sigmoid'])
        self.d0 = MLP([dim, 1], ["softplus"])

        self.p_treated = p_treated
        self.mode  = mode
        self.alpha = alpha
        self.beta  = beta

    def predict(self, x):
        x_repy1 = self.ry1(x)
        x_repy0 = self.ry0(x)

        x_repd1 = self.rd1(x)
        x_repd0 = self.rd0(x)

        hat_y1 = self.y1(x_repy1)
        hat_d1 = self.d1(x_repd1)

        hat_y0 = self.y0(x_repy0)
        hat_d0 = self.d0(x_repd0)

        return hat_d0, hat_y0, hat_d1, hat_y1

    def output(self, x): 
        x_repy1 = self.ry1(x)
        x_repy0 = self.ry0(x)

        x_repd1 = self.rd1(x)
        x_repd0 = self.rd0(x)

        hat_y1 = self.y1(x_repy1)
        hat_d1 = self.d1(x_repd1)

        hat_y0 = self.y0(x_repy0)
        hat_d0 = self.d0(x_repd0)

        return x_repy1, x_repd1, x_repy0, x_repd0, hat_d0, hat_y0, hat_d1, hat_y1

    def ipmloss(self, rep_batch, w_batch):
        rep_t1, rep_t0 = rep_batch[(w_batch > 0).nonzero()[:,0]], rep_batch[(w_batch < 1).nonzero()[:,0]]

        # imb_dist = wasserstein(rep_t0.float(),rep_t1.float(),self.p_treated)
        # imb_error = imb_dist

        imb_dist = mmd2_lin(rep_t0,rep_t1,self.p_treated)
        imb_error = safe_sqrt(imb_dist)

        return imb_error

    def forward(self, x, w, d, t, y):
        x_repy1, x_repd1, x_repy0, x_repd0, hat_d0, hat_y0, hat_d1, hat_y1 = self.output(x)

        p1 = hat_y1
        lamb1 = hat_d1
        loss1 = -(torch.log(p1) + torch.log(lamb1) - lamb1 * d)
        loss2 = -(torch.log(1-p1+p1*torch.exp(-lamb1*t)))
        lossw1 = y * loss1 + (1-y) * loss2

        p0 = hat_y0
        lamb0 = hat_d0
        loss3 = -(torch.log(p0) + torch.log(lamb0) - lamb0 * d)
        loss4 = -(torch.log(1-p0+p0*torch.exp(-lamb0*t)))
        lossw0 = y * loss3 + (1-y) * loss4

        p = p1 * w + p0 * (1-w)
        lamb = lamb1 * w + lamb0 * (1-w)
        loss = lossw1 * w + lossw0 * (1-w)
        loss = loss.mean()

        if self.alpha > 0:
            imb_error1 = self.ipmloss(x_repy1, w) + self.ipmloss(x_repy0, w)
            imb_error2 = self.ipmloss(x_repd1, w) + self.ipmloss(x_repd0, w)
            loss = loss + self.alpha * imb_error1 + self.beta * imb_error2

        return p, lamb, loss


cfg = Config()

batch_size = cfg.batch_size
epochs = cfg.iterations
lrate = cfg.lrate
w_decay = 0.0001
step_put = cfg.step_put
seed = 2023

epochs_draw = [600, 800, 1000, 1200]
fig2_flag = True

set_seed(seed)

if cfg.batch_flag:
    method_name = f'{cfg.model}-{cfg.alpha}-{cfg.beta}-{cfg.batch_size}'
else:
    method_name = f'{cfg.model}-{cfg.alpha}-{cfg.beta}'
for exp in range(cfg.exps):
    print(f"This is the {exp}-th experiments ---- {method_name}. ")
    data_setting = f"{cfg.num}_{cfg.dim}_{cfg.y0_add}_{cfg.y1_add}_{cfg.noise_scale}"
    os.makedirs(os.path.dirname(f'./results/{data_setting}/{method_name}/figure/'), exist_ok=True)
    data = np.load(f'./data/{data_setting}/{exp}/train.npz')
    ttrain = trainData(data)

    X, w, t, d, y, g, pY, lam = ttrain.all()
    trains = batchData([X, w, t, d, y, g, pY, lam], batch_size)


    valid = np.load(f'./data/{data_setting}/{exp}/valid.npz')
    teval = testData(valid)

    test = np.load(f'./data/{data_setting}/{exp}/test.npz')
    ttest = testData(test)

    p_treated = torch.mean(w.float())

    dim = X.shape[1]
    net = Nets(dim, p_treated, cfg.model, cfg.alpha, cfg.beta)
    optimizer = optim.Adam(net.parameters(), lr=lrate, weight_decay=w_decay)

    result_trt = saveResult()
    result_tst = saveResult()
    for epoch in range(epochs):

        if cfg.batch_flag:
            X_b, w_b, t_b, d_b, y_b, g_b, pY_b, lam_b = trains.get_batch()
        else:
            X_b, w_b, t_b, d_b, y_b, g_b, pY_b, lam_b = trains.get_all()

        p, lamb, loss = net(X_b, w_b, d_b, t_b, y_b)
        

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % step_put == 0:
            X_b, w_b, t_b, d_b, y_b, g_b, pY_b, lam_b = trains.get_all()
            p, lamb, loss = net(X_b, w_b, d_b, t_b, y_b)

            hat_d0, hat_y0, hat_d1, hat_y1 = net.predict(teval.X)
            result_trt.one(epoch, loss, ((p-pY_b)**2).mean(), ((lamb*lam_b - 1)**2).mean(), hat_d0, hat_y0, 
                    hat_d1, hat_y1, teval.Lam0, teval.Y0, teval.P0, teval.Lam1, teval.Y1, teval.P1)
            
            hat_d0, hat_y0, hat_d1, hat_y1 = net.predict(ttest.X)
            result_tst.one(epoch, loss, ((p-pY_b)**2).mean(), ((lamb*lam_b - 1)**2).mean(), hat_d0, hat_y0, 
                    hat_d1, hat_y1, ttest.Lam0, ttest.Y0, ttest.P0, ttest.Lam1, ttest.Y1, ttest.P1)
            
            print("Epoch-{}: Outcome-loss - {:.2f}, ResponseTime-loss - {:.2f}.".format(epoch, ((p-pY_b)**2).mean(), ((lamb*lam_b - 1)**2).mean()))
        
        if epoch in epochs_draw and fig2_flag:
            hat_d0, hat_y0, _, _ = net.predict(ttest.X)
            fig2_lams = torch.cat([hat_d0, hat_y0, ttest.Lam0, ttest.Lam1, ttest.Y0, ttest.Y1], axis=1).detach().numpy()
            fig2_path = f'./results/{data_setting}/{method_name}/figure/lam{exp}_{epoch}.npy'
            np.save(fig2_path, fig2_lams)

    
    result_trt.full = result_trt.full.round(4)
    result_trt.full.to_csv(f'./results/{data_setting}/{method_name}/re{exp}_trt.csv', index=False)

    result_tst.full = result_tst.full.round(4)
    result_tst.full.to_csv(f'./results/{data_setting}/{method_name}/re{exp}_tst.csv', index=False)
    print(f"Result save to: /results/{data_setting}/{method_name}/re{exp}_*.csv.")
