import json
import yaml
import numpy as np
import random
import copy
import torch
import torch.optim as optim
from sklearn.metrics import f1_score, roc_auc_score, roc_curve

# json config load
class Config:
    def __init__(self, config_file):
        with open(config_file, "r") as f:
            # self._config = json.load(f)
            self._config = yaml.safe_load(f)

    def __getattr__(self, name):
        if name in self._config:
            return self._config[name]
        raise AttributeError(f"'Config' object has no attribute '{name}'")

def get_optimizer(optimizer_name, model, learning_rate):
    if optimizer_name == "adai":
        optimizer = adai_optim.Adai(
            model.parameters(), 
            lr=learning_rate, 
            betas=(0.1, 0.99), 
            eps=1e-3, 
            weight_decay=5e-4, 
            decoupled=True
        )
    elif optimizer_name == "adam":
        optimizer = optim.Adam(
            model.parameters(), 
            lr=learning_rate, 
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=1e-4,
            amsgrad=False
        )
    elif optimizer_name == "adamw":
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=learning_rate, 
            weight_decay=1e-4,
            betas=(0.9, 0.999),
            eps=1e-8,
        )
    elif optimizer_name == "sgd":
        optimizer = optim.SGD(
            model.parameters(), 
            lr=learning_rate, 
            momentum=0.9
        )
    else:
        raise ValueError(f"Unknown optimizer name: {optimizer_name}")
    return optimizer


class Subsample(object):
    """
    Subsample fixed length of ECG signals.

    Args:
        subsample_length (int): Length of subsampled data.
    """
    def __init__(self, subsample_length: int):

        assert isinstance(subsample_length, int)
        self.subsample_length = subsample_length

    def __call__(self, sample):
        """
        Args:
            sample (Dict): {"data": Array of shape (12, sequence_length).,
                            "label": label}
        Returns:
            sample (Dict): {"data": Array of shape (12, subsample_length).,
                            "label": label}
        """
        data, label = sample["data"], sample["label"]

        start = np.random.randint(0, data.shape[1] - self.subsample_length)
        subsampled_data = data[:, start:start+self.subsample_length]

        return {"data": subsampled_data, "label": label}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    # def __init__(self, label_type: str="float"):
    #     self.label_type = label_type

    def __call__(self, sample):
        data, label  = sample["data"], sample["label"]
        data_tensor = torch.from_numpy(data)
        label_tensor = torch.from_numpy(label)
        # data_tensor = data_tensor.float()
        # if self.label_type == "float":
        #     label_tensor = label_tensor.float()
        # elif self.label_type == "long":
        #     label_tensor = label_tensor.long()
        # else:
        #     raise NotImplementedError
        return {"data": data_tensor, "label": label_tensor}

class SubsampleEval(Subsample):
    """
    Subsampling for evaluation mode.

    Args:
        subsample_length (int): Length of subsampled data.
    """

    def _pad_signal(self, data):
        """
        Args:
            data (np.ndarray):
        Returns:
            padded_data (np.ndarray):
        """
        chunk_length = self.subsample_length // 2
        pad_length = chunk_length - data.shape[1] % chunk_length

        if pad_length == 0:
            return data
        pad = np.zeros([12, pad_length])
        pad_data = np.concatenate([data, pad], axis=-1)
        return pad_data

    def __call__(self, sample):
        """
        Args:
            sample (Dict): {"data": Array of shape (12, sequence_length).,
                            "label": label}
        Returns:
            sample (Dict): {"data": Array of shape (12, num_split, subsample_length).,
                            "label": label}
        """
        data, label = sample["data"], sample["label"]
        slice_indices = np.arange(0, data.shape[1], self.subsample_length // 2)
        index_range = np.arange(self.subsample_length)
        target_locs = slice_indices[:, np.newaxis] + index_range[np.newaxis]

        padded_data = self._pad_signal(data)
        try:
            eval_subsamples = padded_data[:, target_locs]
        except:
            eval_subsamples = padded_data[:, target_locs[:-1]]
        return {"data": eval_subsamples, "label": label}

class ProcessLabel(object):
    "Convert to multiclass label"

    def __init__(
        self,
        normal_index: int,
        target_index: int,
        num_classes: int = 3
    ) -> None:
        self.normal_index = normal_index
        self.target_index = target_index
        self.num_classes = num_classes

    def __call__(self, sample):
        """
        Args:
            sample (Dict): {"data": Array of shape (12, sequence_length).,
                            "label": label}
        Returns:
            sample (Dict): {"data": Array of shape (12, num_split, subsample_length).,
                            "label": label}
        """
        data, label = sample["data"], sample["label"]

        if label[self.target_index]:
            processed_label = 1
        elif label[self.normal_index]:
            processed_label = 0
        else:
            processed_label = 2
        processed_label = np.array(processed_label)

        return {"data": data, "label": processed_label}


def _calc_class_weight(labels, normal_index, target_index):
    """
    Calculate class weight for target dx and others (1 for normal labels).

    Args:
        labels (np.ndarray): Label data array of shape [num_sample, num_classes]
    Returns:
        class_weight (np.ndarray):
    """
    num_samples = labels.shape[0]
    # Extract normal and target dx labels
    normal_labels = labels[:, normal_index]
    target_labels = labels[:, target_index]
    # Validate normal and target dx label do not overlap
    # assert((normal_labels & target_labels).sum() == 0)

    num_normal = normal_labels.sum()
    num_target = target_labels.sum()
    num_others = num_samples - (num_normal + num_target)

    class_weights = [1, num_target/num_normal, num_others/num_normal]
    return np.array(class_weights)

# augmentation
def add_noise(x, noise_level = 0.01):
    noise = torch.randn_like(x) * noise_level
    noisy_data = x + noise
    # noisy_data = torch.clip(noisy_data, 0. ,1.)
    return noisy_data


# origin data preprocessing
# input data npy: 78w with only 0-1 normalization
# due to the large size of the dataset, we divided the npy into 6 pieces
# data is a dataset, keys: data,
def remove_nan_or_inf(dataset):
    """
    check NaN or Inf
    """
    invalid_samples = np.any(np.isnan(dataset.data) | np.isinf(dataset.data), axis=(1, 2))
    dataset.data = dataset.data[~invalid_samples]
    dataset.labels = dataset.labels[~invalid_samples]
    return dataset

def remove_nan_or_inf_inplace(dataset):
    invalid_mask = np.logical_or(np.isnan(dataset.data), np.isinf(dataset.data))
    invalid_samples = np.any(invalid_mask, axis=(1, 2))
    dataset.data = dataset.data[~invalid_samples]
    dataset.labels = dataset.labels[~invalid_samples]
    return dataset

def remove_out_of_range(dataset, min_val=-11, max_val=11):
    """
    check out-of-range data
    """
    mask = (dataset.data < min_val) | (dataset.data > max_val)
    invalid_samples = np.any(mask, axis=(1, 2))
    dataset.data = dataset.data[~invalid_samples]
    dataset.labels = dataset.labels[~invalid_samples]
    return dataset

def remove_statistical_outliers(dataset, mean_range=(-1.5, 1.5), std_range=(0.05, 2)):
    """
    check the mean and std
    """
    means = np.mean(dataset.data, axis=1)
    stds = np.std(dataset.data, axis=1)

    invalid_samples = (
        (means < mean_range[0]) | (means > mean_range[1]) | (stds < std_range[0]) | (stds > std_range[1])
    ).any(axis=1)
    dataset.data = dataset.data[~invalid_samples]
    dataset.labels = dataset.labels[~invalid_samples]
    return dataset

# def remove_spikes_batch(data, diff_threshold=10, batch_size=1000):
#     """
#     check spike in batch
#     """
#     clean_data = []
#     for i in range(0, data.shape[0], batch_size):
#         batch = data[i:i + batch_size]
#         diffs = np.diff(batch, axis=1)
#         spikes = np.abs(diffs) > diff_threshold
#         invalid_samples = np.any(spikes, axis=(1, 2))
#         clean_batch = batch[~invalid_samples]
#         clean_data.append(clean_batch)
#     return np.concatenate(clean_data, axis=0)

def remove_spikes(dataset, diff_threshold=10):
    """
    check spike
    """
    diffs = np.diff(dataset.data, axis=1) 
    spikes = np.abs(diffs) > diff_threshold 
    invalid_samples = np.any(spikes, axis=(1, 2))
    dataset.data = dataset.data[~invalid_samples]
    dataset.labels = dataset.labels[~invalid_samples]
    return dataset

def clean_ecg_data(dataset, min_val=-11, max_val=11, mean_range=(-1.5, 1.5), std_range=(0.05, 2), diff_threshold=10):
    dataset = remove_nan_or_inf_inplace(dataset)
    dataset = remove_out_of_range(dataset, min_val=-11, max_val=11)
    dataset = remove_statistical_outliers(dataset, mean_range=(-1.5, 1.5), std_range=(0.05, 2))
    dataset = remove_spikes(dataset, diff_threshold=10)
    return dataset

# def clean_ecg_data_in_batches(data, batch_size=2000):
#     cleaned_batches = []
#     for i in range(0, data.shape[0], batch_size):
#         batch = data[i:i + batch_size]
#         batch = clean_ecg_data(batch)
#         cleaned_batches.append(batch)
#     return np.concatenate(cleaned_batches, axis=0)

def select_dataset(dataset, index):
    # dataset.data:17443, 5000,12, dataset.labels:17443, 71
    # NORM: 4
    # STE: 57, totally 28 samples
    # LVH: 7, totally 2137 samples
    dataset_new = copy.deepcopy(dataset)
    selected_labels = dataset_new.labels[:, index] == 1.
    dataset_new.data = dataset_new.data[selected_labels]
    dataset_new.labels = dataset_new.labels[selected_labels]
    return dataset_new


class Monitor(object):

    def __init__(self):
        self.num_data = 0
        self.total_loss = 0
        self.ytrue_record = None
        self.ypred_record = None

    def _concat_array(self, record, new_data: np.array):
        """
        Args:

        Returns:

        """
        if record is None:
            return new_data
        else:
            return np.concatenate([record, new_data])

    def store_loss(self, loss: float, num_data: int) -> None:
        """
        Args:
            loss (float): Mini batch loss value.
            num_data (int): Number of data in mini batch.
        Returns:
            None
        """
        self.total_loss += loss
        self.num_data += num_data

    def store_result(self, y_trues: np.ndarray, y_preds: np.ndarray) -> None:
        """
        Args:
            y_trues (np.ndarray):
            y_preds (np.ndarray): Array with 0 - 1 values.
        Returns:
            None
        """
        y_trues = y_trues.cpu().detach().numpy()
        y_preds = y_preds.cpu().detach().numpy()

        self.ytrue_record = self._concat_array(self.ytrue_record, y_trues)
        self.ypred_record = self._concat_array(self.ypred_record, y_preds)
        assert(len(self.ytrue_record) == len(self.ypred_record))

    def average_loss(self) -> float:
        """
        Args:
            None
        Returns:
            average_loss (float):
        """
        return self.total_loss / self.num_data

    def _find_optimal_threshold(self, ytrue, ypred):
        """
        Find optimal cutoff threshold for given class prediction and true labels.
        From `https://github.com/helme/ecg_ptbxl_benchmarking/blob/06187fbc28992f26e15e44058d49f92e1485b079/code/utils/utils.py#L78`.

        Args:
            ytrue (np.ndarray):
            ypred (np.ndarray):
        Returns:

        """
        fpr, tpr, threshold = roc_curve(ytrue, ypred)
        optimal_idx = np.argmax(tpr - fpr)
        optimal_threshold = threshold[optimal_idx]
        return optimal_threshold

    def fmax_score(self):
        """
        ```
        the threshold is optimized on the respective test set for
        each classification task and classifier under consideration.
        ```
        Details of metric from CAFA challenge paper
        (`https://genomebiology.biomedcentral.com/track/pdf/10.1186/s13059-016-1037-6.pdf`)

        Args:
            None
        Returns:

        """
        num_classes = self.ytrue_record.shape[1]
        fmax_scores = []

        # Get optimal threshold for each classes and calculate Fmax score.
        for i in range(num_classes):
            threshold = self._find_optimal_threshold(
                self.ytrue_record[:, i], self.ypred_record[:, i])
            ypred = self.ypred_record[:, i] > threshold
            fmax_scores.append(f1_score(self.ytrue_record[:, i], ypred))
        return fmax_scores

    # def macro_auc_roc(self):
    #     """
    #     Args:

    #     Returns:

    #     """
    #     score = roc_auc_score(
    #         self.ytrue_record, self.ypred_record, average='macro')
    #     return score
    
    def macro_auc_roc(self):
        """
        macro-averaged ROC-AUC score。

        Args:
            y_true (array-like): (n_samples, n_classes)。
            y_pred (array-like): (n_samples, n_classes)。

        Returns:
            float: macro-averaged ROC-AUC score。
        """
        n_classes = self.ytrue_record.shape[1] 
        auc_scores = []
        
        for i in range(n_classes):
            if len(set(self.ytrue_record[:, i])) < 2:
                continue
            
            auc = roc_auc_score(self.ytrue_record[:, i], self.ypred_record[:, i])
            auc_scores.append(auc)
        
        if len(auc_scores) > 0:
            macro_auc = sum(auc_scores) / len(auc_scores)
        else:
            macro_auc = 0.0 
        
        return macro_auc   

    def macro_f1(self):
        """
        Args:

        Returns:

        """
        y_preds = np.argmax(self.ypred_record, axis=1)
        score = f1_score(self.ytrue_record, y_preds, average='macro')
        return score

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
