from s1_model import CausalMultiTaskDataset
from s1_model import SharedRepresentation
from s1_model import MultiTaskCausalModel
from s1_model import train_model
from s1_model import SingleTaskCompositeModel
from s1_model import train_composite_model
from s1_model import CompositeOutcomeDataset
from s1_model import IndependentOutcomeModel
from s1_model import train_independent_model
from s1_model import compute_utility_tensor

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch


# ----- Synthetic Data -----
def simulate_simplified_dgp(n=1000, d=10, K=3, utility='or', weights=None, seed=42, data = 'synthetic'):
    np.random.seed(seed)
    
    
    if data == 'synthetic':
        # Covariates
        X = np.random.randn(n, d)
        A = np.random.binomial(1, 0.5, size=n)
        # Parameters
        beta = np.random.randn(K, d) * 0.5              # linear main effects
        gamma = np.random.randn(K, d) * 0.5             # nonlinear effects
        theta = np.random.randn(K, d) * 0.3             # heterogeneous treatment modifier
        tau_global = 1.0                                 # shared treatment benefit (shift all logits positively)
    elif data == 'IHDP':
        temp = pd.read_csv("C:/Users/Kan Chen/Dropbox/causal_representation_learning/simulation/IHDP.csv", header=None)
        X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        n = temp.shape[0]
        d = 25
        A = np.random.binomial(1, 0.5, size=n)
        # Parameters
        beta = np.random.randn(K, d) * 0.5              # linear main effects
        gamma = np.random.randn(K, d) * 0.5             # nonlinear effects
        theta = np.random.randn(K, d) * 0.3             # heterogeneous treatment modifier
        tau_global = 1.0                                 # shared treatment benefit (shift all logits positively)
    elif data == 'EHR':
        temp = pd.read_csv("C:/Users/Kan Chen/Dropbox/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        n = temp.shape[0]
        d = len(columns_to_keep)
        A = np.random.binomial(1, 0.5, size=n)
        # Parameters
        beta = np.random.randn(K, d) * 1.5              # linear main effects
        gamma = np.random.randn(K, d) * 1.5             # nonlinear effects
        theta = np.random.randn(K, d) * 1.3             # heterogeneous treatment modifier
        tau_global = 3.0                                 # shared treatment benefit (shift all logits positively)
    else:
        return "Error Message"



    def compute_logit(X, A, k):
        linear = X @ beta[k]
        nonlinear = np.tanh(0, X @ gamma[k])      # tanh nonlinearity
        treatment_effect = A * (tau_global + np.tanh(X @ theta[k]))  # heterogeneous, centered around positive value
        return linear + nonlinear + treatment_effect

    # Potential outcomes
    Y0 = np.zeros((n, K))
    Y1 = np.zeros((n, K))
    for k in range(K):
        logit0 = compute_logit(X, np.zeros(n), k)
        logit1 = compute_logit(X, np.ones(n), k)
        #p0 = 1 / (1 + np.exp(-logit0))
        #p1 = 1 / (1 + np.exp(-logit1))
        thresh0 = np.quantile(logit0, 0.4)
        thresh1 = np.quantile(logit1, 0.4)
        Y0[:, k] = (logit0 > thresh0).astype(float)
        Y1[:, k] = (logit1 > thresh1).astype(float)


    # Observed outcomes
    Y = np.where(A[:, None] == 1, Y1, Y0)

    # Composite utility function
    def compute_utility(Y, utility, weights):
        if utility == 'or':
            return 1 - np.prod(1 - Y, axis=1)
        elif utility == 'weighted_sum':
            w = np.array(weights if weights is not None else [1.0] * K)
            return Y @ w
        elif utility == 'tanh':
            w = np.array(weights if weights is not None else [1.0] * K)
            return np.tanh(Y @ w)
        else:
            raise ValueError(f"Unsupported utility: {utility}")

    U0 = compute_utility(Y0, utility, weights)
    U1 = compute_utility(Y1, utility, weights)
    U = np.where(A == 1, U1, U0)
    tau = U1 - U0

    return {
        'X': X,
        'A': A,
        'Y': Y,
        'Y0': Y0,
        'Y1': Y1,
        'U': U,
        'U0': U0,
        'U1': U1,
        'tau': tau
    }


# Predict potential outcomes

data = simulate_simplified_dgp(n=1000, d=10, K=4, data = 'synthetic')

X, A, Y = data['X'], data['A'], data['Y']
U, tau_true = data['U'], data['tau']

K = Y.shape[1]
p = X.shape[1]

# --- Train/Test Split ---
idx_train, idx_test = train_test_split(np.arange(X.shape[0]), test_size=0.2, random_state=None)
X_train, A_train, Y_train = X[idx_train], A[idx_train], Y[idx_train]
U_train, tau_true_train = U[idx_train], tau_true[idx_train]
X_test, A_test, Y_test = X[idx_test], A[idx_test], Y[idx_test]
U_test, tau_true_test = U[idx_test], tau_true[idx_test]



n_epoches = 100
# === Multi-Task Shared Rep ===
dataset = CausalMultiTaskDataset(X_train, A_train, Y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model = MultiTaskCausalModel(input_dim=p, hidden_dim=64, num_outcomes=K)
train_model(model, dataloader, num_epochs=n_epoches, lr=1e-3)
model.eval()

# In-sample
x_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y0_pred, y1_pred, tau_train = model.predict_counterfactuals_individual(x_train_tensor)



# Compute component-level effects
delta = (y1_pred - y0_pred).detach().numpy()  # shape (n, K)

# Compute intermediate point \bar{y} = average of potential outcomes
y_bar = ((y0_pred + y1_pred) / 2).detach().numpy()  # shape (n, K)

# Compute utility gradient at y_bar
if model.utility == 'or':
    grad_u = np.prod(1 - y_bar, axis=1, keepdims=True) * (1 / (1 - y_bar + 1e-8))  # ∂u/∂y_k for OR
elif model.utility == 'weighted_sum':
    grad_u = model.utility_weights.numpy()[None, :]  # constant gradient
elif model.utility == 'tanh_reward':
    linear_sum = (y_bar * model.utility_weights.numpy()).sum(axis=1, keepdims=True)
    grad_u = (1 - np.tanh(linear_sum) ** 2) * model.utility_weights.numpy()[None, :]  # ∂/∂y_k tanh(wᵀy)
else:
    raise NotImplementedError("Utility not supported for decomposition")

# Multiply elementwise: component contributions
component_contribs = grad_u * delta  # shape (n, K)

nonzero_mask = np.any(np.abs(component_contribs) > 1e-6, axis=1)  # 1e-6 avoids float precision issues
component_contribs_nonzero = component_contribs[nonzero_mask]


# Count number of non-zero entries in each row
nonzero_counts = np.count_nonzero(np.abs(component_contribs_nonzero) > 1e-6, axis=1)

# Get indices of rows with the most non-zero entries (e.g., top 10)
top_k = 8
top_indices = np.argsort(-nonzero_counts)[:top_k]

# Select those rows
component_contribs_top = component_contribs_nonzero[top_indices]



import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
sns.heatmap(component_contribs_top, annot=True, fmt=".2f", 
            cmap='coolwarm', center=0, xticklabels=[f"Y{k+1}" for k in range(delta.shape[1])])
plt.title("Heatmap of Component Contributions to Composite Effect")
plt.xlabel("Component Outcomes")
plt.ylabel("Patient Index")
plt.tight_layout()
plt.show()



# Assuming component_contribs_top is (n, K) for top n patients and K component outcomes
# Normalize contributions by row to plot proportions
contribs_norm = np.abs(component_contribs_top)
contribs_norm = contribs_norm / contribs_norm.sum(axis=1, keepdims=True)

# Set custom colors (matching the plot you uploaded)
colors = ['#FFC107', '#FF5722', '#F44336', '#E91E63']  # Yellow, Orange, Red, Pink

# Plot
fig, ax = plt.subplots(figsize=(12, 6))
indices = np.arange(contribs_norm.shape[0])
bottom = np.zeros(contribs_norm.shape[0])

for i in range(contribs_norm.shape[1]):
    ax.bar(indices, contribs_norm[:, i], bottom=bottom, color=colors[i], label=f'Y{i+1}')
    bottom += contribs_norm[:, i]

# Labels and legend
ax.set_title("Stacked Bar Plot of Component Contributions to Composite Treatment Effect", fontsize=14)
ax.set_xlabel("Individuals")
ax.set_ylabel("Proportion of Composite Treatment Effect")
ax.legend(title="Component Outcomes")
plt.tight_layout()
plt.show()









