# tarnet_torch.py
import numpy as np
from copy import deepcopy
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error

def set_seed(seed: int = 42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_layers, activation="elu", dropout=0.0):
        super().__init__()
        act = nn.ELU if activation.lower() == "elu" else nn.ReLU
        layers = []
        d = in_dim
        for _ in range(n_layers):
            layers += [nn.Linear(d, hidden_dim), act(), nn.Dropout(dropout)]
            d = hidden_dim
        self.net = nn.Sequential(*layers)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

@dataclass
class TARNetConfig:
    n_layers_rep: int = 3 
    n_units_rep: int = 200 
    n_layers_head: int = 2 
    n_units_head: int = 100 
    activation: str = "elu" 
    dropout: float = 0.0 
    lr: float = 1e-3 
    weight_decay: float = 1e-4 
    batch_size: int = 64 
    n_epochs: int = 200
    patience: int = 20
    early_stopping: bool = True    

    binary_y: bool = False
    standardize: bool = True
    random_state: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    verbose: bool = False

class TARNetTorch(nn.Module):

    def __init__(self, n_features: int, config: TARNetConfig = TARNetConfig()):
        super().__init__()
        self.cfg = deepcopy(config)
        set_seed(self.cfg.random_state)

        self.scaler_X = StandardScaler() if self.cfg.standardize else None
        self.rep = MLP(
            in_dim=n_features,
            hidden_dim=self.cfg.n_units_rep,
            n_layers=self.cfg.n_layers_rep,
            activation=self.cfg.activation,
            dropout=self.cfg.dropout
        )
        self.head0 = MLP(
            in_dim=self.cfg.n_units_rep,
            hidden_dim=self.cfg.n_units_head,
            n_layers=self.cfg.n_layers_head,
            activation=self.cfg.activation,
            dropout=self.cfg.dropout
        )
        self.head1 = MLP(
            in_dim=self.cfg.n_units_rep,
            hidden_dim=self.cfg.n_units_head,
            n_layers=self.cfg.n_layers_head,
            activation=self.cfg.activation,
            dropout=self.cfg.dropout
        )
        self.out0 = nn.Linear(self.cfg.n_units_head if self.cfg.n_layers_head > 0 else self.cfg.n_units_rep, 1)
        self.out1 = nn.Linear(self.cfg.n_units_head if self.cfg.n_layers_head > 0 else self.cfg.n_units_rep, 1)
        nn.init.zeros_(self.out0.bias); nn.init.zeros_(self.out1.bias)

        self.to(self.cfg.device)

        self._best_state = None
        self._fitted = False

    @staticmethod
    def _pehe(tau_hat, mu0=None, mu1=None):
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(np.sqrt(np.mean((tau_hat - tau_true) ** 2)))

    @staticmethod
    def _abs_ate_error(tau_hat, mu0=None, mu1=None):
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)))

    @staticmethod
    def _rel_ate_error(tau_hat, mu0=None, mu1=None):
        tau_true = np.asarray(mu1) - np.asarray(mu0)
        denom = abs(np.mean(tau_true)) + 1e-8
        return float(abs(np.mean(tau_hat) - np.mean(tau_true)) / denom)

    def _factual_rmse(self, X, T, Y):
        mu0, mu1 = self.predict_mu(X)
        mt = np.where(np.asarray(T).reshape(-1) == 1, mu1, mu0)
        return float(np.sqrt(mean_squared_error(Y, mt)))

    def att_abs_error_rct(self, X, T, Y, e):
        idx = np.where(e == 1)[0]
        treated_rct = idx[T[idx] == 1]
        control_rct = idx[T[idx] == 0]
        att_true = float(np.mean(Y[treated_rct]) - np.mean(Y[control_rct]))
        tau_hat_treated = self.predict_tau(X[treated_rct])
        att_hat = float(np.mean(tau_hat_treated))
        return abs(att_hat - att_true)

    def policy_risk_rct(self, X, T, Y, e, lam=0.0):
        idx = np.where(e == 1)[0]
        tau_hat = self.predict_tau(X[idx])
        pi = (tau_hat > lam).astype(int)
        p1 = np.mean(pi == 1)
        p0 = 1.0 - p1
        mask_treat = (T[idx] == 1)
        mask_ctrl = (T[idx] == 0)
        ytreat_pi1 = Y[idx][(pi == 1) & mask_treat]
        yctrl_pi0 = Y[idx][(pi == 0) & mask_ctrl]
        Ey1_pi1 = float(np.mean(ytreat_pi1))
        Ey0_pi0 = float(np.mean(yctrl_pi0))
        value = Ey1_pi1 * p1 + Ey0_pi0 * p0
        return 1.0 - value

    def forward(self, x):
        h = self.rep(x)
        h0 = self.head0(h) if self.cfg.n_layers_head > 0 else h
        h1 = self.head1(h) if self.cfg.n_layers_head > 0 else h
        y0 = self.out0(h0).squeeze(-1)  # logits if binary_y
        y1 = self.out1(h1).squeeze(-1)
        return y0, y1

    def fit(self, X_train, T_train, Y_train, val_split: float = 0.2):
        X = np.asarray(X_train, dtype=np.float32)
        T = np.asarray(T_train, dtype=np.float32).reshape(-1)
        Y = np.asarray(Y_train, dtype=np.float32).reshape(-1)

        if self.scaler_X is not None:
            X = self.scaler_X.fit_transform(X).astype(np.float32)

        n = X.shape[0]
        idx = np.arange(n)
        rng = np.random.default_rng(self.cfg.random_state)
        rng.shuffle(idx)
        n_val = int(n * val_split)
        val_idx = idx[:n_val]
        tr_idx = idx[n_val:]

        X_tr = torch.from_numpy(X[tr_idx]).to(self.cfg.device)
        T_tr = torch.from_numpy(T[tr_idx]).to(self.cfg.device)
        Y_tr = torch.from_numpy(Y[tr_idx]).to(self.cfg.device)
        X_va = torch.from_numpy(X[val_idx]).to(self.cfg.device)
        T_va = torch.from_numpy(T[val_idx]).to(self.cfg.device)
        Y_va = torch.from_numpy(Y[val_idx]).to(self.cfg.device)

        train_loader = DataLoader(
            TensorDataset(X_tr, T_tr, Y_tr),
            batch_size=self.cfg.batch_size, shuffle=True, drop_last=False
        )

        criterion = nn.BCEWithLogitsLoss() if self.cfg.binary_y else nn.MSELoss()
        optim = torch.optim.Adam(self.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay)
        best_val = float("inf")
        no_improve = 0

        self.train()
        for ep in range(self.cfg.n_epochs):
            ep_loss = 0.0
            for xb, tb, yb in train_loader:
                optim.zero_grad()
                y0_logit, y1_logit = self.forward(xb)
                yf = torch.where(tb > 0.5, y1_logit, y0_logit)
                loss = criterion(yf, yb)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5.0)
                optim.step()
                ep_loss += loss.item() * xb.size(0)
            ep_loss /= len(train_loader.dataset)

            self.eval()
            with torch.no_grad():
                y0v, y1v = self.forward(X_va)
                yfv = torch.where(T_va > 0.5, y1v, y0v)
                val_loss = criterion(yfv, Y_va).item()
            self.train()

            if self.cfg.verbose and ((ep + 1) % 10 == 0):
                print(f"Epoch {ep+1:03d} | train={ep_loss:.4f} | val={val_loss:.4f}")

            improved = val_loss < best_val - 1e-6
            if improved:
                best_val = val_loss
                self._best_state = deepcopy(self.state_dict())
                no_improve = 0
            else:
                no_improve += 1
                if self.cfg.early_stopping and no_improve >= self.cfg.patience:
                    if self.cfg.verbose:
                        print(f"Early stopping at epoch {ep+1}, best val={best_val:.4f}")
                    break

        if self._best_state is not None:
            self.load_state_dict(self._best_state)
        self._fitted = True
        return self

    def _transform_X(self, X):
        X = np.asarray(X, dtype=np.float32)
        if self.scaler_X is not None:
            X = self.scaler_X.transform(X).astype(np.float32)
        return torch.from_numpy(X).to(self.cfg.device)

    def predict_mu(self, X):
        self.eval()
        with torch.no_grad():
            Xt = self._transform_X(X)
            y0, y1 = self.forward(Xt)  
            if self.cfg.binary_y:
                y0 = torch.sigmoid(y0)
                y1 = torch.sigmoid(y1)
            return y0.detach().cpu().numpy().reshape(-1), y1.detach().cpu().numpy().reshape(-1)

    def predict_tau(self, X):
        mu0, mu1 = self.predict_mu(X)
        return (mu1 - mu0).reshape(-1)

    def evaluate(self, X, T=None, Y=None, e=None, mu0=None, mu1=None):

        metrics = {}
        X = np.asarray(X)
        tau_hat = self.predict_tau(X)
        if (mu0 is not None) and (mu1 is not None):
            metrics["PEHE"] = self._pehe(tau_hat, mu0, mu1)
            metrics["ATE_abs_error"] = self._abs_ate_error(tau_hat, mu0, mu1)
            metrics["rel_ATE_error"] = self._rel_ate_error(tau_hat, mu0, mu1)
        if (e is not None) and (T is not None) and (Y is not None):
            metrics["ATT_abs_error_rct"] = self.att_abs_error_rct(X, T, Y, e)
            metrics["policy_risk_rct"] = self.policy_risk_rct(X, T, Y, e)

        return metrics
