import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import sklearn.metrics as metrics
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"



# CROME model
class CausalMultiTaskDataset(Dataset):
    def __init__(self, X, A, Y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.A = torch.tensor(A, dtype=torch.float32)
        self.Y = torch.tensor(Y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx], self.Y[idx]

# ----- Model Definition -----
class SharedRepresentation(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SharedRepresentation, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.encoder(x)

class MultiTaskCausalModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_outcomes, utility='or', utility_weights=None):
        super(MultiTaskCausalModel, self).__init__()
        self.representation = SharedRepresentation(input_dim, hidden_dim)
        self.heads_t0 = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) for _ in range(num_outcomes)])
        self.heads_t1 = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) for _ in range(num_outcomes)])
        self.utility = utility
        self.utility_weights = torch.tensor(utility_weights if utility_weights is not None else [1.0]*num_outcomes, dtype=torch.float32)

    def forward(self, x, a):
        h = self.representation(x)
        preds = []
        for i in range(len(self.heads_t0)):
            y0 = self.heads_t0[i](h)
            y1 = self.heads_t1[i](h)
            y = torch.where(a.view(-1, 1) == 1, y1, y0)
            preds.append(y)
        preds = torch.cat(preds, dim=1)
        composite = self.compute_utility(preds)
        return preds, composite

    def compute_utility(self, preds):
        if self.utility == 'or':
            return 1 - torch.prod(1 - preds, dim=1, keepdim=True)
        elif self.utility == 'weighted_sum':
            return torch.sum(preds * self.utility_weights.to(preds.device), dim=1, keepdim=True)
        elif self.utility == 'tanh_reward':
            return torch.tanh(torch.sum(preds * self.utility_weights.to(preds.device), dim=1, keepdim=True))
        else:
            raise NotImplementedError(f"Unknown utility: {self.utility}")

    def predict_counterfactuals(self, x):
        h = self.representation(x)
        y_hat_0 = torch.cat([head(h) for head in self.heads_t0], dim=1)
        y_hat_1 = torch.cat([head(h) for head in self.heads_t1], dim=1)
        u0 = self.compute_utility(y_hat_0)
        u1 = self.compute_utility(y_hat_1)
        tau = u1 - u0
        return u0, u1, tau
    
    def predict_counterfactuals_individual(self, x):
        h = self.representation(x)
        y_hat_0 = torch.cat([head(h) for head in self.heads_t0], dim=1)
        y_hat_1 = torch.cat([head(h) for head in self.heads_t1], dim=1)
        tau = y_hat_1 - y_hat_0
        return y_hat_0, y_hat_1, tau
    


# This is single task composite model

class SingleTaskCompositeModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.head0 = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
        self.head1 = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())

    def forward(self, x, a):
        h = self.shared(x)
        u0 = self.head0(h)
        u1 = self.head1(h)
        return torch.where(a.view(-1, 1) == 1, u1, u0)

    def predict_counterfactuals(self, x):
        h = self.shared(x)
        u0 = self.head0(h)
        u1 = self.head1(h)
        tau = u1 - u0
        return u0, u1, tau    


class CompositeOutcomeDataset(Dataset):
    def __init__(self, X, A, U):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.A = torch.tensor(A, dtype=torch.float32)
        self.U = torch.tensor(U, dtype=torch.float32).view(-1, 1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx], self.U[idx]


# Independent Outcome Model (no shared representation)

class IndependentOutcomeModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_outcomes):
        super().__init__()
        self.num_outcomes = num_outcomes

        # Each outcome gets its own representation and treatment-specific heads
        self.representations = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU()
            ) for _ in range(num_outcomes)
        ])

        self.heads_t0 = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) for _ in range(num_outcomes)])
        self.heads_t1 = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) for _ in range(num_outcomes)])

    def forward(self, x, a):
        preds = []
        for k in range(self.num_outcomes):
            h_k = self.representations[k](x)
            y0 = self.heads_t0[k](h_k)
            y1 = self.heads_t1[k](h_k)
            y = torch.where(a.view(-1, 1) == 1, y1, y0)
            preds.append(y)
        preds = torch.cat(preds, dim=1)
        return preds

    def predict_counterfactuals(self, x):
        y_hat_0 = []
        y_hat_1 = []
        for k in range(self.num_outcomes):
            h_k = self.representations[k](x)
            y0 = self.heads_t0[k](h_k)
            y1 = self.heads_t1[k](h_k)
            y_hat_0.append(y0)
            y_hat_1.append(y1)
        y0_all = torch.cat(y_hat_0, dim=1)
        y1_all = torch.cat(y_hat_1, dim=1)
        return y0_all, y1_all



# ----- Training Loop -----
def train_model(model, dataloader, num_epochs=10, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCELoss()

    for epoch in range(num_epochs):
        total_loss = 0
        model.train()
        for x_batch, a_batch, y_batch in dataloader:
            optimizer.zero_grad()
            pred_components, _ = model(x_batch, a_batch)

            loss = 0
            for k in range(pred_components.shape[1]):
                loss += loss_fn(pred_components[:, k], y_batch[:, k])
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")


    
    
def train_composite_model(model, dataloader, num_epochs=20, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCELoss()
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for x_batch, a_batch, u_batch in dataloader:
            optimizer.zero_grad()
            pred_u = model(x_batch, a_batch)
            loss = loss_fn(pred_u, u_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Composite Model] Epoch {epoch+1}, Loss: {total_loss:.4f}")  
    
    
    
    
def train_independent_model(model, dataloader, num_epochs=20, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCELoss()

    for epoch in range(num_epochs):
        total_loss = 0
        model.train()
        for x_batch, a_batch, y_batch in dataloader:
            optimizer.zero_grad()
            preds = model(x_batch, a_batch)
            loss = sum(loss_fn(preds[:, k], y_batch[:, k]) for k in range(preds.shape[1]))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Independent Model] Epoch {epoch+1}, Loss: {total_loss:.4f}")
    
    
    
def compute_utility_tensor(Y, utility='or', weights=None):
    if utility == 'or':
        return 1 - torch.prod(1 - Y, dim=1, keepdim=True)
    elif utility == 'weighted_sum':
        w = torch.tensor(weights if weights is not None else [1.0]*Y.shape[1], dtype=torch.float32).to(Y.device)
        return torch.sum(Y * w, dim=1, keepdim=True)
    elif utility == 'tanh':
        w = torch.tensor(weights if weights is not None else [1.0]*Y.shape[1], dtype=torch.float32).to(Y.device)
        return torch.tanh(torch.sum(Y * w, dim=1, keepdim=True))
    else:
        raise ValueError(f"Unsupported utility: {utility}")    
    
    
    
    
    
    
    
