import os, re
from tqdm import tqdm
from pathlib import Path
from typing import List, Optional

import scipy.io as sio
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

from torchvision import transforms as T


IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]


def get_input_transform(model_name: str, model, img_size: int):
    model_name = model_name.lower()
    
    if 'b-cos' in model_name:
        print('b-cos transform: Ommiting AddInverse...')
        transforms = T.Compose(model.transform.transforms.transforms[:-1])
        
    elif 'vit_base' in model_name:
        print('vit_base transform: setting IN-1k standardization...')
        from torchvision.transforms import InterpolationMode
    
        transforms = T.Compose([
            T.Resize(
                size=img_size+24,
                interpolation=InterpolationMode.BICUBIC,
                max_size=None,
                antialias=True,
            ),
            T.CenterCrop(size=(img_size, img_size)),
            T.ToTensor(),
            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
    
    else:
        print('extracting transform from the model...')
        from timm.data import resolve_model_data_config
        from timm.data.transforms_factory import create_transform
    
        config = resolve_model_data_config(model)
        transforms = create_transform(**config)
        
    return transforms



def matches(name: str, keywords: List[str]) -> bool:
    for kw in keywords:
        pattern = r"\b" + re.escape(kw.lower()).replace(r"\ ", r"[\s-]") + r"\b"
        if re.search(pattern, name):
            return True
    return False


def extract_subset(
    dataset: ImageFolder,
    classes: List[str],
    devkit_path: Path,
) -> Subset:

    if len(classes) == 0:
        return dataset

    meta = sio.loadmat(devkit_path, squeeze_me=True)['synsets']

    syn_to_name = {
        str(entry[1]): str(entry[2]).lower()
        for entry in meta
    }

    selected_synsets = {
        syn for syn, name in syn_to_name.items()
        if matches(name, classes)
    }

    subset_idxs = [
        idx for idx, (path, _) in enumerate(dataset.samples)
        if Path(path).parts[-2] in selected_synsets
    ]
    return Subset(dataset, subset_idxs)



