# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
import argparse
import os
import PIL
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


class EvidenceAwareReasoningDataset(Dataset):
    def __init__(self, split, args, knowledge_bank_instance_normal, knowledge_bank_instance_abnormal):
        super().__init__()
        self.args = args
        root = "path_to_retrieved_data"

        if split == 'train':
            dataset_name = 'fundus_remain_5000'
        elif split == 'val':
            dataset_name = 'fundus_val'
        elif split == 'JSIEC_original':
            dataset_name = 'JSIEC_original'
        elif split == 'RIADD_original':
            dataset_name = 'RIADD_original'
        else:
            raise ValueError(f"Invalid split: {split}")

        
        self.normal_directory = os.path.join(root, f'{dataset_name}_normal')
        self.abnormal_directory = os.path.join(root, f'{dataset_name}_abnormal')

        self.distances_from_normal = np.load(os.path.join(root, f'{dataset_name}_distances.npy'))
        self.distances_from_abnormal = np.load(os.path.join(root, f'{dataset_name}_abnormal_distances.npy'))

        self.labels = np.load(os.path.join(root, f'{dataset_name}_anomaly_labels.npy'))
        
        self.knowledge_bank_instance_normal = knowledge_bank_instance_normal
        self.knowledge_bank_instance_abnormal = knowledge_bank_instance_abnormal

    def __getitem__(self, index):
        # 1. Load pre-calculated data (no changes here)
        cur_features = torch.from_numpy(np.load(os.path.join(self.normal_directory, f'features_{index}.npy'))).float()  # Shape: [256, 1024]
        
        # Indices of neighbors from the normal memory bank
        # Shape is likely [256, 16], containing integer indices
        cur_query_nns_from_normal_indices = torch.from_numpy(np.load(os.path.join(self.normal_directory, f'query_nns_{index}.npy')))
        
        # Indices of neighbors from the abnormal memory bank
        # Shape is likely [256, 16], containing integer indices
        cur_query_nns_from_abnormal_indices = torch.from_numpy(np.load(os.path.join(self.abnormal_directory, f'query_nns_{index}.npy')))

        # --- BATCH RETRIEVAL FROM NORMAL FEATURES ---
        # Reshape the 2D indices tensor [256, 16] into a 1D tensor [256 * 16]
        indices_flat_normal = cur_query_nns_from_normal_indices.reshape(-1).long().cpu().numpy()
        
        # Reconstruct all features in a single batch call
        # This returns a numpy array of shape [256 * 16, 1024]
        retrieved_features_flat_normal = self.knowledge_bank_instance_normal.anomaly_scorer.nn_method.search_index.reconstruct_batch(indices_flat_normal)
        
        # Reshape the flat features back to the desired 3D shape [256, 16, 1024]
        cur_retrieved_features_from_normal = torch.from_numpy(retrieved_features_flat_normal).view(
            cur_query_nns_from_normal_indices.shape[0], 
            cur_query_nns_from_normal_indices.shape[1], 
            -1  # Infer the feature dimension (1024)
        ).to(torch.float16)


        # --- BATCH RETRIEVAL FROM ABNORMAL FEATURES ---
        # Reshape the 2D indices tensor [256, 16] into a 1D tensor [256 * 16]
        indices_flat_abnormal = cur_query_nns_from_abnormal_indices.reshape(-1).long().cpu().numpy()
        
        # Reconstruct all features in a single batch call
        retrieved_features_flat_abnormal = self.knowledge_bank_instance_abnormal.anomaly_scorer.nn_method.search_index.reconstruct_batch(indices_flat_abnormal)
        
        # Reshape the flat features back to the desired 3D shape [256, 16, 1024]
        cur_retrieved_features_from_abnormal = torch.from_numpy(retrieved_features_flat_abnormal).view(
            cur_query_nns_from_abnormal_indices.shape[0], 
            cur_query_nns_from_abnormal_indices.shape[1], 
            -1
        ).to(torch.float16)
        
        # 4. Get label (no changes here)
        cur_labels = self.labels[index]

        # # 5. Get distances from normal and abnormal
        cur_distances_from_normal = self.distances_from_normal[index]
        cur_distances_from_abnormal = self.distances_from_abnormal[index]

        return cur_features, cur_retrieved_features_from_normal, cur_distances_from_normal, cur_retrieved_features_from_abnormal, cur_distances_from_abnormal, cur_labels

    def __len__(self):
        return self.labels.shape[0]
