import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import tqdm
import torch.optim as optim
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn


class BoundedContinuousFunctionModel(nn.Module):
    def __init__(self, input_dim, hidden_dims=[8,8], output_dim=1):
        """
        Initializes the ArbitraryFunctionModel.

        Parameters:
        - input_dim (int): Dimensionality of the input features.
        - hidden_dims (list of int): Dimensions of the hidden layers.
        - output_dim (int): Dimensionality of the output.
        """
        super(BoundedContinuousFunctionModel, self).__init__()
        layers = []
        last_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(last_dim, hidden_dim))
            layers.append(nn.ReLU())
            last_dim = hidden_dim
        layers.append(nn.Linear(last_dim, output_dim))
        layers.append(nn.Sigmoid())
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ELU(),
            nn.Linear(16, output_dim),
            nn.ELU()
        )

    def forward(self, x):
        return self.network(x)
    

# this is a simple s-network model
class snet(nn.Module):
    def __init__(self, input_dim,hidden_dim=32):
        super(snet, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
        return self.network(x)


class CFR(nn.Module):
    def __init__(self, input_dim, output_dim, rep_dim=32, hyp_dim=16):
        super(CFR, self).__init__()

        # Representation layer
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, rep_dim),
            nn.ELU(),
            nn.Linear(rep_dim, rep_dim),
            nn.ELU(),
            nn.Linear(rep_dim, rep_dim),
            nn.ELU()
        )

        # Potential outcome for control (y0)
        self.func0 = nn.Sequential(
            nn.Linear(rep_dim, hyp_dim),
            nn.ELU(),
            nn.Linear(hyp_dim, hyp_dim),
            nn.ELU(),
            nn.Linear(hyp_dim, output_dim)
        )

        # Potential outcome for treated (y1)
        self.func1 = nn.Sequential(
            nn.Linear(rep_dim, hyp_dim),
            nn.ELU(),
            nn.Linear(hyp_dim, hyp_dim),
            nn.ELU(),
            nn.Linear(hyp_dim, output_dim)
        )

    def forward(self, X):
        Phi = self.encoder(X)
        Y0 = self.func0(Phi)
        Y1 = self.func1(Phi)
        return Phi, Y0, Y1

    

class TLearner(nn.Module):
    def __init__(self, input_dim, output_dim, hyp_dim=100):
        super().__init__()

        # Potential outcome y0
        func0 = [nn.Linear(input_dim, hyp_dim),
                 nn.ReLU(),
                 nn.Linear(hyp_dim, hyp_dim),
                 nn.ReLU(),
                 nn.Linear(hyp_dim, output_dim)]
        self.func0 = nn.Sequential(*func0)

        # Potential outcome y1
        func1 = [nn.Linear(input_dim, hyp_dim),
                 nn.ReLU(),
                 nn.Linear(hyp_dim, hyp_dim),
                 nn.ReLU(),
                 nn.Linear(hyp_dim, output_dim)]
        self.func1 = nn.Sequential(*func1)

    def forward(self, X):
        Y0 = self.func0(X)
        Y1 = self.func1(X)
        return Y0, Y1





