import contextlib
import logging
import os
import sys

import click
import numpy as np
import torch
from tqdm import tqdm

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import eviscreen.knowledge_bank as knowledge_bank
from eviscreen.knowledge_bank.pos_embed import interpolate_pos_embed

LOGGER = logging.getLogger(__name__)

_DATASETS = {"fundus_5000": ["knowledge_bank.datasets.fundus", "FundusDataset"]}


@click.group(chain=True)
@click.argument("results_path", type=str)
@click.option("--gpu", type=int, default=[0], multiple=True, show_default=True)
@click.option("--seed", type=int, default=0, show_default=True)
@click.option("--log_group", type=str, default="group")
@click.option("--log_project", type=str, default="project")
@click.option("--save_segmentation_images", is_flag=True)
@click.option("--save_knowledge_bank", is_flag=True)
@click.option("--chunk_size", type=int, default=5000, show_default=True)
def main(**kwargs):
    pass


@main.result_callback()
def run(
    methods,
    results_path,
    gpu,
    seed,
    log_group,
    log_project,
    save_segmentation_images,
    save_knowledge_bank,
    chunk_size,
):
    methods = {key: item for (key, item) in methods}

    run_save_path = knowledge_bank.utils.create_storage_folder(
        results_path, log_project, log_group, mode="iterate"
    )

    list_of_dataloaders = methods["get_dataloaders"](seed)

    device = knowledge_bank.utils.set_torch_device(gpu)
    # Device context here is specifically set and used later
    # because there was GPU memory-bleeding which I could only fix with
    # context managers.
    device_context = (
        torch.cuda.device("cuda:{}".format(device.index))
        if "cuda" in device.type.lower()
        else contextlib.suppress()
    )

    result_collect = []

    for dataloader_count, dataloaders in enumerate(list_of_dataloaders):
        original_dataloader_params = {
            "batch_size": dataloaders["training"].batch_size,
            "num_workers": dataloaders["training"].num_workers,
            "pin_memory": dataloaders["training"].pin_memory,
        }
        LOGGER.info(
            "Evaluating dataset [{}] ({}/{})...".format(
                dataloaders["training"].name,
                dataloader_count + 1,
                len(list_of_dataloaders),
            )
        )

        knowledge_bank.utils.fix_seeds(seed, device)

        dataset_name = dataloaders["training"].name
        knowledge_bank_save_path = os.path.join(run_save_path, "models", dataset_name)
        os.makedirs(knowledge_bank_save_path, exist_ok=True)
        with device_context:
            torch.cuda.empty_cache()
            imagesize = dataloaders["training"].dataset.imagesize
            sampler = methods["get_sampler"](
                device,
            )
            Knowledge_Bank_list = methods["get_knowledge_bank"](imagesize, sampler, device)
            if len(Knowledge_Bank_list) > 1:
                LOGGER.info(
                    "Utilizing Knowledge_Bank Ensemble (N={}).".format(len(Knowledge_Bank_list))
                )
            for i, Knowledge_Bank in enumerate(Knowledge_Bank_list):
                torch.cuda.empty_cache()
                if Knowledge_Bank.backbone.seed is not None:
                    knowledge_bank.utils.fix_seeds(Knowledge_Bank.backbone.seed, device)
                LOGGER.info(
                    "Training models ({}/{})".format(i + 1, len(Knowledge_Bank_list))
                )
                torch.cuda.empty_cache()

                training_dataset = dataloaders["training"].dataset
                dataset_size = len(training_dataset)
                chunk_size = chunk_size
                num_subsets = (dataset_size + chunk_size - 1) // chunk_size

                LOGGER.info(
                    f"Dataset too large. Splitting into {num_subsets} chunks of size ~{chunk_size}."
                )

                _ = Knowledge_Bank.forward_modules.eval()
                all_sampled_features = []
                all_sampled_metadata = []
                for subset_idx in range(num_subsets):
                    chunk_features_list = []
                    chunk_metadata_list = []
                    start_idx = subset_idx * chunk_size
                    end_idx = min(start_idx + chunk_size, dataset_size)
                    
                    LOGGER.info(
                        f"Processing chunk {subset_idx + 1}/{num_subsets} (indices {start_idx}-{end_idx-1})..."
                    )

                    subset = torch.utils.data.Subset(training_dataset, range(start_idx, end_idx))
                    
                    subset_dataloader = torch.utils.data.DataLoader(
                        subset, **original_dataloader_params
                    )

                    with tqdm(subset_dataloader, desc=f"Extracting Chunk {subset_idx+1}", position=1, leave=False) as data_iterator:
                        for batch_idx, image_batch in enumerate(data_iterator):
                            if isinstance(image_batch, dict):
                                image_batch = image_batch["image"]
                            
                            with torch.no_grad():
                                input_image = image_batch.to(torch.float).to(device)
                                features, metadata = Knowledge_Bank._embed(input_image, provide_patch_metadata=True, norm=log_project.split('_')[-1] == 'normnew')
                            
                            global_image_indices = range(start_idx + batch_idx * dataloaders["training"].batch_size, 
                                                       start_idx + (batch_idx + 1) * dataloaders["training"].batch_size)
                            
                            for i, meta in enumerate(metadata):
                                batch_img_idx = meta['image_idx']
                                meta['global_image_idx'] = global_image_indices[batch_img_idx]
                                meta['chunk_idx'] = subset_idx
                            
                            chunk_features_list.append(features)
                            chunk_metadata_list.extend(metadata)

                    chunk_features = np.concatenate(chunk_features_list, axis=0)
                    LOGGER.info(f"Chunk {subset_idx + 1} extracted features with shape: {chunk_features.shape}")
                    
                    LOGGER.info(f"Applying sampler to chunk {subset_idx + 1} features...")
                    if Knowledge_Bank.featuresampler.percentage == 0.99:
                        sampled_chunk_features = chunk_features
                        sampled_chunk_indices = np.arange(len(chunk_features))
                    else:
                        sampled_chunk_features, sampled_chunk_indices = Knowledge_Bank.featuresampler.run_with_indices(chunk_features)
                    
                    all_sampled_features.append(sampled_chunk_features)
                    LOGGER.info(f"Chunk {subset_idx + 1} sampled features shape: {sampled_chunk_features.shape}. Adding to memory bank.")

                    np.save(f"{knowledge_bank_save_path}/sampled_chunk_features_{subset_idx}.npy", sampled_chunk_features)
                    np.save(f"{knowledge_bank_save_path}/sampled_chunk_metadata_{subset_idx}.npy", [chunk_metadata_list[i] for i in sampled_chunk_indices])

                    del chunk_features, chunk_features_list, sampled_chunk_features, chunk_metadata_list, sampled_chunk_indices
                    torch.cuda.empty_cache()

                
                LOGGER.info("All chunks have been processed and sampled.")
                final_memory_bank = np.concatenate(all_sampled_features, axis=0)
                LOGGER.info(f"Final memory bank created with shape: {final_memory_bank.shape}")
                
                Knowledge_Bank.anomaly_scorer.fit(detection_features=[final_memory_bank])
                LOGGER.info("Training complete.")

            torch.cuda.empty_cache()

            for i, Knowledge_Bank in enumerate(Knowledge_Bank_list):
                prepend = (
                    "Ensemble-{}-{}_".format(i + 1, len(Knowledge_Bank_list))
                    if len(Knowledge_Bank_list) > 1
                    else ""
                )
                Knowledge_Bank.save_to_path(knowledge_bank_save_path, prepend)

        LOGGER.info("\n\n-----\n")
        exit(0)


@main.command("knowledge_bank_construction")
# Pretraining-specific parameters.
@click.option("--backbone_names", "-b", type=str, multiple=True, default=[])
@click.option("--load_backbone", "-lb", type=str)
@click.option("--layers_to_extract_from", "-le", type=str, multiple=True, default=[])
# Parameters for Glue-code (to merge different parts of the pipeline.
@click.option("--pretrain_embed_dimension", type=int, default=1024)
@click.option("--target_embed_dimension", type=int, default=1024)
@click.option("--preprocessing", type=click.Choice(["mean", "conv"]), default="mean")
@click.option("--aggregation", type=click.Choice(["mean", "mlp"]), default="mean")
# Nearest-Neighbour Anomaly Scorer parameters.
@click.option("--anomaly_scorer_num_nn", type=int, default=5)
# Patch-parameters.
@click.option("--patchsize", type=int, default=3)
@click.option("--patchscore", type=str, default="max")
@click.option("--patchoverlap", type=float, default=0.0)
@click.option("--patchsize_aggregate", "-pa", type=int, multiple=True, default=[])
# NN on GPU.
@click.option("--faiss_on_gpu", is_flag=True)
@click.option("--faiss_num_workers", type=int, default=8)
def knowledge_bank_construction(
    backbone_names,
    load_backbone,
    layers_to_extract_from,
    pretrain_embed_dimension,
    target_embed_dimension,
    preprocessing,
    aggregation,
    patchsize,
    patchscore,
    patchoverlap,
    anomaly_scorer_num_nn,
    patchsize_aggregate,
    faiss_on_gpu,
    faiss_num_workers,
):
    backbone_names = list(backbone_names)
    if len(backbone_names) > 1:
        layers_to_extract_from_coll = [[] for _ in range(len(backbone_names))]
        for layer in layers_to_extract_from:
            idx = int(layer.split(".")[0])
            layer = ".".join(layer.split(".")[1:])
            layers_to_extract_from_coll[idx].append(layer)
    else:
        layers_to_extract_from_coll = [layers_to_extract_from]

    def get_knowledge_bank(input_shape, sampler, device):
        loaded_knowledge_banks = []
        for backbone_name, layers_to_extract_from in zip(
            backbone_names, layers_to_extract_from_coll
        ):
            backbone_seed = None
            if ".seed-" in backbone_name:
                backbone_name, backbone_seed = backbone_name.split(".seed-")[0], int(
                    backbone_name.split("-")[-1]
                )
            backbone = knowledge_bank.backbones.load(backbone_name)
            if load_backbone:

                
                checkpoint = torch.load(load_backbone)
                checkpoint = checkpoint["teacher"]
                key_list = []
                for key in checkpoint:
                    if "backbone." in key:
                        key_list.append(key)
                for key in key_list:
                    new_key = key.replace("backbone.", "")
                    checkpoint[new_key] = checkpoint[key]
                for key in key_list:
                    del checkpoint[key]

                interpolate_pos_embed(backbone, checkpoint)
                msg = backbone.load_state_dict(checkpoint, strict=False)
                print(msg)
            
            backbone.name, backbone.seed = backbone_name, backbone_seed

            nn_method = knowledge_bank.common.FaissNN(faiss_on_gpu, faiss_num_workers)

            knowledge_bank_instance = knowledge_bank.knowledge_bank.Knowledge_Bank(device)
            knowledge_bank_instance.load(
                backbone=backbone,
                layers_to_extract_from=layers_to_extract_from,
                device=device,
                input_shape=input_shape,
                pretrain_embed_dimension=pretrain_embed_dimension,
                target_embed_dimension=target_embed_dimension,
                patchsize=patchsize,
                featuresampler=sampler,
                anomaly_scorer_num_nn=anomaly_scorer_num_nn,
                nn_method=nn_method,
            )
            loaded_knowledge_banks.append(knowledge_bank_instance)
        return loaded_knowledge_banks

    return ("get_knowledge_bank", get_knowledge_bank)


@main.command("sampler")
@click.argument("name", type=str)
@click.option("--percentage", "-p", type=float, default=0.1, show_default=True)
def sampler(name, percentage):
    def get_sampler(device):
        if name == "identity":
            return knowledge_bank.sampler.IdentitySampler()
        elif name == "greedy_coreset":
            return knowledge_bank.sampler.GreedyCoresetSampler(percentage, device)
        elif name == "approx_greedy_coreset":
            return knowledge_bank.sampler.ApproximateGreedyCoresetSampler(percentage, device)

    return ("get_sampler", get_sampler)


@main.command("dataset")
@click.argument("name", type=str)
@click.argument("data_path", type=click.Path(exists=True, file_okay=False))
@click.option("--subdatasets", "-d", multiple=True, type=str, required=True)
@click.option("--train_val_split", type=float, default=1, show_default=True)
@click.option("--batch_size", default=4, type=int, show_default=True)
@click.option("--num_workers", default=8, type=int, show_default=True)
@click.option("--resize", default=256, type=int, show_default=True)
@click.option("--imagesize", default=224, type=int, show_default=True)
@click.option("--augment", is_flag=True)
@click.option("--train_scale", default="all", type=str, show_default=True)
@click.option("--category", default='normal', type=str, show_default=True)

def dataset(
    name,
    data_path,
    subdatasets,
    train_val_split,
    batch_size,
    resize,
    imagesize,
    num_workers,
    augment,
    train_scale,
    category,
):
    dataset_info = _DATASETS[name]
    dataset_library = __import__(dataset_info[0], fromlist=[dataset_info[1]])

    def get_dataloaders(seed):
        dataloaders = []
        for subdataset in subdatasets:
            train_dataset = dataset_library.__dict__[dataset_info[1]](
                data_path,
                classname=subdataset,
                resize=resize,
                train_val_split=train_val_split,
                imagesize=imagesize,
                split=dataset_library.DatasetSplit.TRAIN,
                seed=seed,
                augment=augment,
                train_scale=train_scale,
                category=category,
            )

            train_dataloader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_workers,
                pin_memory=True,
            )


            train_dataloader.name = name
            if subdataset is not None:
                train_dataloader.name += "_" + subdataset

            if train_val_split < 1:
                val_dataset = dataset_library.__dict__[dataset_info[1]](
                    data_path,
                    classname=subdataset,
                    resize=resize,
                    train_val_split=train_val_split,
                    imagesize=imagesize,
                    split=dataset_library.DatasetSplit.VAL,
                    seed=seed,
                )

                val_dataloader = torch.utils.data.DataLoader(
                    val_dataset,
                    batch_size=batch_size,
                    shuffle=False,
                    num_workers=num_workers,
                    pin_memory=True,
                )
            else:
                val_dataloader = None
            dataloader_dict = {
                "training": train_dataloader,
                "validation": val_dataloader,
            }

            dataloaders.append(dataloader_dict)
        return dataloaders

    return ("get_dataloaders", get_dataloaders)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    LOGGER.info("Command line arguments: {}".format(" ".join(sys.argv)))
    main()
