import os
os.makedirs('output', exist_ok=True)
import argparse
import torch
import models
import numpy as np
import pandas as pd

from copy import deepcopy
from sksurv.util import Surv
from sksurv.metrics import concordance_index_ipcw, integrated_brier_score

from data.mnist import make_mnist_survival
from data.octmnist import make_octmnist_survival
from data.pathmnist import make_pathmnist_survival
from data.tissuemnist import make_tissuemnist_survival
from data.tiny_imagenet import make_tinyimagenet_survival
from data.organmnist3d import make_organmnist3d_survival
from data.retinamnist import make_retinamnist_survival

from mixup.survmixup import SurvMixup
from tools.evaluate import *
from tools.labeltransform import LabTransDiscreteTime
from networks import MNISTModel, ImagenetModel, MNIST3DModel

torch.set_float32_matmul_precision('high')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description='Script to reproduce experimental results on SurvMix')
parser.add_argument('--dataset', type=str, default='mnist', help='Dataset to use, mnist | octmnist | pathmnist | tissuemnist | tinyimagenet')
parser.add_argument('--model', type=str, default='DeepHit', help='Survival analysis method to use, DeepCox | DeepAFT | DeepIBS | DeepHit | DeepMTLR')
parser.add_argument('--backbone', type=str, default='resnet18', help='Backbone network architecture')
parser.add_argument('--mixup_strategy', type=str, default='hmix', help='Mixup strategy to use, erm | hmix | chmix | smix | omix')
parser.add_argument('--mixup_alpha', type=float, default=0, help='Mixup alpha value, ignored if mixup_strategy is erm')
parser.add_argument('--keep_prev', action='store_true', help='Whether to keep prev class distribution in mixup')
parser.add_argument('--seed', type=str, default='0', help='Random seed for reproducibility')
args = parser.parse_args()

print(args)
for seed in [int(s) for s in args.seed.split(',')]:
    args.seed = seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    train_transforms, test_transforms = None, None
    if args.dataset == 'tinyimagenet':
        from tools.augmentation import train_transforms, test_transforms
        train_ds, valid_ds, test_ds, risk_assign = make_tinyimagenet_survival(seed=args.seed)
    elif args.dataset == 'mnist':
        train_ds, valid_ds, test_ds, risk_assign = make_mnist_survival(seed=args.seed)
        args.backbone = 'mnist_cnn'
    elif args.dataset == 'octmnist':
        train_ds, valid_ds, test_ds, risk_assign = make_octmnist_survival(seed=args.seed)
    elif args.dataset == 'pathmnist':
        train_ds, valid_ds, test_ds, risk_assign = make_pathmnist_survival(seed=args.seed)
    elif args.dataset == 'tissuemnist':
        train_ds, valid_ds, test_ds, risk_assign = make_tissuemnist_survival(seed=args.seed)
    elif args.dataset == 'organmnist3d':
        train_ds, valid_ds, test_ds, risk_assign = make_organmnist3d_survival(seed=args.seed)
    elif args.dataset == 'retinamnist':
        train_ds, valid_ds, test_ds, risk_assign = make_retinamnist_survival(seed=args.seed)

    n_disc = 50
    discretizer = LabTransDiscreteTime(num_durations=n_disc, scheme='quantile')
    discretizer.fit(train_ds.times.numpy())

    test_times  = test_ds.times.numpy()
    test_events = test_ds.events.numpy().astype(bool)
    last_event_time = float(test_times[test_events].max()) if test_events.any() else float(test_times.max())
    time_windows = np.quantile(test_times[test_times <= last_event_time],np.round(np.arange(0.1, 1.0, 0.1), 1))

    configs = {
        'DeepCox': {'n_outputs': 1, 'output_activation': 'linear'},
        'DeepAFT': {'n_outputs': 2, 'output_activation': 'linear'},
        'DeepIBS': {'n_outputs': n_disc+1, 'output_activation': 'softmax'},
        'DeepHit': {'n_outputs': n_disc+1, 'output_activation': 'softmax'},
        'DeepMTLR': {'n_outputs': n_disc+1, 'output_activation': 'linear'},
    }

    model_name = args.model
    mixup_alpha = args.mixup_alpha
    results = {
        'dataset': args.dataset,
        'backbone': args.backbone,
        'model': model_name,
        'mixup_strategy': args.mixup_strategy,
        'mixup_alpha': mixup_alpha,
        'keep_prev': str(args.keep_prev),
        'seed': args.seed
    }

    if (args.dataset == 'mnist') or (args.dataset in ['octmnist', 'tissuemnist']):
        args.backbone = 'mnist_cnn' 
        net = MNISTModel(
            n_outputs=configs[model_name]['n_outputs'],
            output_activation=configs[model_name]['output_activation'])

        batch_size = 128
        steps_per_epoch = len(train_ds) // batch_size
        max_epochs = 200
        milestones_steps = [int(max_epochs * 0.5) * steps_per_epoch, int(max_epochs * 0.75) * steps_per_epoch]

        opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-6)
        sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=milestones_steps, gamma=0.1)

    elif args.dataset == 'organmnist3d':
        args.backbone = 'cnn_3d'
        net = MNIST3DModel(
            n_outputs=configs[model_name]['n_outputs'],
            output_activation=configs[model_name]['output_activation']
        )
        batch_size = 32
        steps_per_epoch = len(train_ds) // batch_size
        max_epochs = 100
        milestones_steps = [int(max_epochs * 0.5) * steps_per_epoch,
                            int(max_epochs * 0.75) * steps_per_epoch]
        opt = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-6)
        sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=milestones_steps, gamma=0.1)

    else:
        net = ImagenetModel(
            n_outputs=configs[model_name]['n_outputs'],
            output_activation=configs[model_name]['output_activation'],
            backbone=args.backbone
        )

        if args.dataset == 'tinyimagenet':
            batch_size = 512
            opt = torch.optim.AdamW(net.parameters(), lr=1e-2, weight_decay=1e-6)

        elif args.dataset == 'retinamnist':
            batch_size = 64
            opt = torch.optim.AdamW(net.parameters(), lr=1e-2, weight_decay=1e-6)

        steps_per_epoch = len(train_ds) // batch_size
        max_epochs = 200
        milestones_steps = [int(max_epochs * 0.5) * steps_per_epoch, int(max_epochs * 0.75) * steps_per_epoch]

        sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=milestones_steps, gamma=0.1)

    net = net.to(device)
    net = torch.compile(net)

    if args.mixup_strategy in ['hmix', 'chmix', 'smix', 'omix']:
        mixup = SurvMixup(alpha=args.mixup_alpha, strategy=args.mixup_strategy, keep_prev=args.keep_prev, device=device)
    else:
        mixup = None

    model_args = {
        'net': net,
        'opt': opt,
        'sch': sch,
        'mixup': mixup,
        'epochs': max_epochs,
        'discretizer': discretizer,
        'train_transform': train_transforms,
        'test_transform': test_transforms
    }

    m = getattr(models, model_name)(**model_args)
    m.fit(train_ds, valid_ds)

    te_S = m.survival_probability_at_times(test_ds, times=time_windows)
    te_risk_all = 1.0 - np.clip(te_S, 1e-8, 1.0)

    y_tr = Surv.from_arrays(event=train_ds.events.numpy().astype(bool),
                            time=train_ds.times.numpy().astype(float))
    y_te = Surv.from_arrays(event=test_ds.events.numpy().astype(bool),
                            time=test_ds.times.numpy().astype(float))

    c_test_by_t = np.array([
        concordance_index_ipcw(y_tr, y_te, te_risk_all[:, j], float(time_windows[j]))[0]
        for j in range(te_S.shape[1])
    ], dtype=float)

    test_ibs = float(integrated_brier_score(y_tr, y_te, te_S, time_windows))

    test_ici_scores = []
    test_dcal_p_values = []
    test_ece_scores = []

    for j, t in enumerate(time_windows):
        te_S_t = te_S[:, j]
        
        ici = calculate_ici_survival(y_te, te_S_t, t)
        _, d_cal_p = calculate_d_calibration(y_te, te_S_t, t)
        ece = calculate_ece_survival(y_te, te_S_t, t)
        
        test_ici_scores.append(ici)
        test_dcal_p_values.append(d_cal_p)
        test_ece_scores.append(ece)

    results["test_ibs"] = test_ibs
    results["test_td_cindex_avg"] = float(np.nanmean(c_test_by_t))
    results["test_ici_avg"] = float(np.nanmean(test_ici_scores))
    results["test_d_cal_p_avg"] = float(np.nanmean(test_dcal_p_values))
    results["test_ece_avg"] = float(np.nanmean(test_ece_scores))

    for i in range(len(time_windows)):
        results[f"test_td_cindex_t{i:02d}"] = float(c_test_by_t[i])
        results[f"test_ici_t{i:02d}"] = float(test_ici_scores[i])
        results[f"test_d_cal_p_t{i:02d}"] = float(test_dcal_p_values[i])
        results[f"test_ece_t{i:02d}"] = float(test_ece_scores[i])

    results_df = pd.DataFrame(results, index=[0]).T
    results_df.to_csv(f'output/{args.dataset}+{args.backbone}+{args.model}+{args.mixup_strategy}+{args.mixup_alpha:.1f}+{str(args.keep_prev)}+{args.seed:02d}.csv', header=False)
