# %%
import contextlib
import logging
import os
import sys

import click
import numpy as np
import torch

import eviscreen.knowledge_bank.knowledge_bank as knowledge_bank

import gc
from enum import Enum
import argparse
import pandas as pd


LOGGER = logging.getLogger(__name__)


_DATASETS = {"fundus": ["eviscreen.knowledge_bank.datasets.fundus", "FundusDataset"]}

parser = argparse.ArgumentParser(description="Process command-line arguments.")
parser.add_argument("--modality", type=str, help="modality", default="fundus")
parser.add_argument("--image_size", type=int, help="image_size", default=224)
parser.add_argument("--batch_size", type=int, help="batch_size", default=1)
args = parser.parse_args()

df = pd.read_csv(f"path_to_knowledge_bank_construction_results.csv")

dataset_list_dict = {
    "fundus": ["fundus_val", "fundus_remain_5000", "RIADD_original", "JSIEC_original"],
}

# %%
gpu=[0]
device = knowledge_bank.utils.set_torch_device(gpu)
device_context = (
    torch.cuda.device("cuda:{}".format(device.index))
    if "cuda" in device.type.lower()
    else contextlib.suppress()
)

# %%
def dataset(
    name="fundus", data_path="/root", subdatasets=("fundus",), batch_size=1, resize=224, imagesize=224, num_workers=8, augment=False, test_set="JSIEC"
):
    dataset_info = _DATASETS[name]
    dataset_library = __import__(dataset_info[0], fromlist=[dataset_info[1]])
    
    if test_set == "fundus_remain_5000":
        split = dataset_library.DatasetSplit.REMAIN_5000
    elif test_set == "fundus_val":
        split = dataset_library.DatasetSplit.VAL
    elif test_set == "RIADD":
        split = dataset_library.DatasetSplit.RIADD
    elif test_set == "JSIEC":
        split = dataset_library.DatasetSplit.JSIEC
    elif test_set == "JSIEC_original":
        split = dataset_library.DatasetSplit.JSIEC_original
    elif test_set == "RIADD_original":
        split = dataset_library.DatasetSplit.RIADD_original
    elif test_set == "Fundus_val":
        split = dataset_library.DatasetSplit.VAL
    elif test_set == "BRSET":
        split = dataset_library.DatasetSplit.TEST_BRSET
    elif test_set == "EDDFS":
        split = dataset_library.DatasetSplit.TEST_EDDFS
    elif test_set == "OCTDL":
        split = dataset_library.DatasetSplit.OCTDL
    elif test_set == "OCTID":
        split = dataset_library.DatasetSplit.OCTID
    elif test_set == "CXR_remain_5000":
        split = dataset_library.DatasetSplit.CXR_remain_5000
    elif test_set == "MIMIC_val":
        split = dataset_library.DatasetSplit.VAL
    elif test_set == "CheXpert":
        split = dataset_library.DatasetSplit.CheXpert
    elif test_set == "derm_remain_5000":
        split = dataset_library.DatasetSplit.derm_remain_5000
    elif test_set == "Derm12345":
        split = dataset_library.DatasetSplit.Derm12345
    elif test_set == "Derm12345_original":
        split = dataset_library.DatasetSplit.Derm12345_original
    elif test_set == "PAD_UFES_20":
        split = dataset_library.DatasetSplit.PAD_UFES_20
    elif test_set == "BUSI":
        split = dataset_library.DatasetSplit.BUSI
    elif test_set == "MedFMC_Endo":
        split = dataset_library.DatasetSplit.MedFMC_Endo
    else:
        raise ValueError(f"Invalid test set name: {test_set}")

    def get_dataloaders_iter(seed):
        for subdataset in subdatasets:
            test_dataset = dataset_library.__dict__[dataset_info[1]](
                data_path,
                classname=subdataset,
                resize=resize,
                imagesize=imagesize,
                split=split,
                seed=seed,
            )

            test_dataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_workers,
                pin_memory=True,
            )

            test_dataloader.name = name
            if subdataset is not None:
                test_dataloader.name += "_" + subdataset

            dataloader_dict = {"testing": test_dataloader}

            yield dataloader_dict

    return get_dataloaders_iter


# %%
def patch_core_loader(patch_core_paths=[], faiss_on_gpu=True, faiss_num_workers=8, anomaly_scorer_num_nn=1, backbone_path=None):
    def get_knowledge_bank_iter(device):
        for patch_core_path in patch_core_paths:
            loaded_knowledge_banks = []
            gc.collect()
            n_knowledge_banks = len(
                [x for x in os.listdir(patch_core_path) if ".faiss" in x]
            )
            if n_knowledge_banks == 1:
                nn_method = knowledge_bank.common.FaissNN(faiss_on_gpu, faiss_num_workers)
                knowledge_bank_instance = knowledge_bank.knowledge_bank.Knowledge_Bank(device)
                knowledge_bank_instance.load_from_path(
                    load_path=patch_core_path, device=device, nn_method=nn_method, anomaly_scorer_num_nn=anomaly_scorer_num_nn, backbone_path=backbone_path
                )
                loaded_knowledge_banks.append(knowledge_bank_instance)
            else:
                for i in range(n_knowledge_banks):
                    nn_method = knowledge_bank.common.FaissNN(
                        faiss_on_gpu, faiss_num_workers
                    )
                    knowledge_bank_instance = knowledge_bank.knowledge_bank.Knowledge_Bank(device)
                    knowledge_bank_instance.load_from_path(
                        load_path=patch_core_path,
                        device=device,
                        nn_method=nn_method,
                        prepend="Ensemble-{}-{}_".format(i + 1, n_knowledge_banks),
                        anomaly_scorer_num_nn=anomaly_scorer_num_nn,
                        backbone_path=backbone_path
                    )
                    loaded_knowledge_banks.append(knowledge_bank_instance)

            yield loaded_knowledge_banks

    return get_knowledge_bank_iter



for anomaly_scorer_num_nn in [16]:
    tgt_root = f"path_to_retrieved_data"
    for index, row in df.iterrows():
        exp = row["exp"]
        normal_path = row["normal_path"]
        abnormal_path = row["abnormal_path"]
        backbone_path = row["backbone_path"]
        if str(backbone_path) == "No":
            backbone_path = None
        print(exp, normal_path, abnormal_path, backbone_path)
        print("exp:", exp)
        os.makedirs(os.path.join(tgt_root, f"{exp}", "retrieved_data"), exist_ok=True)
        for test_set in dataset_list_dict[args.modality]:
            print("test_set:", test_set)
            # %%
            seed = 0
            dataloader_iter = dataset(name=args.modality, test_set=test_set, batch_size=args.batch_size, resize=args.image_size, imagesize=args.image_size)
            dataloader_iter = dataloader_iter(seed)
            knowledge_bank_iter = patch_core_loader(patch_core_paths=[normal_path], anomaly_scorer_num_nn=anomaly_scorer_num_nn, backbone_path=backbone_path)
            knowledge_bank_iter = knowledge_bank_iter(device)
            abnormal_knowledge_bank_iter = patch_core_loader(patch_core_paths=[abnormal_path], anomaly_scorer_num_nn=anomaly_scorer_num_nn, backbone_path=backbone_path)
            abnormal_knowledge_bank_iter = abnormal_knowledge_bank_iter(device)

            n_dataloaders = 1
            n_knowledge_banks = 1

            # %%
            for dataloader_count, dataloaders in enumerate(dataloader_iter):
                print(dataloader_count)
                LOGGER.info(
                    "Evaluating dataset [{}] ({}/{})...".format(
                        dataloaders["testing"].name, dataloader_count + 1, n_dataloaders
                    )
                )

                knowledge_bank.utils.fix_seeds(seed, device)

                dataset_name = dataloaders["testing"].name

                with device_context:

                    torch.cuda.empty_cache()
                    if dataloader_count < n_knowledge_banks:
                        Knowledge_Bank_list = next(knowledge_bank_iter)
                        AbnormalKnowledge_Bank_list = next(abnormal_knowledge_bank_iter)

                    anomaly_labels = [
                        x[1] != "good" for x in dataloaders["testing"].dataset.data_to_iterate
                    ]

                    for i, Knowledge_Bank in enumerate(Knowledge_Bank_list):
                        torch.cuda.empty_cache()
                        LOGGER.info(
                            "Embedding test data with models ({}/{})".format(
                                i + 1, len(Knowledge_Bank_list)
                            )
                        )
                        scores, patch_scores, segmentations, labels_gt, masks_gt, distances = Knowledge_Bank.predict(
                            dataloaders["testing"], 
                            save_path=os.path.join(tgt_root, f"{exp}", "retrieved_data", f"{test_set}_normal")
                        )

                    for i, AbnormalKnowledge_Bank in enumerate(AbnormalKnowledge_Bank_list):
                        torch.cuda.empty_cache()
                        LOGGER.info(
                            "Embedding test data with models ({}/{})".format(
                                i + 1, len(AbnormalKnowledge_Bank_list)
                            )
                        )
                        scores, patch_scores, segmentations, labels_gt, masks_gt, distances = AbnormalKnowledge_Bank.predict(
                            dataloaders["testing"],
                            save_path=os.path.join(tgt_root, f"{exp}", "retrieved_data", f"{test_set}_abnormal")
                        )

                    anomaly_labels = np.array(anomaly_labels).astype(int)

                    np.save(os.path.join(tgt_root, f"{exp}", "retrieved_data", f"{test_set}_anomaly_labels.npy"), anomaly_labels)
