
import warnings
warnings.filterwarnings("ignore")

import os
os.makedirs('output', exist_ok=True)

import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer

import models
from tools.labeltransform import LabTransDiscreteTime
from tools.evaluate import calculate_ici_survival, calculate_d_calibration, calculate_ece_survival
from sksurv.util import Surv
from sksurv.metrics import concordance_index_ipcw, integrated_brier_score
from mixup.survmixup import SurvMixup

from data.framingham import Framingham
from data.metabric import Metabric
from data.support import Support
from data.flchain import FLChain

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description='Tabular Survival Analysis (features-first dataset)')
parser.add_argument('--dataset', type=str, default='framingham', choices=['metabric', 'support', 'framingham', 'flchain'])
parser.add_argument('--framingham_event', type=str, default='CVD', help='Framingham target (e.g., CVD, STROKE, DEATH, ANYCHD)')
parser.add_argument('--model', type=str, default='DeepHit', choices=['DeepMTLR','DeepHit','DeepAFT','DeepCox','DeepIBS'])
parser.add_argument('--mixup_strategy', type=str, default='none', help='hmix | chmix | smix | omix | none')
parser.add_argument('--mixup_alpha', type=float, default=0.1)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--test_size', type=float, default=0.3)
parser.add_argument('--valid_size', type=float, default=0.1)
parser.add_argument('--n_disc', type=int, default=20)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lr', type=float, default=5e-4)
parser.add_argument('--weight_decay', type=float, default=1e-6)
parser.add_argument('--hidden', type=int, nargs='+', default=[64, 32])
args = parser.parse_args()

torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.dataset == 'metabric':
    ds = Metabric()
    df = ds.data.copy()
    event_col, time_col = ds.event, ds.time
    cont_cols, cat_cols = ds.continuous_predictors, ds.categorical_predictors

elif args.dataset == 'support':
    ds = Support()
    df = ds.data.copy()
    event_col, time_col = ds.event, ds.time
    cont_cols, cat_cols = ds.continuous_predictors, ds.categorical_predictors

elif args.dataset == 'framingham':
    ds = Framingham(target_event=args.framingham_event)
    df = ds.preprocess()  
    event_col, time_col = 'event', 'time'
    cont_cols, cat_cols = ds.continuous_predictors, ds.categorical_predictors

elif args.dataset == 'flchain':
    ds = FLChain()
    df = ds.data.copy()
    event_col, time_col = ds.event, ds.time
    cont_cols, cat_cols = ds.continuous_predictors, ds.categorical_predictors

else:
    raise ValueError("Unknown dataset")

cont_cols = [c for c in cont_cols if c in df.columns]
cat_cols  = [c for c in cat_cols  if c in df.columns]

times_all  = df[time_col].to_numpy(dtype='float32')
events_all = df[event_col].to_numpy().astype(dtype='float32')

idx_all = np.arange(len(df))
idx_tr, idx_te = train_test_split(
    idx_all, test_size=args.test_size, random_state=args.seed,
    stratify=events_all if events_all.sum() > 0 else None
)
idx_tr, idx_va = train_test_split(
    idx_tr,
    test_size=args.valid_size / (1.0 - args.test_size),
    random_state=args.seed,
    stratify=events_all[idx_tr] if events_all[idx_tr].sum() > 0 else None
)

df_tr = df.iloc[idx_tr].reset_index(drop=True)
df_va = df.iloc[idx_va].reset_index(drop=True)
df_te = df.iloc[idx_te].reset_index(drop=True)

t_tr = df_tr[time_col].to_numpy(dtype='float32')
e_tr = df_tr[event_col].to_numpy(dtype='float32')
t_va = df_va[time_col].to_numpy(dtype='float32')
e_va = df_va[event_col].to_numpy(dtype='float32')
t_te = df_te[time_col].to_numpy(dtype='float32')
e_te = df_te[event_col].to_numpy(dtype='float32')

cont_pipe = Pipeline([
    ('imputer', SimpleImputer(strategy='mean')),
    ('scaler', MinMaxScaler())
])

try:
    cat_pipe = Pipeline([
        ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False))
    ])
except TypeError:
    cat_pipe = Pipeline([
        ('onehot', OneHotEncoder(handle_unknown='ignore', sparse=False))
    ])

Xnum_tr = cont_pipe.fit_transform(df_tr[cont_cols]) if len(cont_cols) > 0 else np.empty((len(df_tr), 0), dtype=float)
Xcat_tr = cat_pipe.fit_transform(df_tr[cat_cols])   if len(cat_cols)  > 0 else np.empty((len(df_tr), 0), dtype=float)

Xnum_va = cont_pipe.transform(df_va[cont_cols]) if len(cont_cols) > 0 else np.empty((len(df_va), 0), dtype=float)
Xcat_va = cat_pipe.transform(df_va[cat_cols])   if len(cat_cols)  > 0 else np.empty((len(df_va), 0), dtype=float)

Xnum_te = cont_pipe.transform(df_te[cont_cols]) if len(cont_cols) > 0 else np.empty((len(df_te), 0), dtype=float)
Xcat_te = cat_pipe.transform(df_te[cat_cols])   if len(cat_cols)  > 0 else np.empty((len(df_te), 0), dtype=float)

X_tr = np.concatenate([Xnum_tr, Xcat_tr], axis=1)
X_va = np.concatenate([Xnum_va, Xcat_va], axis=1)
X_te = np.concatenate([Xnum_te, Xcat_te], axis=1)

in_dim = X_tr.shape[1]
if in_dim == 0:
    raise ValueError("No predictor columns found after preprocessing.")

class TabularSurvDataset(torch.utils.data.Dataset):
    def __init__(self, times, events, X):
        self.times = torch.as_tensor(times, dtype=torch.float32)
        self.events = torch.as_tensor(events.astype('float32'), dtype=torch.float32)
        self.X = torch.as_tensor(X, dtype=torch.float32)

    def __len__(self):
        return len(self.times)

    def __getitem__(self, idx):
        # Features first, then labels — matches your image datasets
        return self.X[idx], self.times[idx], self.events[idx]

train_ds = TabularSurvDataset(t_tr, e_tr, X_tr)
valid_ds = TabularSurvDataset(t_va, e_va, X_va)
test_ds  = TabularSurvDataset(t_te, e_te, X_te)

discretizer = LabTransDiscreteTime(num_durations=args.n_disc, scheme='quantile')
discretizer.fit(train_ds.times.numpy()[train_ds.events.numpy().astype(bool)])

train_times = train_ds.times.numpy()
train_events = train_ds.events.numpy().astype(bool)
train_max_obs = float(train_times[train_events].max())
cap = np.nextafter(train_max_obs, -np.inf)  # the next float below the max

base = train_times[train_times < cap]
qs = np.round(np.arange(0.1, 1.0, 0.1), 1)
time_windows = np.quantile(base, qs)
test_ds.times = torch.clamp(test_ds.times, max=cap)


class TabMLP(nn.Module):
    def __init__(self, in_dim, hidden=[32, 16], out_dim=1, output_activation='linear'):
        super().__init__()
        layers = []
        prev = in_dim
        for h in (hidden or []):
            layers += [nn.Linear(prev, h), nn.GELU()]
            prev = h 
        self.backbone = nn.Sequential(*layers) if layers else nn.Identity()
        self.head = nn.Linear(prev, out_dim)
        self.act = nn.Softmax(dim=-1) if output_activation == 'softmax' else None

    @staticmethod
    def _ensure_2d(x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 0:   # scalar
            return x.view(1, 1)
        if x.dim() == 1:   # [D]
            return x.unsqueeze(0)
        if x.dim() > 2:    # e.g., [*, D]
            return x.view(-1, x.size(-1))
        return x

    def forward(self, x):
        if not torch.is_tensor(x):
            if isinstance(x, (list, tuple)):
                tensors = [z for z in x if torch.is_tensor(z)]
                if not tensors:
                    raise ValueError("Expected tensor input X or (X, ...).")
                x = tensors[0]  # FIRST tensor is X
            elif isinstance(x, dict) and 'X' in x and torch.is_tensor(x['X']):
                x = x['X']
            else:
                raise ValueError("Expected tensor input X.")
        x = self._ensure_2d(x)
        z = self.backbone(x)
        out = self.head(z)
        return self.act(out) if self.act is not None else out

configs = {
    'DeepCox': {'n_outputs': 1,             'output_activation': 'linear'},
    'DeepAFT': {'n_outputs': 2,             'output_activation': 'linear'},
    'DeepIBS': {'n_outputs': args.n_disc+1, 'output_activation': 'softmax'},
    'DeepHit': {'n_outputs': args.n_disc+1, 'output_activation': 'softmax'},
    'DeepMTLR':{'n_outputs': args.n_disc+1, 'output_activation': 'linear'},
}
cfg = configs[args.model]

net = TabMLP(in_dim=in_dim, hidden=args.hidden, out_dim=cfg['n_outputs'], output_activation=cfg['output_activation']).to(device)
try:
    net = torch.compile(net)
except Exception:
    pass

batch_size = args.batch_size
steps_per_epoch = max(1, len(train_ds) // max(1, batch_size))
max_epochs = args.epochs
milestones_steps = [int(max_epochs * 0.5) * steps_per_epoch]

opt = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=milestones_steps, gamma=0.1)

mixup = None
if args.mixup_strategy in ['hmix', 'chmix', 'smix', 'omix'] and args.mixup_alpha > 0:
    mixup = SurvMixup(alpha=args.mixup_alpha, strategy=args.mixup_strategy, device=device)
else:
    args.mixup_alpha = 0.0
    args.mixup_strategy = 'none'

model_args = {
    'net': net,
    'opt': opt,
    'sch': sch,
    'mixup': mixup,
    'epochs': max_epochs,
    'discretizer': discretizer,
}

m = getattr(models, args.model)(**model_args)
m.fit(train_ds, valid_ds)

te_S = m.survival_probability_at_times(test_ds, times=time_windows)
te_risk_all = 1.0 - np.clip(te_S, 1e-8, 1.0)

y_tr = Surv.from_arrays(event=train_ds.events.numpy().astype(bool),
                        time=train_ds.times.numpy().astype(float))
y_te = Surv.from_arrays(event=test_ds.events.numpy().astype(bool),
                        time=test_ds.times.numpy().astype(float))

c_test_by_t = np.array([
    concordance_index_ipcw(y_tr, y_te, te_risk_all[:, j], float(time_windows[j]))[0]
    for j in range(te_S.shape[1])
], dtype=float)

test_ibs = float(integrated_brier_score(y_tr, y_te, te_S, time_windows))

test_ici_scores, test_dcal_p_values, test_ece_scores = [], [], []
for j, t in enumerate(time_windows):
    te_S_t = te_S[:, j]
    ici = calculate_ici_survival(y_te, te_S_t, t)
    _, d_cal_p = calculate_d_calibration(y_te, te_S_t, t)
    ece = calculate_ece_survival(y_te, te_S_t, t)
    test_ici_scores.append(float(ici))
    test_dcal_p_values.append(float(d_cal_p))
    test_ece_scores.append(float(ece))

results = dict(
    dataset=args.dataset,
    model=args.model,
    mixup_strategy=args.mixup_strategy,
    mixup_alpha=args.mixup_alpha,
    seed=args.seed,
    test_ibs=test_ibs,
    test_td_cindex_avg=float(np.nanmean(c_test_by_t)),
    test_ici_avg=float(np.nanmean(test_ici_scores)),
    test_d_cal_p_avg=float(np.nanmean(test_dcal_p_values)),
    test_ece_avg=float(np.nanmean(test_ece_scores))
)
for i in range(len(time_windows)):
    results[f"test_td_cindex_t{i:02d}"] = float(c_test_by_t[i])
    results[f"test_ici_t{i:02d}"] = float(test_ici_scores[i])
    results[f"test_d_cal_p_t{i:02d}"] = float(test_dcal_p_values[i])
    results[f"test_ece_t{i:02d}"] = float(test_ece_scores[i])

outname = f'output/tabular-{args.dataset}+{args.model}+{args.mixup_strategy}+{args.mixup_alpha:.1f}+{args.seed:02d}.csv'
pd.DataFrame(results, index=[0]).T.to_csv(outname, header=False)
print(f"Saved results to {outname}")
