

import os
import math
import errno
from typing import List, Dict, Optional
import torch.nn as nn
import numpy as np
from PIL import Image
import h5py

from continuum import ClassIncremental, InstanceIncremental
from continuum.datasets import (
    CIFAR100, ImageNet100, TinyImageNet200, ImageFolderDataset, Core50
)
from .utils import get_dataset_class_names
from MTIL_datasets.stanford_cars import StanfordCars as MTILStanfordCars
from MTIL_datasets.fgvc_aircraft import FGVCAircraft as MTILFGVCAircraft
from MTIL_datasets.caltech101 import Caltech101 as MTILCaltech101
from MTIL_datasets.dtd import DescribableTextures as MTILDTD
from MTIL_datasets.eurosat import EuroSAT as MTILEuroSAT
from MTIL_datasets.oxford_flowers import OxfordFlowers as MTILOxfordFlowers
from MTIL_datasets.food101 import Food101 as MTILFood101
from MTIL_datasets.mnist import MNIST as MTILMNIST
from MTIL_datasets.oxford_pets import OxfordPets as MTILOxfordPets
from MTIL_datasets.sun397 import SUN397 as MTILSUN397
from MTIL_datasets.country211 import Country211 as MTILCountry211
from MTIL_datasets.sst2 import SST2 as MTILSST2
from MTIL_datasets.hatefulmemes import HatefulMemes as MTILHatefulMemes
from MTIL_datasets.gtsrb import GTSRB as MTILGTSRB
from MTIL_datasets.resisc import RESISC45 as MTILRESISC45
from MTIL_datasets.fer2013 import FER2013 as MTILFER2013
from MTIL_datasets.ucf101 import UCF101 as MTILUCF101
from MTIL_datasets.cifar10 import CIFAR10 as MTILCIFAR10
from MTIL_datasets.stl10 import STL10 as MTILSTL10
from MTIL_datasets.voc2007 import VOC2007 as MTILVOC2007
from MTIL_datasets.imagenet_r import ImageNetR as MTILImageNetR
from MTIL_datasets.kitti_distance import KittiDistance as MTILKittiDistance
from MTIL_datasets.pcam import PCam as MTILPCam
from MTIL_datasets.clevr_count import CLEVRCount as MTILCLEVRCount


class ImageNet1000(ImageFolderDataset):
    """Continuum dataset for datasets with tree-like structure.
    :param train_folder: The folder of the train data.
    :param test_folder: The folder of the test data.
    :param download: Dummy parameter.
    """

    def __init__(
            self,
            data_path: str,
            train: bool = True,
            download: bool = False,
    ):
        super().__init__(data_path=data_path, train=train, download=download)

    def get_data(self):
        if self.train:
            self.data_path = os.path.join(self.data_path, "train")
        else:
            self.data_path = os.path.join(self.data_path, "val")
        return super().get_data()


class MTILImageFolderCL(ImageFolderDataset):
    """Generic Continuum-compatible adapter for MTIL datasets.

    Builds a class-folder cache using the MTIL wrapper's file list so that
    Continuum's ImageFolderDataset can load it.
    """

    def __init__(self, data_root: str, mtil_cls, train: bool = True, download: bool = False, mtil_kwargs: Optional[Dict] = None):
        mtil_kwargs = mtil_kwargs or {}
        _inst = None
        try:
            _inst = mtil_cls(root=data_root, num_shots=0, seed=1, **mtil_kwargs)
        except TypeError:
            try:
                _inst = mtil_cls(root=data_root, seed=1, **mtil_kwargs)
            except TypeError:
                _inst = mtil_cls(root=data_root, **mtil_kwargs)
        self._mtil = _inst
        self.dataset_dir = self._mtil.dataset_dir  # absolute path inside data_root
        self.cache_split = 'train' if train else 'test'
        self.cache_dir = os.path.join(self.dataset_dir, 'cl_cache', self.cache_split)
        self._ensure_cache(train)
        super().__init__(data_path=self.cache_dir, train=train, download=download)

    def _safe_symlink(self, src: str, dst: str):
        try:
            if not os.path.exists(dst):
                os.symlink(src, dst)
        except OSError as e:
            if e.errno in (errno.EPERM, errno.EACCES, errno.ENOTSUP):
                if not os.path.exists(dst):
                    import shutil
                    shutil.copy2(src, dst)
            else:
                raise

    def _ensure_cache(self, train: bool):
        if os.path.isdir(self.cache_dir) and any(os.scandir(self.cache_dir)):
            return
        os.makedirs(self.cache_dir, exist_ok=True)
        items = getattr(self._mtil, 'train_x') if train else getattr(self._mtil, 'test')
        # per-class counters to generate unique filenames when saving from arrays/PIL
        per_class_count: Dict[str, int] = {}
        for item in items:
            class_dir = os.path.join(self.cache_dir, item.classname)
            os.makedirs(class_dir, exist_ok=True)
            impath = item.impath
            # Direct string path case
            if isinstance(impath, str):
                dst = os.path.join(class_dir, os.path.basename(impath))
                self._safe_symlink(impath, dst)
                continue
            # H5 tuple path case: ('h5', abs_path, key, index)
            if isinstance(impath, tuple) and len(impath) == 4 and impath[0] == 'h5':
                _, fpath, key, index = impath
                try:
                    with h5py.File(fpath, 'r') as f:
                        arr = f[key][int(index)]
                        arr = np.asarray(arr)
                        # CHW -> HWC if needed
                        if arr.ndim == 3 and arr.shape[0] in (1, 3) and arr.shape[-1] not in (1, 3):
                            arr = np.transpose(arr, (1, 2, 0))
                        if arr.dtype != np.uint8:
                            arr = arr.astype(np.uint8)
                        if arr.ndim == 2:
                            img = Image.fromarray(arr, mode='L').convert('RGB')
                        else:
                            if arr.shape[-1] == 1:
                                img = Image.fromarray(arr.squeeze(-1), mode='L').convert('RGB')
                            else:
                                img = Image.fromarray(arr, mode='RGB')
                except Exception:
                    continue
                c = per_class_count.get(item.classname, 0)
                per_class_count[item.classname] = c + 1
                dst = os.path.join(class_dir, f"{c:08d}.png")
                try:
                    img.save(dst)
                except Exception:
                    pass
                continue
            # PIL Image case
            if isinstance(impath, Image.Image):
                img = impath
                c = per_class_count.get(item.classname, 0)
                per_class_count[item.classname] = c + 1
                dst = os.path.join(class_dir, f"{c:08d}.png")
                try:
                    img.convert('RGB').save(dst)
                except Exception:
                    pass
                continue
            # Unsupported type: skip

    def get_data(self):
        return super().get_data()


class StanfordCarsCL(ImageFolderDataset):
    """Continuum-compatible StanfordCars built on a symlinked cache.

    We reuse the MTIL StanfordCars parser to read splits and then create a
    class-folder tree with symlinks so that ImageFolderDataset can load it.
    """

    def __init__(self, data_root: str, train: bool = True, download: bool = False):
        # data_root is the global dataset_root, which contains 'stanford_cars/'
        self.dataset_dir = os.path.join(data_root, 'stanford_cars')
        self.cache_split = 'train' if train else 'test'
        self.cache_dir = os.path.join(self.dataset_dir, 'cl_cache', self.cache_split)
        self._ensure_cache(data_root, train)
        super().__init__(data_path=self.cache_dir, train=train, download=download)

    def _safe_symlink(self, src: str, dst: str):
        try:
            if not os.path.exists(dst):
                os.symlink(src, dst)
        except OSError as e:
            # Fallback: if symlink not permitted (e.g., Windows), copy the file
            if e.errno in (errno.EPERM, errno.EACCES, errno.ENOTSUP):
                if not os.path.exists(dst):
                    # lazy import to avoid overhead
                    import shutil
                    shutil.copy2(src, dst)
            else:
                raise

    def _ensure_cache(self, data_root: str, train: bool):
        if os.path.isdir(self.cache_dir) and any(os.scandir(self.cache_dir)):
            return  # cache exists

        os.makedirs(self.cache_dir, exist_ok=True)
        mtil = MTILStanfordCars(root=data_root, num_shots=0, seed=1)
        items = mtil.train_x if train else mtil.test

        # Build class folders and link images
        for item in items:
            class_dir = os.path.join(self.cache_dir, item.classname)
            os.makedirs(class_dir, exist_ok=True)
            dst = os.path.join(class_dir, os.path.basename(item.impath))
            self._safe_symlink(item.impath, dst)

    def get_data(self):
        # ImageFolderDataset expects data_path to be set; already set to cache_dir
        return super().get_data()


def get_dataset(cfg, is_train, transforms=None):
    if cfg.dataset == "cifar100":
        data_path = os.path.join(cfg.dataset_root, cfg.dataset)
        dataset = CIFAR100(
            data_path=data_path, 
            download=True, 
            train=is_train, 
            # transforms=transforms
        )
        classes_names = dataset.dataset.classes

    elif cfg.dataset == "tinyimagenet":
        data_path = os.path.join(cfg.dataset_root, cfg.dataset)
        dataset = TinyImageNet200(
            data_path, 
            train=is_train,
            download=True
        )
        classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
        
    elif cfg.dataset == "imagenet100":
        data_path = os.path.join(cfg.dataset_root, "ImageNet")
        dataset = ImageNet100(
            data_path, 
            train=is_train,
            data_subset=os.path.join('/home/dhw/yjz_workspace/project1_y/CIL_ours_compare_v3_lr_5e_3_1router_l2/Continual-CLIP/dataset_reqs/imagenet100_splits', "train_100.txt" if is_train else "val_100.txt")
        )
        classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)

    elif cfg.dataset == "imagenet1000":
        data_path = os.path.join(cfg.dataset_root, cfg.dataset)
        dataset = ImageNet1000(
            data_path, 
            train=is_train
        )
        classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)

    elif cfg.dataset == "core50":
        data_path = os.path.join(cfg.dataset_root, cfg.dataset)
        dataset = dataset = Core50(
            data_path, 
            scenario="domains", 
            classification="category", 
            train=is_train
        )
        classes_names = [
            "plug adapters", "mobile phones", "scissors", "light bulbs", "cans", 
            "glasses", "balls", "markers", "cups", "remote controls"
        ]
    
    elif cfg.dataset in ("stanfordcars", "stanford_cars", "StanfordCars"):
        # Use global dataset_root; MTIL wrapper expects root that contains 'stanford_cars/'
        dataset = StanfordCarsCL(
            data_root=cfg.dataset_root,
            train=is_train,
            download=False,
        )
        # Classes inferred from folder names in cache
        # Continuum internally wraps a torchvision ImageFolder as .dataset
        try:
            classes_names = dataset.dataset.classes  # type: ignore[attr-defined]
        except Exception:
            # Fallback: discover from cache directory
            classes_names = sorted([d.name for d in os.scandir(dataset.data_path) if d.is_dir()])
    elif cfg.dataset in ("aircraft", "caltech101", "dtd", "eurosat", "oxford_flowers", "food101", "mnist", "oxford_pets", "sun397",
                         "country211", "sst2", "hatefulmemes", "gtsrb", "resisc45", "fer2013", "ucf101", "cifar10", "stl10",
                         "voc2007", "imagenet_r", "kitti_distance", "pcam", "clevr_count"):
        key2cls = {
            "aircraft": MTILFGVCAircraft,
            "caltech101": MTILCaltech101,
            "dtd": MTILDTD,
            "eurosat": MTILEuroSAT,
            "oxford_flowers": MTILOxfordFlowers,
            "food101": MTILFood101,
            "mnist": MTILMNIST,
            "oxford_pets": MTILOxfordPets,
            "sun397": MTILSUN397,
            "country211": MTILCountry211,
            "sst2": MTILSST2,
            "hatefulmemes": MTILHatefulMemes,
            "gtsrb": MTILGTSRB,
            "resisc45": MTILRESISC45,
            "fer2013": MTILFER2013,
            "ucf101": MTILUCF101,
            "cifar10": MTILCIFAR10,
            "stl10": MTILSTL10,
            "voc2007": MTILVOC2007,
            "imagenet_r": MTILImageNetR,
            "kitti_distance": MTILKittiDistance,
            "pcam": MTILPCam,
            "clevr_count": MTILCLEVRCount,
        }
        # Special kwargs per dataset (e.g., VOC2007 single-label mode)
        special_kwargs: Dict[str, Dict] = {
            "voc2007": {"single_label": True},
        }
        dataset = MTILImageFolderCL(
            data_root=cfg.dataset_root,
            mtil_cls=key2cls[cfg.dataset],
            train=is_train,
            download=False,
            mtil_kwargs=special_kwargs.get(cfg.dataset, {}),
        )
        try:
            classes_names = dataset.dataset.classes  # type: ignore[attr-defined]
        except Exception:
            classes_names = sorted([d.name for d in os.scandir(dataset.data_path) if d.is_dir()])
    
    else:
        ValueError(f"'{cfg.dataset}' is a invalid dataset.")

    return dataset, classes_names


def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module:

    dataset, classes_names = get_dataset(cfg, is_train)

    if cfg.scenario == "class":
        # Build balanced increments if cil_splits is specified: first r tasks get base+1, rest base
        num_classes = len(classes_names)
        # Determine target number of tasks K:
        # Prefer scalar cfg.cil_splits; if it's list/ListConfig (multi-dataset), derive K from cfg.increment.
        raw_cil = getattr(cfg, 'cil_splits', 0)
        cil_splits_scalar = None
        # scalar forms
        try:
            if isinstance(raw_cil, (int, str)):
                cil_splits_scalar = int(raw_cil)
        except Exception:
            cil_splits_scalar = None
        # list/ListConfig -> ambiguous in multi-dataset; ignore and use increment
        try:
            from omegaconf import ListConfig  # type: ignore
            if isinstance(raw_cil, (list, tuple, ListConfig)):
                cil_splits_scalar = None
        except Exception:
            pass

        scenario = None
        # Determine class order: default to contiguous [0..num_classes-1] if not provided
        default_order = list(range(num_classes))
        class_order_to_use = getattr(cfg, 'class_order', None)
        # Coerce Hydra ListConfig to list if needed
        try:
            from omegaconf import ListConfig  # type: ignore
            if isinstance(class_order_to_use, ListConfig):
                class_order_to_use = list(class_order_to_use)
        except Exception:
            pass
        if class_order_to_use in (None, []):
            class_order_to_use = default_order
        # decide K
        k = None
        if isinstance(cil_splits_scalar, int) and cil_splits_scalar > 0:
            k = min(cil_splits_scalar, max(1, num_classes))
        else:
            # derive from cfg.increment (robust to ListConfig/list/str)
            inc_val = getattr(cfg, 'increment', None)
            inc_int = None
            try:
                from omegaconf import ListConfig  # type: ignore
            except Exception:
                ListConfig = tuple()  # type: ignore
            try:
                if isinstance(inc_val, (list, tuple)):
                    inc_int = int(inc_val[0]) if len(inc_val) > 0 else None
                elif isinstance(inc_val, ListConfig):  # type: ignore
                    inc_int = int(inc_val[0]) if len(inc_val) > 0 else None
                elif inc_val is not None:
                    inc_int = int(inc_val)
            except Exception:
                inc_int = None
            if inc_int and inc_int > 0:
                k = max(1, math.ceil(num_classes / inc_int))
            else:
                # try from initial_increment
                init_val = getattr(cfg, 'initial_increment', None)
                try:
                    if isinstance(init_val, (list, tuple)):
                        init_int = int(init_val[0]) if len(init_val) > 0 else None
                    elif isinstance(init_val, ListConfig):  # type: ignore
                        init_int = int(init_val[0]) if len(init_val) > 0 else None
                    elif init_val is not None:
                        init_int = int(init_val)
                    else:
                        init_int = None
                except Exception:
                    init_int = None
                if init_int and init_int > 0:
                    k = max(1, math.ceil(num_classes / init_int))

        # Always construct full increments list; avoid int-only path to prevent sum mismatch on last task
        if not k or k <= 0:
            k = 1
        k = min(k, max(1, num_classes))
        base = num_classes // k
        r = num_classes % k
        increments = [base + 1] * r + [base] * (k - r)
        # Construct scenario with full list increments; omit initial_increment to avoid sum mismatch
        scenario = ClassIncremental(
            dataset,
            increment=increments,
            transformations=transforms.transforms,  # Convert Compose into list
            class_order=class_order_to_use,
        )

    elif cfg.scenario == "domain":
        scenario = InstanceIncremental(
            dataset,
            transformations=transforms.transforms,
        )

    elif cfg.scenario == "task-agnostic":
        NotImplementedError("Method has not been implemented. Soon be added.")

    else:
        ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, " 
                    "please choose from {{'class', 'domain', 'task-agnostic'}}.")

    return scenario, classes_names