import torch
import torch.nn as nn

import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split


def binary_classification_loss(concat_true, concat_pred):
    t_true = concat_true[:, 1]
    t_pred = concat_pred[:, 2]
    t_pred = (t_pred + 0.001) / 1.002
    losst = torch.sum(F.binary_cross_entropy(t_pred, t_true))

    return losst


def regression_loss(concat_true, concat_pred):
    y_true = concat_true[:, 0]
    t_true = concat_true[:, 1]

    y0_pred = concat_pred[:, 0]
    y1_pred = concat_pred[:, 1]

    loss0 = torch.sum((1. - t_true) * torch.square(y_true - y0_pred))
    loss1 = torch.sum(t_true * torch.square(y_true - y1_pred))

    return loss0 + loss1

def ned_loss(concat_true, concat_pred):
    t_true = concat_true[:, 1]

    t_pred = concat_pred[:, 1]
    return torch.sum(F.binary_cross_entropy(t_pred, t_true))


def dead_loss(concat_true, concat_pred):
    return regression_loss(concat_true, concat_pred)

def dragonnet_loss_binarycross(concat_true, concat_pred):
    return regression_loss(concat_true, concat_pred) + binary_classification_loss(concat_true, concat_pred)

def treatment_accuracy(concat_true, concat_pred):
    t_true = concat_true[:, 1]
    t_pred = concat_pred[:, 2]
    return binary_accuracy(t_true, t_pred)

def track_epsilon(concat_true, concat_pred):
    epsilons = concat_pred[:, 3]
    return torch.abs(torch.mean(epsilons))

class EpsilonLayer(nn.Module):
    def __init__(self):
        super(EpsilonLayer, self).__init__()

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.epsilon = nn.Parameter(torch.randn(1, 1))
        # You can also initialize self.epsilon to 1 by using:
        # self.epsilon = nn.Parameter(torch.ones(1, 1))

    def forward(self, inputs):
        return self.epsilon * torch.ones_like(inputs)[:, 0:1]



def make_tarreg_loss(ratio=1., dragonnet_loss=dragonnet_loss_binarycross):
    def tarreg_ATE_unbounded_domain_loss(concat_true, concat_pred):
        vanilla_loss = dragonnet_loss(concat_true, concat_pred)

        y_true = concat_true[:, 0]
        t_true = concat_true[:, 1]

        y0_pred = concat_pred[:, 0]
        y1_pred = concat_pred[:, 1]
        t_pred = concat_pred[:, 2]

        epsilons = concat_pred[:, 3]
        t_pred = (t_pred + 0.01) / 1.02
        # t_pred = torch.clamp(t_pred, 0.01, 0.99)  # Apply clipping using torch.clamp

        y_pred = t_true * y1_pred + (1 - t_true) * y0_pred

        h = t_true / t_pred - (1 - t_true) / (1 - t_pred)

        y_pert = y_pred + epsilons * h
        targeted_regularization = torch.sum(torch.square(y_true - y_pert))

        # final
        loss = vanilla_loss + ratio * targeted_regularization
        return loss

    return tarreg_ATE_unbounded_domain_loss


# import torch
# import torch.nn as nn

# class EpsilonLayer(nn.Module):
#     def __init__(self):
#         super(EpsilonLayer, self).__init__()

#     def forward(self, t_predictions):
#         return t_predictions

class DragonNet(nn.Module):
    def __init__(self, input_dim, reg_l2):
        super(DragonNet, self).__init__()
        t_l1 = 0.
        t_l2 = reg_l2

        self.inputs = nn.Linear(input_dim, 200)
        self.hidden1 = nn.Linear(200, 200)
        self.hidden2 = nn.Linear(200, 200)
        self.t_predictions = nn.Linear(200, 1)

        self.y0_hidden = nn.Sequential(
            nn.Linear(200, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU()
        )
        self.y1_hidden = nn.Sequential(
            nn.Linear(200, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU()
        )
        self.y0_predictions = nn.Linear(100, 1)
        self.y1_predictions = nn.Linear(100, 1)

        self.epsilon_layer = EpsilonLayer()

    def forward(self, inputs):
        x = nn.ELU()(self.inputs(inputs))
        x = nn.ELU()(self.hidden1(x))
        x = nn.ELU()(self.hidden2(x))

        t_predictions = torch.sigmoid(self.t_predictions(x))

        y0_hidden = self.y0_hidden(x)
        y1_hidden = self.y1_hidden(x)

        y0_predictions = self.y0_predictions(y0_hidden)
        y1_predictions = self.y1_predictions(y1_hidden)

        epsilons = self.epsilon_layer(t_predictions)

        concat_pred = torch.cat([y0_predictions, y1_predictions, t_predictions, epsilons], dim=1)

        return concat_pred

# import torch
# import torch.nn as nn

# class EpsilonLayer(nn.Module):
#     def __init__(self):
#         super(EpsilonLayer, self).__init__()

#     def forward(self, t_predictions):
#         return t_predictions

class TarNet(nn.Module):
    def __init__(self, input_dim, reg_l2):
        super(TarNet, self).__init__()

        self.inputs = nn.Linear(input_dim, 200)
        self.hidden1 = nn.Linear(200, 200)
        self.hidden2 = nn.Linear(200, 200)
        self.t_predictions = nn.Linear(input_dim, 1)

        self.y0_hidden = nn.Sequential(
            nn.Linear(200, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU()
        )
        self.y1_hidden = nn.Sequential(
            nn.Linear(200, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU()
        )
        self.y0_predictions = nn.Linear(100, 1)
        self.y1_predictions = nn.Linear(100, 1)

        self.epsilon_layer = EpsilonLayer()

    def forward(self, inputs):
        x = nn.ELU()(self.inputs(inputs))
        x = nn.ELU()(self.hidden1(x))
        x = nn.ELU()(self.hidden2(x))

        t_predictions = torch.sigmoid(self.t_predictions(inputs))

        y0_hidden = self.y0_hidden(x)
        y1_hidden = self.y1_hidden(x)

        y0_predictions = self.y0_predictions(y0_hidden)
        y1_predictions = self.y1_predictions(y1_hidden)

        epsilons = self.epsilon_layer(t_predictions)

        concat_pred = torch.cat([y0_predictions, y1_predictions, t_predictions, epsilons], dim=1)

        return concat_pred

class NEDNet(nn.Module):
    def __init__(self, input_dim, reg_l2=0.01):
        super(NEDNet, self).__init__()

        self.inputs = nn.Linear(input_dim, 200)
        self.hidden1 = nn.Linear(200, 200)
        self.hidden2 = nn.Linear(200, 200)
        self.t_predictions = nn.Linear(200, 1)
        self.y_predictions = nn.Linear(200, 1)

    def forward(self, inputs):
        x = nn.ELU()(self.inputs(inputs))
        x = nn.ELU()(self.hidden1(x))
        x = nn.ELU()(self.hidden2(x))

        t_predictions = torch.sigmoid(self.t_predictions(x))
        y_predictions = self.y_predictions(x)

        concat_pred = torch.cat([y_predictions, t_predictions], dim=1)

        return concat_pred

class PostCutNet(nn.Module):
    def __init__(self, nednet, input_dim, reg_l2=0.01):
        super(PostCutNet, self).__init__()

        self.frozen = nn.Sequential(*list(nednet.children())[:-3])

        self.y0_hidden = nn.Sequential(
            nn.Linear(200, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU()
        )
        self.y1_hidden = nn.Sequential(
            nn.Linear(200, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU()
        )
        self.y0_predictions = nn.Linear(100, 1)
        self.y1_predictions = nn.Linear(100, 1)

    def forward(self, inputs):
        x = self.frozen(inputs)

        y0_hidden = self.y0_hidden(x)
        y1_hidden = self.y1_hidden(x)

        y0_predictions = self.y0_predictions(y0_hidden)
        y1_predictions = self.y1_predictions(y1_hidden)

        concat_pred = torch.cat([y0_predictions, y1_predictions], dim=1)

        return concat_pred



def train_and_predict_dragons(t, y_unscaled, x, targeted_regularization=True, output_dir='',
                              knob_loss=dragonnet_loss_binarycross, ratio=1., dragon='', val_split=0.2, batch_size=512):
    # Standardize the target variable y
    y_scaler = StandardScaler().fit(y_unscaled)
    y = y_scaler.transform(y_unscaled)

    train_outputs = []
    test_outputs = []
    runs = 25

    # Iterate over multiple runs
    for i in range(runs):
        # Create the DragonNet model based on the specified "dragon" type
        if dragon == 'tarnet':
            dragonnet = TarNet(x.shape[1], 0.01)
        elif dragon == 'dragonnet':
            dragonnet = DragonNet(x.shape[1], 0.01)

        metrics = [regression_loss, binary_classification_loss, treatment_accuracy, track_epsilon]

        # Determine the loss function based on whether targeted regularization is enabled
        if targeted_regularization:
            loss = make_tarreg_loss(ratio=ratio, dragonnet_loss=knob_loss)
        else:
            loss = knob_loss

        # Split the data into train and test sets
        train_index, test_index = train_test_split(np.arange(x.shape[0]), test_size=0, random_state=1)
        test_index = train_index
        x_train, x_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]
        t_train, t_test = t[train_index], t[test_index]
        yt_train = np.concatenate([y_train, t_train], 1)

        # Compile and train the DragonNet model
        dragonnet.compile(optimizer=Adam(lr=1e-3), loss=loss, metrics=metrics)
        dragonnet.forward(x_train, yt_train, callbacks=adam_callbacks, validation_split=val_split,
                      epochs=100, batch_size=batch_size, verbose=verbose)

        # Switch to SGD optimizer and continue training
        dragonnet.compile(optimizer=SGD(lr=sgd_lr, momentum=momentum, nesterov=True), loss=loss, metrics=metrics)
        dragonnet.fit(x_train, yt_train, callbacks=sgd_callbacks, validation_split=val_split,
                      epochs=300, batch_size=batch_size, verbose=verbose)

        # Predict on test and train sets using the trained model
        yt_hat_test = dragonnet.predict(x_test)
        yt_hat_train = dragonnet.predict(x_train)

        # Store the outputs for each run
        test_outputs += [_split_output(yt_hat_test, t_test, y_test, y_scaler, x_test, test_index)]
        train_outputs += [_split_output(yt_hat_train, t_train, y_train, y_scaler, x_train, train_index)]

    return test_outputs, train_outputs
