"""
Extract ImageNet features for a subset of samples per class.
Subsets are taken by already extracted features and targets of ImageNet1k.
"""
import argparse
import json
import os
from pathlib import Path
from loguru import logger
import numpy as np
import torch
from tqdm import tqdm

from src.utils.utils import load_features_targets, check_feature_existence
from helper import load_models, format_path
from project_location import FEATURES_ROOT, SUBSET_ROOT

MODELS_CONFIG = "./scripts/configs/models_config_single_model_layer_combination.json"


def main(args):
    # Create output directory: replace {num_samples_class} and {split} in the path
    out_path_root = format_path(args.output_root_dir, args.num_samples_class, args.split)
    logger.info(f"Formatted {out_path_root=}")

    # Get path to the subset indices file: replace {num_samples_class} and {split} in the path
    idxs_fn = format_path(args.subset_idxs, args.num_samples_class, args.split)
    
    # Load the indices map
    with open(idxs_fn, 'r') as f:
        indices_map = json.load(f)
    indices = np.array(list(map(list, indices_map.values()))).flatten()
    logger.info(f"Loaded indices from {idxs_fn=}")

    model_keys = [args.model_key] if not isinstance(args.model_key, list) else args.model_key

    models, _ = load_models(args.model_config)

    logger.info(f"For each model in model_keys (n={len(model_keys)}) extract the features ...")
    for model_id in tqdm(model_keys, desc=f"Extracting subset of features with {args.num_samples_class} per class."):
        model_cfg = models.get(model_id)
        if model_cfg is None:
            raise ValueError(f"Model {model_id} not found in {MODELS_CONFIG}")
        
        module_names = model_cfg.get('module_names', None)
        if module_names is None:
            raise ValueError(f"Module names not found for model {model_id} in {MODELS_CONFIG}")
        
        if not isinstance(module_names, list):
            module_names = [module_names]

        logger.info(f"Extracting features for model {model_id} and with {len(module_names)} modules ...")

        for module_name in module_names:

            curr_full_features_dir  = Path(args.features_root) / model_id / module_name
            new_subset_features_dir = Path(out_path_root) / model_id / module_name

            if args.split == 'train':
                check_train = True
                check_test = False
            else:
                check_train = False
                check_test = True

            if check_feature_existence(new_subset_features_dir, check_train=check_train, check_test=check_test):
                logger.info(f"Features and targets already exist for model {model_id} and module {module_name} in \n{new_subset_features_dir}. Skipping...")
                continue

            try:
                features, targets = load_features_targets(curr_full_features_dir, split=args.split, normalize=True)
            except FileNotFoundError as e:
                logger.info(
                    f'\nFeatures or targets of wds_imagenet1k not found for model {model_id} and idxes at:\n{idxs_fn}. Skipping...')
                logger.info(f'>> Error: {e}\n')
                continue

            features_subset = features[indices, :]
            targets_subset = targets[indices]

            if not new_subset_features_dir.exists():
                new_subset_features_dir.mkdir(parents=True, exist_ok=True)
                logger.info(f'Created directory {new_subset_features_dir}')

            torch.save(features_subset, new_subset_features_dir / f'features_{args.split}.pt')
            torch.save(targets_subset, new_subset_features_dir / f'targets_{args.split}.pt')
            logger.info(f'Saved {args.split} features and targets for model {model_id} and module {module_name} to {new_subset_features_dir}\n\n')


if __name__ == "__main__":
    # models, n_models = load_models(MODELS_CONFIG)

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_config', default=MODELS_CONFIG,
                        help='Path to the model config file.')
    parser.add_argument('--features_root', default=os.path.join(FEATURES_ROOT, 'wds_imagenet1k'),
                        help='Root directory of the extracted features and targets for ImageNet1k.')
    parser.add_argument('--model_key', nargs='+',
                        help='Model key(s) for which the features are extracted.')
    parser.add_argument('--split', default='train', choices=['train', 'test'])
    parser.add_argument('--num_samples_class', default=10, type=int,
                        help='Number of samples per class in the subset.')
    parser.add_argument('--subset_idxs',
                        default=os.path.join(SUBSET_ROOT,
                                             'imagenet-subset-{num_samples_class}k/imagenet-{num_samples_class}k-{split}.json'),
                        help='Path to the subset indices file.')
    parser.add_argument('--output_root_dir',
                        default=os.path.join(FEATURES_ROOT, 'imagenet-subset-{num_samples_class}k'),
                        help='Root directory for the output features and targets.')
    args = parser.parse_args()

    main(args)
