import os
import random
import os.path as osp
import tarfile
import zipfile
from collections import defaultdict
import gdown
import errno
import warnings
import json
import numpy as np
import h5py
from torch.utils.data import Dataset as TorchDataset
from PIL import Image


class Datum:
    def __init__(self, impath="", label=0, domain=0, classname=""):
        self._impath = impath
        self._label = label
        self._domain = domain
        self._classname = classname

    @property
    def impath(self):
        return self._impath

    @property
    def label(self):
        return self._label

    @property
    def domain(self):
        return self._domain

    @property
    def classname(self):
        return self._classname


class DatasetBase:
    dataset_dir = "" 
    domains = []

    def __init__(self, train_x=None, train_u=None, val=None, test=None):
        self._train_x = train_x
        self._train_u = train_u
        self._val = val
        self._test = test
        self._num_classes = self.get_num_classes(train_x)
        self._lab2cname, self._classnames = self.get_lab2cname(train_x)

    @property
    def train_x(self):
        return self._train_x

    @property
    def train_u(self):
        return self._train_u

    @property
    def val(self):
        return self._val

    @property
    def test(self):
        return self._test

    @property
    def lab2cname(self):
        return self._lab2cname

    @property
    def classnames(self):
        return self._classnames

    @property
    def num_classes(self):
        return self._num_classes

    @staticmethod
    def get_num_classes(data_source):
        label_set = set()
        for item in data_source:
            label_set.add(item.label)
        return max(label_set) + 1

    @staticmethod
    def get_lab2cname(data_source):
        container = set()
        for item in data_source:
            container.add((item.label, item.classname))
        mapping = {label: classname for label, classname in container}
        labels = list(mapping.keys())
        labels.sort()
        classnames = [mapping[label] for label in labels]
        return mapping, classnames

    @property
    def template(self):
        return self.templates[0]

    def check_input_domains(self, source_domains, target_domains):
        assert len(source_domains) > 0, "source_domains (list) is empty"
        assert len(target_domains) > 0, "target_domains (list) is empty"
        self.is_input_domain_valid(source_domains)
        self.is_input_domain_valid(target_domains)

    def is_input_domain_valid(self, input_domains):
        for domain in input_domains:
            if domain not in self.domains:
                raise ValueError(
                    "Input domain must belong to {}, "
                    "but got [{}]".format(self.domains, domain)
                )

    def download_data(self, url, dst, from_gdrive=True):
        if not osp.exists(osp.dirname(dst)):
            os.makedirs(osp.dirname(dst))

        if from_gdrive:
            gdown.download(url, dst, quiet=False)
        else:
            raise NotImplementedError

        print("Extracting file ...")

        if dst.endswith(".zip"):
            zip_ref = zipfile.ZipFile(dst, "r")
            zip_ref.extractall(osp.dirname(dst))
            zip_ref.close()

        elif dst.endswith(".tar"):
            tar = tarfile.open(dst, "r:")
            tar.extractall(osp.dirname(dst))
            tar.close()

        elif dst.endswith(".tar.gz"):
            tar = tarfile.open(dst, "r:gz")
            tar.extractall(osp.dirname(dst))
            tar.close()

        else:
            raise NotImplementedError

        print("File extracted to {}".format(osp.dirname(dst)))

    def generate_fewshot_dataset(
        self, *data_sources, num_shots=-1, repeat=False
    ):
        if num_shots < 1:
            if len(data_sources) == 1:
                return data_sources[0]
            return data_sources

        print(f"Creating a {num_shots}-shot dataset")

        output = []

        for data_source in data_sources:
            tracker = self.split_dataset_by_label(data_source)
            dataset = []

            for label, items in tracker.items():
                if len(items) >= num_shots:
                    sampled_items = random.sample(items, num_shots)
                else:
                    if repeat:
                        sampled_items = random.choices(items, k=num_shots)
                    else:
                        sampled_items = items
                dataset.extend(sampled_items)

            output.append(dataset)

        if len(output) == 1:
            return output[0]

        return output

    def split_dataset_by_label(self, data_source):
        output = defaultdict(list)

        for item in data_source:
            output[item.label].append(item)

        return output

    def split_dataset_by_domain(self, data_source):
        output = defaultdict(list)

        for item in data_source:
            output[item.domain].append(item)

        return output


class DatasetWrapper(TorchDataset):

    def __init__(self, data_source, transform=None, is_train=False):
        self.data_source = data_source
        self.transform = transform
        self.is_train = is_train
        self._h5_cache = {}
        self._h5_info_printed = set()

    def __len__(self):
        return len(self.data_source)

    def __getitem__(self, idx):
        item = self.data_source[idx]

        impath = item.impath
        if isinstance(impath, str):
            img0 = Image.open(impath).convert("RGB")
        elif isinstance(impath, tuple) and len(impath) == 4 and impath[0] == 'h5':
            _, fpath, key, index = impath
            f = self._h5_cache.get(fpath)
            if f is None:
                f = h5py.File(fpath, 'r')
                self._h5_cache[fpath] = f
                # Print file info once
                try:
                    keys = list(f.keys())
                    print(f"H5 open: {os.path.basename(fpath)} keys={keys[:5]}{'...' if len(keys) > 5 else ''}")
                except Exception as e:
                    print(f"H5 open (keys) failed for {fpath}: {e}")
            # Print dataset info per file the first time we see this path
            if fpath not in self._h5_info_printed:
                try:
                    ds = f[key]
                    print(f"H5 dataset: {os.path.basename(fpath)}[{key}] shape={getattr(ds, 'shape', '?')} dtype={getattr(ds, 'dtype', '?')}")
                except Exception as e:
                    print(f"H5 dataset info failed for {fpath}[{key}]: {e}")
                self._h5_info_printed.add(fpath)
            arr = f[key][int(index)]
            arr = np.asarray(arr)
            # Convert CHW -> HWC if needed
            if arr.ndim == 3 and arr.shape[0] in (1, 3) and arr.shape[-1] not in (1, 3):
                arr = np.transpose(arr, (1, 2, 0))
            # Ensure HWC and uint8
            if arr.ndim == 3 and arr.shape[-1] in (1, 3):
                pass
            else:
                raise ValueError(f"Unexpected H5 image shape: {arr.shape}")
            if arr.dtype != np.uint8:
                arr = arr.astype(np.uint8)
            if arr.shape[-1] == 1:
                img0 = Image.fromarray(arr.squeeze(-1), mode='L').convert('RGB')
            else:
                img0 = Image.fromarray(arr, mode='RGB')
        else:
            # Fallback: if already PIL Image
            if isinstance(impath, Image.Image):
                img0 = impath
            else:
                raise ValueError("Unsupported impath type in DatasetWrapper")

        if self.transform:
            img = self.transform(img0)
        else:
            img = img0

        return img, item.label


def check_isfile(fpath):
    isfile = osp.isfile(fpath)
    if not isfile:
        warnings.warn('No file found at "{}"'.format(fpath))
    return isfile


def mkdir_if_missing(dirname):
    if not osp.exists(dirname):
        try:
            os.makedirs(dirname)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise


def listdir_nohidden(path, sort=False):
    items = [f for f in os.listdir(path) if not f.startswith(".")]
    if sort:
        items.sort()
    return items


def read_json(fpath):
    with open(fpath, "r") as f:
        obj = json.load(f)
    return obj


def write_json(obj, fpath):
    mkdir_if_missing(osp.dirname(fpath))
    with open(fpath, "w") as f:
        json.dump(obj, f, indent=4, separators=(",", ": "))