import torch
from tqdm import tqdm
import numpy as np
import collections
import argparse
import random
import os
import pickle
from datetime import datetime
from sklearn.metrics import roc_auc_score, roc_curve
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
from collections import Counter
import sys
import pprint
from pathlib import Path

sys.path.insert(0, str(Path(sys.path[0]).parent.absolute()))

from data_utils.dataloader import InfiniteDataLoader, FastDataLoader
from utils.hparams_registry import random_hparams, default_hparams
from utils.misc import seed_hash, data_transform, save_model, load_model, save_result, resampling_weight
from utils.metrics import Metric
from pycox.models.utils import make_subgrid

start_time = datetime.now()

parser = argparse.ArgumentParser()
parser.add_argument('--surv_model', type=str, default='DeepHit', help="Choose from 'DeepHit', 'NnetSurv', 'PMF'")
parser.add_argument('--fair_model', type=str, default='GroupDRO', help="Choose from 'None', 'Regularization', 'GroupDRO', 'DomainInd', 'Reweighting', 'DomainIndAggregated'")
parser.add_argument('--dataset', type=str, default='mimiccxr', help="Choose from 'mimiccxr', 'areds', 'adni'")
parser.add_argument('--sensitive_attribute', type=str, default='sex', help="Choose from 'sex', 'age', 'race'")
parser.add_argument('--metric', type=str, default='ctd', help="Choose from 'ctd', 'brier', 'auc'")
parser.add_argument('--pretrained', action='store_true', help="Use pretrained model")
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--hparams_seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--shift', type=str, default='x', help="Choose from 'None', 'x', 'y', 'd'")
parser.add_argument('--group_shift', type=str, default='0', help="Choose from 'None', '0', '1'")

args = parser.parse_args()

if args.fair_model != 'None':
    assert args.surv_model == 'DeepHit', 'Fair TTE algorithms only works with DeepHit model currently.'
if args.shift == 'None':
    from utils.misc import get_class
else:
    from utils.misc import get_class_shift as get_class

# parameter initialization
if args.fair_model in ['DomainInd', 'DomainIndAggregated', 'Reweighting']:
    random.seed(args.hparams_seed)
    np.random.seed(args.hparams_seed)
    torch.manual_seed(args.hparams_seed)
    torch.cuda.manual_seed(args.hparams_seed)
else:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if args.gpu != 'osc':
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load hyperparameters
hparams = random_hparams(args.surv_model, args.fair_model, args.dataset, args.sensitive_attribute, args.metric, args.hparams_seed,
                         seed_hash(args.hparams_seed, args.surv_model, args.dataset, args.sensitive_attribute, args.metric))

hparams['device'] = device
hparams['pretrained'] = args.pretrained
hparams['model_dir'] = 'saved_model/%s' % args.fair_model
hparams['score_dir'] = 'output/%s' % args.fair_model
Path(hparams['model_dir']).mkdir(parents=True, exist_ok=True)
Path(hparams['score_dir']).mkdir(parents=True, exist_ok=True)
hparams['shift'] = args.shift
hparams['group_shift'] = args.group_shift

print(hparams)

# create datasets
dataset_class = get_class(args.dataset)
if args.shift == 'None':
    train_dataset = dataset_class(hparams, 'train', transform=data_transform)
else:
    train_dataset = dataset_class(hparams, 'train', transform=data_transform, shift=args.shift, group=int(args.group_shift))
label_transform = train_dataset.discretize_label()
val_dataset = dataset_class(hparams, 'val', transform=data_transform)
test_dataset = dataset_class(hparams, 'test', transform=data_transform)

# create dataloaders
if args.fair_model == 'Reweighting':
    weights = resampling_weight(train_dataset.mtdt[hparams['sensitive_attribute']].to_numpy())
    train_dataloader = iter(InfiniteDataLoader(train_dataset, weights, hparams['batch_size'], num_workers=args.num_workers, collate_fn=None))
else:
    train_dataloader = iter(InfiniteDataLoader(train_dataset, None, hparams['batch_size'], num_workers=args.num_workers, collate_fn=None))
val_dataloader = FastDataLoader(val_dataset, hparams['test_batch_size'], num_workers=args.num_workers, collate_fn=None)
test_dataloader = FastDataLoader(test_dataset, hparams['test_batch_size'], num_workers=args.num_workers, collate_fn=None)

time_grid_train_np = label_transform.cuts

# create metric

metric = Metric(hparams)
print('time_grid_train_np shape:{}, value:{}'.format(np.shape(time_grid_train_np), time_grid_train_np[0]))
# model initialization
if args.fair_model == 'None':
    model = get_class(args.surv_model)(hparams, time_grid_train_np, device).to(device)
else:
    model = get_class(args.fair_model)(hparams, time_grid_train_np, device).to(device)
# model training
train_loss = collections.defaultdict(lambda: [])
if args.metric in ['ctd', 'auc']:
    best_val_acc = float('-inf')
elif args.metric in ['brier']:
    best_val_acc = float('inf')
val_acc_list = []
model.train()
for step in tqdm(range(1, hparams['num_step']+1)):
    # mnb = [img, censoring, tte, sensitive]
    mnb = next(train_dataloader)
    mnb = [x.to(device) for x in mnb]
    # update the model    
    loss = model.update(mnb[0], mnb[2], mnb[1], mnb[3]) # input, Y, D, S
    
    for k, v in loss.items():
        train_loss[k].append(v)
    if (step % hparams['checkpoint_freq'] == 0) or (step == hparams['num_step']):
        pred_list = np.empty((0, len(time_grid_train_np)))
        y_list = np.empty(0)
        s_list = np.empty(0)
        d_list = np.empty(0)
        for k, v in train_loss.items():
            print('Train %s at step %d: %.4f' % (k, step, np.mean(v)))
        model.eval()
        for mnb in tqdm(val_dataloader):
            mnb = [x.to(device) for x in mnb]
            if args.fair_model == 'DomainInd':
                pred = model.predict_surv(mnb[0], mnb[3], batch_size=hparams['test_batch_size'], to_cpu=True, numpy=True)
            else:
                pred = model.predict_surv(mnb[0], batch_size=hparams['test_batch_size'], to_cpu=True, numpy=True)
            pred_list = np.concatenate((pred_list, pred), axis=0)
            d_list = np.concatenate((d_list, mnb[1].cpu().numpy()), axis=0)
            y_list = np.concatenate((y_list, mnb[2].cpu().numpy()), axis=0)
            s_list = np.concatenate((s_list, mnb[3].cpu().numpy()), axis=0)
        # predict, time_index_train, time_index_test, label_train, label_test, censoring_train, censoring_test, sensitive_attribute
        d_list_train = train_dataset.mtdt['indicator']
        y_list_train = train_dataset.mtdt['time_to_event']
        time_grid_test_np = np.unique(y_list)

        result_val = metric(pred_list, time_grid_train_np, time_grid_test_np, y_list_train, y_list, d_list_train, d_list, s_list)
        print('Val metric at step %d:' % step)
        pprint.pprint(result_val)
        val_acc_list.append(result_val['accuracy'])
        print('Val accuracy (%s) at step %d: %.4f' % (args.metric, step, result_val['accuracy']))
        if args.metric in ['ctd', 'auc']:
            if best_val_acc < result_val['accuracy']:
                best_val_acc = result_val['accuracy']
                save_model(model, model.optimizer, hparams, args.hparams_seed, args.seed)
                print('Model saved at step %d' % step)
        elif args.metric in ['brier']:
            if best_val_acc > result_val['accuracy']:
                best_val_acc = result_val['accuracy']
                save_model(model, model.optimizer, hparams, args.hparams_seed, args.seed)
                print('Model saved at step %d' % step)
        model.train()
        train_loss = collections.defaultdict(lambda: [])

# model evaluation
score = {}
if args.metric in ['ctd', 'auc']:
    best_step = (np.argmax(val_acc_list) + 1) * hparams['checkpoint_freq']
elif args.metric in ['brier']:
    best_step = (np.argmin(val_acc_list) + 1) * hparams['checkpoint_freq']
load_model(model, model.optimizer, hparams, args.hparams_seed, args.seed)
model.eval()

pred_list = np.empty((0, len(time_grid_train_np)))
y_list = np.empty(0)
s_list = np.empty(0)
d_list = np.empty(0)
for mnb in val_dataloader:
    mnb = [x.to(device) for x in mnb]
    if args.fair_model == 'DomainInd':
        pred = model.predict_surv(mnb[0], mnb[3], batch_size=hparams['test_batch_size'], to_cpu=True, numpy=True)
    else:
        pred = model.predict_surv(mnb[0], batch_size=hparams['test_batch_size'], to_cpu=True, numpy=True)
    pred_list = np.concatenate((pred_list, pred), axis=0)
    d_list = np.concatenate((d_list, mnb[1].cpu().numpy()), axis=0)
    y_list = np.concatenate((y_list, mnb[2].cpu().numpy()), axis=0)
    s_list = np.concatenate((s_list, mnb[3].cpu().numpy()), axis=0)
d_list_train = train_dataset.mtdt['indicator']
y_list_train = train_dataset.mtdt['time_to_event']
time_grid_test_np = np.unique(y_list)
result_val = metric(pred_list, time_grid_train_np, time_grid_test_np, y_list_train, y_list, d_list_train, d_list, s_list)
print('Val metric at best step (%d):' % best_step)
pprint.pprint(result_val)
print('Val accuracy (%s) at best step %d: %.4f' % (args.metric, best_step, result_val['accuracy']))

pred_list = np.empty((0, len(time_grid_train_np)))
y_list = np.empty(0)
s_list = np.empty(0)
d_list = np.empty(0)
for mnb in test_dataloader:
    mnb = [x.to(device) for x in mnb]
    if args.fair_model == 'DomainInd':
        pred = model.predict_surv(mnb[0], mnb[3], batch_size=hparams['test_batch_size'], to_cpu=True, numpy=True)
    else:
        pred = model.predict_surv(mnb[0], batch_size=hparams['test_batch_size'], to_cpu=True, numpy=True)
    pred_list = np.concatenate((pred_list, pred), axis=0)
    d_list = np.concatenate((d_list, mnb[1].cpu().numpy()), axis=0)
    y_list = np.concatenate((y_list, mnb[2].cpu().numpy()), axis=0)
    s_list = np.concatenate((s_list, mnb[3].cpu().numpy()), axis=0)
d_list_train = train_dataset.mtdt['indicator']
y_list_train = train_dataset.mtdt['time_to_event']
time_grid_test_np = np.unique(y_list)
result_test = metric(pred_list, time_grid_train_np, time_grid_test_np, y_list_train, y_list, d_list_train, d_list, s_list)
print('Test metric at best step (%d):' % best_step)
pprint.pprint(result_test)
print('Test accuracy (%s) at best step %d: %.4f' % (args.metric, best_step, result_test['accuracy']))

save_result(result_val, result_test, hparams, args.hparams_seed, args.seed)

end_time = datetime.now()

print('Running time: %s' % (end_time - start_time))
