from typing import Any
from pathlib import Path

from PIL import Image
import numpy as np


try:
    import imagenet_c
except ImportError:
    imagenet_c = None
    print("Warning: 'imagenet_c' package is not installed. Some functionalities may not work.")

from torchvision.transforms import v2 as transforms
from .image_utils import (
    load_image, random_split, ImageDataset, ImageTransformDataset, ImageSubset,
    IndexedDataset)


import logging
txt_logger = logging.getLogger("sfda_reg")

domain_map = {
    'male': 0,
    'female': 1,
    'white': 0, 
    'black': 1, 
    'asian': 2, 
    'indian': 3, 
    'others': 4,
}



class UTKFace(ImageDataset):

    def __init__(self, utkface_path: str, domain: str = "original"):

        root = utkface_path
        self.path_list = sorted(root.rglob("*.jpg"))
        valid_path_list = [p for p in self.path_list if self._is_valid_name(p.name)]
        invalid_paths = [p for p in self.path_list if not self._is_valid_name(p.name)]
        txt_logger.info(f"Filtered {len(invalid_paths)} invalid filenames. Examples: {invalid_paths[:5]}")
        self.path_list = valid_path_list
        txt_logger.info(f"Sorted path examples: {self.path_list[:10]}")

        self.ages: np.ndarray = np.array(
            [float(p.name.split("_")[0]) for p in self.path_list], dtype=np.float32)

        mask = np.ones(len(self.path_list), dtype=bool)

        self.genders = np.array([int(p.name.split("_")[1]) for p in self.path_list])
        self.races = np.array([int(p.name.split("_")[2]) for p in self.path_list])
        
        if domain in ["male", "female"]:
            gender_value = domain_map[domain]
            mask = self.genders == gender_value
            txt_logger.info(f"Filtering by gender: {domain} ({gender_value}), remain {np.sum(mask)} samples.")
        elif domain in ["white", "black", "asian", "indian", "others"]:
            race_value = domain_map[domain]
            mask = self.races == race_value
            txt_logger.info(f"Filtering by race: {domain} ({race_value}), remain {np.sum(mask)} samples.")
        else:
            txt_logger.info(f"No filter for domain {domain}, with {np.sum(mask)} samples.")
        self.path_list = [p for p, m in zip(self.path_list, mask) if m]
        self.ages = self.ages[mask]
        
        
    def _is_valid_name(self, name: str) -> bool:
        parts = name.split("_")
        if len(parts) < 3:
            return False
        try:
            int(parts[0])  # age
            int(parts[1])  # gender
            int(parts[2])  # race
            return True
        except ValueError:
            return False
    
    def __len__(self) -> int:
        return len(self.path_list)

    def __getitem__(self, i: int) -> tuple[Image.Image, float]:
        p = self.path_list[i]
        img = load_image(p)
        # img.info["path"] = p
        return img, self.ages[i]




def get_utkface(fetch_dset: dict[str, Any],
                domain: str,
                path: str,
                apply_normalization: bool = True):
    # for target domain
    train_aug_type = fetch_dset.get("train_aug_type", 'basic')
    aug_type = fetch_dset.get("aug_type", 'basic')
    tr_val_split = fetch_dset["tr_val_split"]
    seed = fetch_dset.get("seed", 42)
    
    if apply_normalization:  # related to pre-trained weights (on ImageNet)
        to_tensor = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
    else:
        to_tensor = transforms.ToTensor()
    
        
    if domain in ["original", "female", 'male', 'white', 'black', 'asian', 'indian', 'others']:
        used_train_aug_type = train_aug_type
        used_val_aug_type = 'val'
        used_aug_type = aug_type
        train_transform = get_utkface_transform(to_tensor, used_train_aug_type)
        val_transform = get_utkface_transform(to_tensor, used_val_aug_type)
        aug_transform = get_utkface_transform(to_tensor, used_aug_type)
        
        
    
    txt_logger.info(f"Train set aug is [{used_train_aug_type}] | Val set aug type is [{used_val_aug_type}] | additional aug [{used_aug_type}].")

        
    utkface_path = Path(path)
    dataset = UTKFace(utkface_path, domain)
    
    if tr_val_split > 1:
        tr_val_file = f"configs/data/utkface-all_val_indices.npy"
        txt_logger.info(f"Dataset Division - UTKFace: load val indices from {tr_val_file!r}")
        val_indices: np.ndarray = np.load(tr_val_file)
        txt_logger.info(f"Example val_indices: {val_indices[:10]}")
        train_mask = np.ones(len(dataset), dtype=np.bool_)
        train_mask[val_indices] = False
        train_indices = np.arange(len(dataset))[train_mask]

        train_raw_ds = ImageSubset(dataset, train_indices.tolist())
        val_raw_ds = ImageSubset(dataset, val_indices.tolist())
    else:
        txt_logger.info(f"Dataset Division - UTKFace: split randomly,  with train ratio {tr_val_split} and seed is {seed}.")
        train_num = int(len(dataset) * tr_val_split)
        train_raw_ds, val_raw_ds = random_split(dataset, train_num, seed)
    txt_logger.info(f"Dataset Division - UTKFace: train data - [{len(train_raw_ds)}] | val data - [{len(val_raw_ds)}]")

    train_ds = ImageTransformDataset(train_raw_ds, train_transform)
    train_aug_ds = ImageTransformDataset(train_raw_ds, aug_transform)
    val_ds = ImageTransformDataset(val_raw_ds, val_transform)
    val_aug_ds = ImageTransformDataset(val_raw_ds, aug_transform)
    
    return IndexedDataset(train_ds), IndexedDataset(train_aug_ds), IndexedDataset(
        val_ds), IndexedDataset(val_aug_ds)


def get_utkface_transform(to_tensor, aug_type):
    match aug_type:
        case "basic":
            aug_transform = transforms.Compose(
        [
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.RandomHorizontalFlip(),
            to_tensor,
        ])
        case "val":
            aug_transform = transforms.Compose(
                [
                    transforms.Resize((256, 256)),
                    transforms.CenterCrop((224, 224)),
                    to_tensor,
                ])
        case "val_noNormalization":
            aug_transform = transforms.Compose(
                [
                    transforms.Resize((256, 256)),
                    transforms.CenterCrop((224, 224)),
                    transforms.ToTensor(),
                ])
        
    return aug_transform

