import torch
from tqdm import tqdm
import numpy as np
import collections
import argparse
import random
import os
import pickle
from datetime import datetime
from sklearn.metrics import roc_auc_score, roc_curve
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
from collections import Counter
import sys
import pprint
from pathlib import Path

sys.path.insert(0, str(Path(sys.path[0]).parent.absolute()))

from data_utils.dataloader import InfiniteDataLoader, FastDataLoader
from utils.hparams_registry import random_hparams, default_hparams
from utils.misc import seed_hash, data_transform, save_model, load_model, save_result, resampling_weight
from utils.metrics import Metric
from pycox.models.utils import make_subgrid

start_time = datetime.now()

parser = argparse.ArgumentParser()
parser.add_argument('--surv_model', type=str, default='DeepHit', help="Choose from 'DeepHit', 'NnetSurv', 'PMF'")
parser.add_argument('--fair_model', type=str, default='GroupDRO', help="Choose from 'None', 'Regularization', 'GroupDRO', 'DomainInd', 'Reweighting', 'DomainIndAggregated'")
parser.add_argument('--dataset', type=str, default='mimiccxr', help="Choose from 'mimiccxr', 'areds', 'adni'")
parser.add_argument('--sensitive_attribute', type=str, default='sex', help="Choose from 'sex', 'age', 'race'")
parser.add_argument('--metric', type=str, default='ctd', help="Choose from 'ctd', 'brier', 'auc'")
parser.add_argument('--pretrained', action='store_true', help="Use pretrained model")
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--hparams_seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--shift', type=str, default='x', help="Choose from 'None', 'x', 'y', 'd'")
parser.add_argument('--group_shift', type=str, default='0', help="Choose from 'None', '0', '1'")

args = parser.parse_args()

if args.fair_model != 'None':
    assert args.surv_model == 'DeepHit', 'Fair TTE algorithms only works with DeepHit model currently.'
if args.shift == 'None':
    from utils.misc import get_class
else:
    from utils.misc import get_class_shift as get_class

# parameter initialization
if args.fair_model in ['DomainInd', 'DomainIndAggregated', 'Reweighting']:
    random.seed(args.hparams_seed)
    np.random.seed(args.hparams_seed)
    torch.manual_seed(args.hparams_seed)
    torch.cuda.manual_seed(args.hparams_seed)
else:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if args.gpu != 'osc':
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load hyperparameters
hparams = random_hparams(args.surv_model, args.fair_model, args.dataset, args.sensitive_attribute, args.metric, args.hparams_seed,
                         seed_hash(args.hparams_seed, args.surv_model, args.dataset, args.sensitive_attribute, args.metric))

hparams['device'] = device
hparams['pretrained'] = args.pretrained
hparams['model_dir'] = 'saved_model/%s' % args.fair_model
hparams['shift'] = args.shift
hparams['group_shift'] = args.group_shift
hparams['repr_dir'] = 'repr'
Path(hparams['repr_dir']).mkdir(parents=True, exist_ok=True)

print(hparams)

# create datasets
dataset_class = get_class(args.dataset)
if args.shift == 'None':
    train_dataset = dataset_class(hparams, 'train', transform=data_transform)
else:
    train_dataset = dataset_class(hparams, 'train', transform=data_transform, shift=args.shift, group=int(args.group_shift))
label_transform = train_dataset.discretize_label()
test_dataset = dataset_class(hparams, 'test', transform=data_transform)

# create dataloaders
train_dataloader = FastDataLoader(train_dataset, hparams['test_batch_size'], num_workers=args.num_workers, collate_fn=None)
test_dataloader = FastDataLoader(test_dataset, hparams['test_batch_size'], num_workers=args.num_workers, collate_fn=None)

time_grid_train_np = label_transform.cuts

# create metric
metric = Metric(hparams)
print('time_grid_train_np shape:{}, value:{}'.format(np.shape(time_grid_train_np), time_grid_train_np[0]))

# model initialization
if args.fair_model == 'None':
    model = get_class(args.surv_model)(hparams, time_grid_train_np, device).to(device)
else:
    model = get_class(args.fair_model)(hparams, time_grid_train_np, device).to(device)

# model evaluation

load_model(model, None, hparams, args.hparams_seed, args.seed)
model.eval()

repr_list = np.empty((0, hparams['feature_dim']))
y_list = np.empty(0)
s_list = np.empty(0)
d_list = np.empty(0)
for mnb in tqdm(train_dataloader):
    mnb = [x.to(device) for x in mnb]
    representation = model.get_repr(mnb[0]).cpu().detach().numpy()
    repr_list = np.concatenate((repr_list, representation), axis=0)
    d_list = np.concatenate((d_list, mnb[1].cpu().numpy()), axis=0)
    y_list = np.concatenate((y_list, mnb[2].cpu().numpy()), axis=0)
    s_list = np.concatenate((s_list, mnb[3].cpu().numpy()), axis=0)
output_train = {'repr': repr_list, 'y': y_list, 's': s_list, 'd': d_list}

rerp_list = np.empty((0, hparams['feature_dim']))
y_list = np.empty(0)
s_list = np.empty(0)
d_list = np.empty(0)
for mnb in tqdm(test_dataloader):
    mnb = [x.to(device) for x in mnb]
    representation = model.get_repr(mnb[0]).cpu().detach().numpy()
    repr_list = np.concatenate((repr_list, representation), axis=0)
    d_list = np.concatenate((d_list, mnb[1].cpu().numpy()), axis=0)
    y_list = np.concatenate((y_list, mnb[2].cpu().numpy()), axis=0)
    s_list = np.concatenate((s_list, mnb[3].cpu().numpy()), axis=0)
output_test = {'repr': repr_list, 'y': y_list, 's': s_list, 'd': d_list}

output = {'train': output_train, 'test': output_test}

# save result

outfile = 'repr_%s_%s_%s_%s_%s_%s_%s_%s_%d_%d.pkl' % (args.surv_model, args.fair_model, args.dataset, args.sensitive_attribute, args.metric, args.pretrained, args.shift, args.group_shift, args.hparams_seed, args.seed)

with open(os.path.join(hparams['repr_dir'], outfile), 'wb') as f:
    pickle.dump(output, f)

end_time = datetime.now()

print('Running time: %s' % (end_time - start_time))
