from util import PatientDataset, collate_fn, map2atc3_label, ddi_rate_score, soft_label_loss_kl, multi_label_metric
import dill, random, torch
from torch.utils.data import DataLoader
from model import myModel # Assuming myModel is defined in model.py
import numpy as np
from sklearn.metrics import jaccard_score
import logging
import os

# =========================================
# setup logging (can reuse the same logger setup)
logging.basicConfig(
    filename="./save/inference_log.txt",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    filemode="w" # Use 'w' to overwrite for a new inference run
)

logger = logging.getLogger()
# Prevent adding duplicate handlers if script is run multiple times in same process
if not logger.handlers:
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)


# =========================================
# setup parameters (should match training parameters)
torch.manual_seed(2048)
random.seed(2048)
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1
node_features = 9
edge_features = 3
patient_dim = 256
mol_dim = 256
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
save_dir = "./savedModel/v2" # Directory where models were saved

# Specify which epoch's model to load for inference
# CHANGE THIS TO THE EPOCH NUMBER YOU WANT TO TEST
MODEL_EPOCH_TO_LOAD = 45 # Example: Load the model from epoch 

# ===========================================
# load source file
data_path = './data/output/records.pkl'
voc_path = './data/output/voc.pkl'
mapping_path = './data/output/ATC4_mappings.pkl'
graph_data_path = './data/output/graph_data.pkl'
ddi_graph_path = './data/output/ddi_matrix.pkl'

with open(voc_path, 'rb') as f:
    voc = dill.load(f)
med_voc = voc['med_voc']
diag_voc = voc['diag_voc']
pro_voc = voc['pro_voc']
med_voc_size = len(med_voc.word2idx)
atc3_list = sorted(set(key[:4] for key in med_voc.word2idx.keys()))
atc3_to_index = {prefix: idx for idx, prefix in enumerate(atc3_list)}

with open(data_path, 'rb') as f:
    data = dill.load(f)

with open(mapping_path, 'rb') as f:
    mapping = dill.load(f)

with open(graph_data_path, 'rb') as f:
    graph_data_dict = dill.load(f)

with open(ddi_graph_path, 'rb') as f:
    ddi_graph = dill.load(f)
ddi_graph_tensor = torch.tensor(ddi_graph, dtype=torch.float, device=device) # Keep this if model uses it

def main_inference():
    """
    Main function for performing inference on the test dataset.
    """
    # ===========================================
    # get data (using the same split as training)

    num_records = len(data)
    train_size = int(train_ratio * num_records)
    val_size = int(val_ratio * num_records)
    test_size = num_records - train_size - val_size

    # Need to use the SAME random state as training to get the IDENTICAL split
    # If you saved the random state during training, load it here.
    # If not, rely on the fixed seed being the same.
    indices = list(range(num_records))
    random.shuffle(indices) # This shuffle needs to be the same as in training

    # Recreate the exact test indices
    val_indices = indices[train_size:train_size + val_size]
    val_records = [data[i] for i in val_indices]
    test_indices = indices[train_size + val_size:]
    test_records = [data[i] for i in test_indices]
    logger.info(f"Size of test set: {len(test_records)}")

    val_dataset = PatientDataset(val_records)   
    test_dataset = PatientDataset(test_records)
    batch_size = 32 # Can potentially use a larger batch size for inference if memory allows
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) # shuffle=False for consistent results

    # ==========================================
    # model
    # Instantiate the model with the same parameters as training
    model = myModel(
        diag_voc_size=len(diag_voc.word2idx),
        pro_voc_size=len(pro_voc.word2idx),
        med_voc_size=med_voc_size,
        med_voc=med_voc,
        graph_data_dict=graph_data_dict, # Pass graph data
        ddi_graph=ddi_graph, # Pass DDI graph
        node_features=node_features,
        edge_features=edge_features,
        patient_dim=patient_dim,
        mol_dim=mol_dim,
        latent_dim=32, # Ensure this matches your CVAE latent_dim
        device=device
    ).to(device)

    # ========================================
    # Load the trained model state dictionary
    model_path = os.path.join(save_dir, f"epoch_{MODEL_EPOCH_TO_LOAD}_model.pth")
    if not os.path.exists(model_path):
        logger.error(f"Model file not found at: {model_path}")
        return

    logger.info(f"Loading model from {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=device))
    logger.info("Model loaded successfully.")

    # ========================================
    # Perform Inference
    test_model(model, test_loader, med_voc_size, ddi_graph)


def test_model(model, test_loader, med_voc_size, ddi_graph):
    """
    Evaluates the model on the test dataset and logs metrics.

    Args:
        model (torch.nn.Module): The trained model.
        test_loader (DataLoader): DataLoader for the test dataset.
        med_voc_size (int): Size of the medication vocabulary.
        ddi_graph (np.ndarray): The DDI adjacency matrix.
    """
    model.eval() # Set model to evaluation mode

    total_samples = 0 # Total number of administrations processed
    smm_record = [] # To store predicted medication indices for DDI calculation
    ja, prauc, avg_p, avg_r, avg_f1 = [], [], [], [], [] # Lists to store metrics per patient
    med_cnt, visit_cnt = 0, 0 # Counters for average medications per visit

    with torch.no_grad():
        for step, batch in enumerate(test_loader):
            logger.info(f"Processing Test Batch {step + 1}/{len(test_loader)}")
            y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
            for batch_step, patient_data in enumerate(batch):
                for idx, adm in enumerate(patient_data):
                    input_sequence = patient_data[: idx + 1]
                    logits, _, _, _, _, _ = model(input_sequence)
                    logits = logits.squeeze(0) # Shape (med_voc_size,)

                    target_label = np.zeros(med_voc_size)
                    med = adm[2]
                    target_label[med] = 1

                    pred_probs = torch.sigmoid(logits).detach().cpu().numpy()

                    pred_labels = (pred_probs >= 0.5).astype(int)

                    pred_labels_index = np.where(pred_labels == 1)[0]
                    y_pred_label.append(sorted(pred_labels_index))
                    visit_cnt += 1
                    med_cnt += len(pred_labels_index)

                    atc3_pred_labels, atc3_true_labels, atc3_pred_probs = map2atc3_label(med_voc.word2idx, atc3_to_index, pred_labels, target_label, pred_probs)
                    y_gt.append(atc3_true_labels)
                    y_pred.append(atc3_pred_labels)
                    y_pred_prob.append(atc3_pred_probs)

                    total_samples += 1

                smm_record.append(y_pred_label)
                adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(
                    np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)
                )
                ja.append(adm_ja)
                prauc.append(adm_prauc)
                avg_p.append(adm_avg_p)
                avg_r.append(adm_avg_r)
                avg_f1.append(adm_avg_f1)

        ddi_rate = ddi_rate_score(smm_record, ddi_graph)
        metircs = f"Epoch {epoch} Eval:Jaccard: {np.mean(ja):.4f}, PRAUC: {np.mean(prauc):.4f}, F1: {np.mean(avg_f1):.4f}, DDIRate: {ddi_rate:.4f}, AVG_MED: {med_cnt/visit_cnt:.4}"
        model_save_path = os.path.join(save_dir, f"epoch_{epoch}_model.pth")
        torch.save(model.state_dict(), model_save_path)
        logger.info(metircs)        


if __name__ == "__main__":
    main_inference()