import time
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from s1_model import MultiTaskCausalModel, train_model

class CausalMultiTaskDataset(Dataset):
    def __init__(self, X, A, Y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.A = torch.tensor(A, dtype=torch.float32)
        self.Y = torch.tensor(Y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx], self.Y[idx]

def simulate_runtime_dgp(n=1000, d=10, K=3, seed=42):
    np.random.seed(seed)
    X = np.random.randn(n, d)
    A = np.random.binomial(1, 0.5, size=n)
    beta = np.random.randn(K, d) * 0.5
    gamma = np.random.randn(K, d) * 0.5
    theta = np.random.randn(K, d) * 0.3
    tau_global = 1.0

    def compute_logit(X, A, k):
        linear = X @ beta[k]
        nonlinear = np.tanh(X @ gamma[k])
        treatment_effect = A * (tau_global + np.tanh(X @ theta[k]))
        return linear + nonlinear + treatment_effect

    Y0 = np.zeros((n, K))
    Y1 = np.zeros((n, K))
    for k in range(K):
        logit0 = compute_logit(X, np.zeros(n), k)
        logit1 = compute_logit(X, np.ones(n), k)
        thresh0 = np.quantile(logit0, 0.4)
        thresh1 = np.quantile(logit1, 0.4)
        Y0[:, k] = (logit0 > thresh0).astype(float)
        Y1[:, k] = (logit1 > thresh1).astype(float)

    Y = np.where(A[:, None] == 1, Y1, Y0)
    return X, A, Y

# Evaluate runtime
Ks = [2, 3, 5, 10, 20, 50, 100]
results = []

for K in Ks:
    X, A, Y = simulate_runtime_dgp(K=K)
    dataset = CausalMultiTaskDataset(X, A, Y)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    model = MultiTaskCausalModel(input_dim=X.shape[1], hidden_dim=64, num_outcomes=K)

    start = time.time()
    train_model(model, dataloader, num_epochs=200, lr=1e-3)
    end = time.time()
    
    results.append((K, round(end - start, 2)))

df = pd.DataFrame(results, columns=["Num Components (K)", "Runtime (seconds)"])
print(df)
