import torch
from torch.nn.utils.rnn import pack_sequence, pad_sequence
import numpy as np
import os
import pickle
from collections import Counter
import hashlib
import cv2
from data_utils.datasets import MIMICCXRDataset, UKBiobankMIDataset, UKBiobankPDDataset, UKBiobankStrokeDataset, ADNIDataset, AREDSDataset
from data_utils.datasets_shift import MIMICCXRDatasetShift, AREDSDatasetShift, ADNIDatasetShift
from algorithms.surv_model import DeepHit, NnetSurv, PMFSurv
from algorithms.fair_model import GroupDRO, Regularization, DomainInd, DomainIndAggregated, Reweighting
from torchxrayvision.datasets import normalize


def seed_hash(*args):
    """
    Derive an integer hash from all args, for use as a random seed.
    """
    args_str = str(args)
    return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31)


def get_class(class_name):
    if class_name == 'mimiccxr':
        class_name = 'MIMICCXRDataset'
    if class_name == 'adni':
        class_name = 'ADNIDataset'
    if class_name == 'areds':
        class_name = 'AREDSDataset'
    if class_name not in globals():
        raise NotImplementedError("Class not found: {}".format(class_name))
    return globals()[class_name]


def get_class_shift(class_name):
    if class_name == 'mimiccxr':
        class_name = 'MIMICCXRDatasetShift'
    if class_name == 'adni':
        class_name = 'ADNIDatasetShift'
    if class_name == 'areds':
        class_name = 'AREDSDatasetShift'
    if class_name not in globals():
        raise NotImplementedError("Class not found: {}".format(class_name))
    return globals()[class_name]


def data_transform(hparams, data):
    if hparams['dataset'] in ['areds']:
        return np.transpose((data / 255).astype(np.float32), (2, 0, 1))
    elif hparams['dataset'] in ['mimiccxr']:
        return np.transpose((data / 255).astype(np.float32), (2, 0, 1))
    elif hparams['dataset'] in ['adni']:
        data = data.astype(np.float32)
        data[data < 1e-7] = 0
        data = (data - np.min(data)) / (np.max(data) - np.min(data))
        data = np.transpose(data, (2, 0, 1))
        data = np.stack([data] * 3, axis=0) # add this line
        return data
    else:
        raise NotImplementedError


def save_model(model, optimizer, hparams, hparams_seed, seed):
    outfile = '%s_%s_%s_%s_%s_%s_%s_%s_%d_%d.ckpt' % (hparams['surv_model'], hparams['fair_model'], hparams['dataset'], hparams['sensitive_attribute'], hparams['metric'], hparams['pretrained'], hparams['shift'], hparams['group_shift'],  hparams_seed, seed)
    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'hparams': hparams}, os.path.join(hparams['model_dir'], outfile))


def load_model(model, optimizer, hparams, hparams_seed, seed):
    ckpt_name = '%s_%s_%s_%s_%s_%s_%s_%s_%d_%d.ckpt' % (hparams['surv_model'], hparams['fair_model'], hparams['dataset'], hparams['sensitive_attribute'], hparams['metric'], hparams['pretrained'], hparams['shift'], hparams['group_shift'], hparams_seed, seed)
    checkpoint = torch.load(os.path.join(hparams['model_dir'], ckpt_name), map_location=hparams['device'], weights_only=False)
    if hparams['fair_model'] == 'GroupDRO':
        model.q = (checkpoint['model_state_dict']['q'])
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    model.to(hparams['device'])
    return model


def save_result(score_val, score_test, hparams, hparams_seed, seed):
    outfile = 'score_%s_%s_%s_%s_%s_%s_%s_%s_%d_%d.pkl' % (hparams['surv_model'], hparams['fair_model'], hparams['dataset'], hparams['sensitive_attribute'], hparams['metric'], hparams['pretrained'], hparams['shift'], hparams['group_shift'], hparams_seed, seed)
    with open(os.path.join(hparams['score_dir'], outfile), 'wb') as f:
        pickle.dump({'val': score_val, 'test': score_test}, f)


def resampling_weight(sensitive_attribute_list):
    weight = 1 / np.bincount(sensitive_attribute_list)
    output = weight[sensitive_attribute_list]
    return output