import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, TensorDataset
from copy import deepcopy

from mixup.survmixup import SurvMixup
from models.deepcox import DeepCox

torch.manual_seed(123)
np.random.seed(123)

def make_rings_survival(n_outer=500, n_inner=500, seed=123):
    rng = np.random.default_rng(seed)
    angles_o = rng.uniform(0, 2*np.pi, n_outer)
    r_o = rng.normal(1.0, 0.06, n_outer)
    Xo = np.c_[r_o*np.cos(angles_o), r_o*np.sin(angles_o)]
    angles_i = rng.uniform(0, 2*np.pi, n_inner)
    r_i = rng.normal(0.45, 0.06, n_inner)
    Xi = np.c_[r_i * np.cos(angles_i), r_i * np.sin(angles_i)]
    X = np.vstack([Xo, Xi]).astype(float)
    groups = np.array([0]*n_outer + [1]*n_inner)  # 1=high risk

    f_true = np.where(groups==1, 10, 3) 
    hazard = 0.25*np.exp(f_true)

    U = rng.uniform(0,1,size=len(X))
    T = -np.log(U)/hazard
    C = rng.exponential(scale=1 / 0.18, size=len(X))
    O = np.minimum(T, C).astype(float)
    E = (T<=C).astype(float)
    return X, O, E, groups

class ToyNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64),
            nn.GELU(),
            nn.Linear(64, 64),
            nn.GELU(),
            nn.Linear(64, 1),
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

device = 'cpu'
batch_size = 64
epochs = 1000

X, O, E, groups = make_rings_survival()
train_ds = TensorDataset(
    torch.tensor(X, dtype=torch.float32),
    torch.tensor(O, dtype=torch.float32),
    torch.tensor(E, dtype=torch.float32)
)
 
erm_net = ToyNN().to(device)
omix_net = deepcopy(erm_net)
hmix_net = deepcopy(erm_net)

erm_model = DeepCox(
    net=erm_net,
    opt=torch.optim.Adam(erm_net.parameters(), lr=1e-3),
    sch=None,
    mixup=None,
    epochs=epochs,
    batch_size=batch_size,
    device=device
)

omx_model = DeepCox(
    net=omix_net,
    opt=torch.optim.Adam(omix_net.parameters(), lr=1e-3),
    sch=None,
    mixup=SurvMixup(alpha=0.4, strategy='omix', keep_prev=True),
    epochs=epochs,
    batch_size=batch_size,
    device=device
)

hmx_model = DeepCox(
    net=hmix_net,
    opt=torch.optim.Adam(hmix_net.parameters(), lr=1e-3),
    sch=None,
    mixup=SurvMixup(alpha=0.4, strategy='hmix', keep_prev=True),
    epochs=epochs,
    batch_size=batch_size,
    device=device
)

erm_model.fit(train_ds)
omx_model.fit(train_ds)
hmx_model.fit(train_ds)

gx = np.linspace(-1.5, 1.5, 300)
gy = np.linspace(-1.5, 1.5, 300)
GX, GY = np.meshgrid(gx, gy)
G = np.stack([GX, GY], axis=-1).reshape(-1, 2).astype(np.float32)

t0 = float(np.median(O[E == 1]))

class GridDataset(Dataset):
    def __init__(self, Xgrid):
        self.Xgrid = torch.tensor(Xgrid, dtype=torch.float32)
        self.zero = torch.zeros(len(Xgrid), dtype=torch.float32)
    def __len__(self):
        return len(self.Xgrid)
    def __getitem__(self, i):
        return self.Xgrid[i], self.zero[i], self.zero[i]  # (x, o, e) but only x is used

S_erm = erm_model.survival_probability_at_times(GridDataset(G), np.array([t0])).reshape(-1)
S_omx = omx_model.survival_probability_at_times(GridDataset(G), np.array([t0])).reshape(-1)
S_hmx = hmx_model.survival_probability_at_times(GridDataset(G), np.array([t0])).reshape(-1)

shade_erm = (1.0 - S_erm).reshape(GX.shape)
shade_omx = (1.0 - S_omx).reshape(GX.shape)
shade_hmx = (1.0 - S_hmx).reshape(GX.shape)

def panel(ax, shade, title):
    ax.imshow(shade, extent=[gx.min(), gx.max(), gy.min(), gy.max()],
              origin='lower', cmap='Blues')
    ax.scatter(X[groups==0,0], X[groups==0,1], s=4, color='green', edgecolor='none', label='Low risk')
    ax.scatter(X[groups==1,0], X[groups==1,1], s=4, color='red', edgecolor='none', label='High risk')
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.set_xticks([]); ax.set_yticks([]); ax.set_aspect('equal')
    ax.set_title(title, fontsize=12)
    for spine in ax.spines.values():
        spine.set_visible(False)

fig, axes = plt.subplots(1, 3, figsize=(8, 3), gridspec_kw={'wspace':0})
panel(axes[0], shade_erm, "ERM")
panel(axes[1], shade_omx, "Naive Interpolation")
panel(axes[2], shade_hmx, "H-Mixup (Proposed)")
plt.tight_layout()
os.makedirs("output", exist_ok=True)
plt.savefig("output/toyfigure.pdf", dpi=500)
