from util import PatientDataset, collate_fn, map2atc3_label, ddi_rate_score, soft_label_loss_kl, multi_label_metric
import dill, random, torch
from torch.utils.data import DataLoader
from model import myModel
import numpy as np
from sklearn.metrics import jaccard_score
import logging
import os

logging.basicConfig(
    filename="./save/4.txt", 
    level=logging.INFO, 
    format="%(asctime)s - %(levelname)s - %(message)s", 
    filemode="w"  
)

logger = logging.getLogger()

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# =========================================
# setup
torch.manual_seed(2048)
random.seed(2048)
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1
node_features = 9
edge_features = 3
patient_dim = 512
mol_dim = 256
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
save_dir = "./savedModel/v1"
print(device)
# ===========================================
# load source file
data_path = './data/output/records.pkl'
voc_path = './data/output/voc.pkl'
mapping_path = './data/output/ATC4_mappings.pkl'
graph_data_path = './data/output/graph_data.pkl'
ddi_graph_path = './data/output/ddi_matrix.pkl'

with open(voc_path, 'rb') as f:
    voc = dill.load(f)
med_voc = voc['med_voc']
diag_voc = voc['diag_voc']
pro_voc = voc['pro_voc']
med_voc_size = len(med_voc.word2idx)
atc3_list = sorted(set(key[:4] for key in med_voc.word2idx.keys()))
atc3_to_index = {prefix: idx for idx, prefix in enumerate(atc3_list)}

with open(data_path, 'rb') as f:
    data = dill.load(f)

with open(mapping_path, 'rb') as f:
    mapping = dill.load(f)

with open(graph_data_path, 'rb') as f:
    graph_data_dict = dill.load(f)

with open(ddi_graph_path, 'rb') as f:
    ddi_graph = dill.load(f)
ddi_graph_tensor = torch.tensor(ddi_graph, dtype=torch.float, device=device)

def main():
    # ===========================================
    # get data

    num_records = len(data)
    train_size = int(train_ratio * num_records)
    val_size = int(val_ratio * num_records)
    test_size = num_records - train_size - val_size

    indices = list(range(num_records))
    random.shuffle(indices)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]
    train_records = [data[i] for i in train_indices]
    val_records = [data[i] for i in val_indices]
    test_records = [data[i] for i in test_indices]
    print(f"Size of Train Set: {len(train_records)}, Size of Validate Set: {len(val_records)}, Size of Test Set: {len(test_records)}")
    train_dataset = PatientDataset(train_records)
    val_dataset = PatientDataset(val_records)
    test_dataset = PatientDataset(test_records)

    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    # ==========================================
    # model
    model = myModel(
        diag_voc_size=len(diag_voc.word2idx),
        pro_voc_size=len(pro_voc.word2idx),
        med_voc_size=med_voc_size,
        med_voc=med_voc,
        graph_data_dict=graph_data_dict,
        ddi_graph=ddi_graph,
        node_features=node_features,
        edge_features=edge_features,
        patient_dim=patient_dim,
        mol_dim=mol_dim,
        latent_dim=32,
        device=device
    ).to(device)

    # ========================================
    # train
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.001)
    num_epochs = 80
    for epoch in range(num_epochs):
        model.train()

        for step, batch in enumerate(train_loader):
            optimizer.zero_grad()
            # batch metrics
            batch_loss = 0.0  # total loss of batch
            batch_jaccard = 0.0  # Jaccard of batch
            batch_samples = 0

            for batch_step, patient_data in enumerate(batch):
                for idx, adm in enumerate(patient_data):
                    input = patient_data[: idx + 1]
                    logits, atc4_emb_matrix, build_loss, z, y, r = model(input)
                    logits = logits.squeeze(0)

                    # calculate bce loss
                    target = np.zeros((1, med_voc_size))
                    med = adm[2]
                    target[:, med] = 1
                    target_tensor = torch.from_numpy(target).float().squeeze(0).to(device)
                    batch_loss += criterion(logits, target_tensor)

                    # calculate Jaccard similarity
                    pred_probs = torch.sigmoid(logits)
                    pred_probs = pred_probs.detach().cpu().numpy()
                    pred_labels = (pred_probs >= 0.5).astype(int)
                    true_labels = target_tensor.cpu().numpy()

                    # atc3_pred_labels, atc3_true_labels, _ = map2atc3_label(med_voc.word2idx, atc3_to_index, pred_labels,
                    #                                                     true_labels, pred_probs)

                    batch_jaccard += jaccard_score(pred_labels, true_labels)
                    batch_samples += 1
                    batch_loss += build_loss

            batch_loss.backward()
            optimizer.step()
            logger.info(f"Epoch {epoch + 1} Batch {step + 1} Train Jaccard: {batch_jaccard / batch_samples:.4f}")
        eval(model, val_loader, med_voc_size, epoch)


def eval(model, val_loader, med_voc_size, epoch):
    model.eval()
    total_samples = 0
    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    med_cnt, visit_cnt = 0, 0
    ddi_rate = 0.0

    for step, batch in enumerate(val_loader):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        for batch_step, patient_data in enumerate(batch):
            # print(f"Batch {step + 1} Patient {batch_step + 1}")
            for idx, adm in enumerate(patient_data):
                input = patient_data[: idx + 1]
                logits, _, _, _, _, _ = model(input)
                logits = logits.squeeze(0)

                target_label = np.zeros(med_voc_size)
                med = adm[2]
                target_label[med] = 1

                pred_probs = torch.sigmoid(logits).detach().cpu().numpy()

                pred_labels = (pred_probs >= 0.5).astype(int)

                pred_labels_index = np.where(pred_labels == 1)[0]
                y_pred_label.append(sorted(pred_labels_index))
                visit_cnt += 1
                med_cnt += len(pred_labels_index)

                atc3_pred_labels, atc3_true_labels, atc3_pred_probs = map2atc3_label(med_voc.word2idx, atc3_to_index, pred_labels, target_label, pred_probs)
                y_gt.append(atc3_true_labels)
                y_pred.append(atc3_pred_labels)
                y_pred_prob.append(atc3_pred_probs)

                total_samples += 1

            smm_record.append(y_pred_label)
            adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(
                np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)
            )
            ja.append(adm_ja)
            prauc.append(adm_prauc)
            avg_p.append(adm_avg_p)
            avg_r.append(adm_avg_r)
            avg_f1.append(adm_avg_f1)

    ddi_rate = ddi_rate_score(smm_record, ddi_graph)
    metircs = f"Epoch {epoch} Eval:Jaccard: {np.mean(ja):.4f}, PRAUC: {np.mean(prauc):.4f}, F1: {np.mean(avg_f1):.4f}, DDIRate: {ddi_rate:.4f}, AVG_MED: {med_cnt/visit_cnt:.4}"
    model_save_path = os.path.join(save_dir, f"epoch_{epoch}_model.pth")
    # torch.save(model.state_dict(), model_save_path)
    logger.info(metircs)

if __name__ == "__main__":
    main()
