import numpy as np
import torch
from ogb.utils import smiles2graph
from torch.utils.data import Dataset
from torch_geometric.data import Data
import sys
from itertools import product

from sklearn.metrics import (
    roc_auc_score,
    f1_score,
    average_precision_score,
)

def llprint(message):
    sys.stdout.write(message)
    sys.stdout.flush()


def ddi_rate_score(record, ddi_graph):
    all_cnt = 0
    dd_cnt = 0
    for patient in record:
        for adm in patient:
            med = np.array(adm)
            if len(med) == 0:
                continue
            if np.max(med) >= ddi_graph.shape[0] or np.min(med) < 0:
                print(f"index out of range (min={np.min(med)}, max={np.max(med)}, graph_size={ddi_graph.shape[0]})")
                continue
                
            idx_pairs = np.triu_indices(len(med), k=1)  
            med_i = med[idx_pairs[0]]
            med_j = med[idx_pairs[1]]
            
            try:
                ddi_pairs = ddi_graph[med_i, med_j] | ddi_graph[med_j, med_i]
                dd_cnt += np.sum(ddi_pairs)
                all_cnt += len(ddi_pairs)
            except IndexError as e:
                print(f": {e}")
                print(f"i={med_i}, j={med_j}")
                print(f"shape: {ddi_graph.shape}")
                continue
                
    if all_cnt == 0:
        return 0
    return dd_cnt / all_cnt

def map2atc3_label(word2idx, atc3_to_index, l1, l2, p):
    num_atc3 = len(atc3_to_index)
    l1_mapped = np.zeros(num_atc3, dtype=int)
    l2_mapped = np.zeros(num_atc3, dtype=int)
    p_mapped = np.full(num_atc3, -np.inf, dtype=float)  
    p_mapped_count = np.zeros(num_atc3, dtype=int)

    for code, presence in zip(word2idx.keys(), l1):
        if presence == 1:
            atc3 = code[:4]
            l1_mapped[atc3_to_index[atc3]] = 1

    for code, presence in zip(word2idx.keys(), l2):
        if presence == 1:
            atc3 = code[:4]
            l2_mapped[atc3_to_index[atc3]] = 1

    for code, prob in zip(word2idx.keys(), p):
        atc3 = code[:4]
        idx = atc3_to_index[atc3]
        p_mapped[idx] = max(p_mapped[idx], prob)
        p_mapped_count[idx] += 1
    p_mapped[p_mapped_count == 0] = 0.0
    # p_mapped = np.divide(p_mapped, p_mapped_count, out=np.zeros_like(p_mapped), where=p_mapped_count > 0)

    return l1_mapped.tolist(), l2_mapped.tolist(), p_mapped.tolist()

def graph_batch_from_smile(smiles_list):
    edge_idxes, edge_feats, node_feats, lstnode, batch = [], [], [], 0, []
    graphs = [smiles2graph(x) for x in smiles_list]
    for idx, graph in enumerate(graphs):
        edge_idxes.append(graph['edge_index'] + lstnode)
        edge_feats.append(graph['edge_feat'])
        node_feats.append(graph['node_feat'])
        lstnode += graph['num_nodes']
        batch.append(np.ones(graph['num_nodes'], dtype=np.int64) * idx)

    result = {
        'edge_index': np.concatenate(edge_idxes, axis=-1),
        'edge_attr': np.concatenate(edge_feats, axis=0),
        'batch': np.concatenate(batch, axis=0),
        'x': np.concatenate(node_feats, axis=0)
    }
    result = {k: torch.from_numpy(v) for k, v in result.items()}
    result['num_nodes'] = lstnode
    return Data(**result)


class PatientDataset(Dataset):
    def __init__(self, records):
        """
        :param records: EHR records [[[diag], [pro], [med]], ...]
        """
        self.records = records

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        return self.records[idx]


def collate_fn(batch):
    return batch


def compute_ddi_loss(r, ddi_graph, alpha=1.0, beta=10.0):
    if r.shape == 1:
        r = r.unsqueeze(0)
    batch_size, num_drugs = r.shape
    i, j = torch.triu_indices(num_drugs, num_drugs, offset=1)

    r_i = r[:, i]  # (batch_size, num_pairs)
    r_j = r[:, j]  # (batch_size, num_pairs)
    
    # (num_pairs,)
    a_ij = ddi_graph[i, j]  # 0/1
    term1 = (1 - r_i).pow(alpha) * torch.sigmoid(beta * (r_j - r_i))  # (batch_size, num_pairs)
    term2 = (1 - r_j).pow(alpha) * torch.sigmoid(beta * (r_i - r_j))  # (batch_size, num_pairs)
    
    # (batch_size, num_pairs)
    pair_loss = a_ij * r_i * r_j * (term1 + term2)
    loss = pair_loss.sum() / (batch_size * (a_ij.sum() + 1e-6)) 
    
    return loss


def multi_label_metric(y_gt, y_pred, y_prob):
    def jaccard(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            union = set(out_list) | set(target)
            jaccard_score = 0 if union == 0 else len(inter) / len(union)
            score.append(jaccard_score)
        return np.mean(score)

    def average_prc(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
            score.append(prc_score)
        return score

    def average_recall(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            recall_score = 0 if len(target) == 0 else len(inter) / len(target)
            score.append(recall_score)
        return score

    def average_f1(average_prc, average_recall):
        score = []
        for idx in range(len(average_prc)):
            if average_prc[idx] + average_recall[idx] == 0:
                score.append(0)
            else:
                score.append(
                    2
                    * average_prc[idx]
                    * average_recall[idx]
                    / (average_prc[idx] + average_recall[idx])
                )
        return score

    def f1(y_gt, y_pred):
        all_micro = []
        for b in range(y_gt.shape[0]):
            all_micro.append(f1_score(y_gt[b], y_pred[b], average="macro"))
        return np.mean(all_micro)

    def roc_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(roc_auc_score(y_gt[b], y_prob[b], average="macro"))
        return np.mean(all_micro)

    def precision_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(
                average_precision_score(y_gt[b], y_prob[b], average="macro")
            )
        return np.mean(all_micro)

    def precision_at_k(y_gt, y_prob, k=3):
        precision = 0
        sort_index = np.argsort(y_prob, axis=-1)[:, ::-1][:, :k]
        for i in range(len(y_gt)):
            TP = 0
            for j in range(len(sort_index[i])):
                if y_gt[i, sort_index[i, j]] == 1:
                    TP += 1
            precision += TP / len(sort_index[i])
        return precision / len(y_gt)

    # roc_auc
    try:
        auc = roc_auc(y_gt, y_prob)
    except:
        auc = 0
    # precision
    p_1 = precision_at_k(y_gt, y_prob, k=1)
    p_3 = precision_at_k(y_gt, y_prob, k=3)
    p_5 = precision_at_k(y_gt, y_prob, k=5)
    # macro f1
    f1 = f1(y_gt, y_pred)
    # precision
    prauc = precision_auc(y_gt, y_prob)
    # jaccard
    ja = jaccard(y_gt, y_pred)
    # pre, recall, f1
    avg_prc = average_prc(y_gt, y_pred)
    avg_recall = average_recall(y_gt, y_pred)
    avg_f1 = average_f1(avg_prc, avg_recall)

    return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)

