import os
from glob import glob

import cv2
import h5py
import numpy as np
import torch
import torch.utils.data as data
from torchvision import transforms

from config import Constants
from utils import (Model_Logger, RandomRotation90, _setup_size,
                   get_image_filename_list, loss_less_dot_resize, random_crop, random_segmentation)

IMG_EXTENSIONS = [
    '*.png', '*.jpeg', '*.jpg', '*.tif', '*.PNG', '*.JPEG', '*.JPG', '*.TIF'
]

logger = Model_Logger('data')


class Counting_Domains_dataset(object):
    def __init__(self,
                 args,
                 memory_saving: bool = False):
        super().__init__()

        self.source_dataset_path = os.path.join(
            Constants.DATA_FOLDER, Constants.DATASET[args.source_dataset])
        self.target_dataset_path = os.path.join(
            Constants.DATA_FOLDER, Constants.DATASET[args.target_dataset])

        self.source_dataset_type = args.source_dataset_type
        self.target_dataset_type = args.target_dataset_type

        self.source_samples = self.load_data(
            self.source_dataset_path, self.source_dataset_type)
        self.target_samples = self.load_data(
            self.target_dataset_path, self.target_dataset_type)

        self.target_samples_train, self.target_samples_val = self.split_target_domains(
            self.target_samples, args.training_ratio)

        self.combined_samples_train = self.combine_samples(
            self.source_samples, self.target_samples_train)

    def combine_samples(self, source_samples, target_samples):
        combined_samples = []
        source_length = len(source_samples)
        for target_sample in target_samples:
            for _ in range(3):
                random_sample = source_samples[np.random.randint(source_length)]
                combined_samples.append((random_sample, target_sample))
        return combined_samples

    def split_target_domains(self, imgs, training_ratio):
        num_imgs = len(imgs)
        num_train = int(num_imgs * training_ratio)
        indices = np.random.permutation(num_imgs)
        train_indices = indices[:num_train]
        val_indices = indices[num_train:]

        train_imgs = [imgs[i] for i in train_indices]
        val_imgs = [imgs[i] for i in val_indices]

        return train_imgs, val_imgs

    def obtain_samples(self):
        return self.combined_samples_train, self.target_samples_val

    def load_data(self, root, type):
        if type == 'h5py':
            try:
                filename = glob(pathname='*.hdf5', root_dir=root)[0]
            except IndexError:
                raise FileNotFoundError(
                    "hdf5 file not found in {}. Check dataset type and file.".
                    format(root))
            h5_file = h5py.File(os.path.join(root, filename))
            try:
                imgs = np.asarray(h5_file['imgs'])
                dots = np.asarray(h5_file['counts'])
                samples = list(zip(imgs, dots))

            except KeyError:
                raise KeyError(
                    "Not a proper hdf5 structure. It should include \'imgs\' key and \'counts\' key!"
                )
        else:
            file_list = get_image_filename_list(root)
            if file_list is None:
                raise FileNotFoundError(
                    "There are not any supported image files in {}".format(root))
            samples = []
            for dot_filename, raw_filename in file_list:
                dot_filename = os.path.join(root, 'dot', dot_filename)
                raw_filename = os.path.join(root, 'raw', raw_filename)
                if os.path.exists(raw_filename) is not True:
                    logger.warning(
                        "Could not find the raw file {}. Skipping this file"
                        .format(raw_filename))
                    continue
                if os.path.exists(dot_filename) is not True:
                    logger.warning(
                        "Could not find the annotations file {}. Skipping this file"
                        .format(dot_filename))
                    continue
                samples.append((raw_filename, dot_filename))
        return samples


class Binary_Counting_Dataset(data.Dataset):
    def __init__(self,
                 samples,
                 args):
        self.samples = samples
        # self.C_size = _setup_size(crop_size, 'Error random crop size.')
        self.R_size = _setup_size(args.image_resize, 'Error resize size.')
        self.transform = False
        self.memory_saving = args.memory_saving

    def sample_loader(self, sample):
        if isinstance(sample[0], str):
            img_ori = sample[0]
            dot_ori = sample[1]
            img = np.array(cv2.imread(sample[0]))
            dot = np.array(cv2.imread(sample[1]))
        else:
            self.memory_saving = False
            img_ori = None
            dot_ori = None
            img = np.copy(sample[0])
            dot = np.copy(sample[1])

        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        if len(dot.shape) == 3:
            dot = cv2.cvtColor(dot, cv2.COLOR_BGR2GRAY)
            dot = dot.squeeze()

        img = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
        dot = torch.from_numpy(dot).float()

        if self.R_size:
            img = transforms.functional.resize(img, self.R_size, antialias=True)
            # dot = loss_less_dot_resize(dot, self.R_size)
            dot = transforms.functional.resize(dot.unsqueeze(0).unsqueeze(0), self.R_size, transforms.functional.InterpolationMode.NEAREST_EXACT)
            dot [dot > 0] = 1

        return (img.squeeze(0), dot.squeeze(0), img_ori, dot_ori)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample = self.samples[index]
        source_sample = sample[0]
        target_sample = sample[1]

        source = self.sample_loader(source_sample)
        target = self.sample_loader(target_sample)

        return source, target

# Implement Dataset Class


class Counting_dataset(data.Dataset):
    def __init__(self,
                 samples,
                 args):
        self.samples = samples
        self.memory_saving = args.memory_saving
        self.R_size = _setup_size(args.image_resize, 'Error resize size.')
        self.transform = False

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        img, dot = self.samples[index]
        if isinstance(img, str) and isinstance(dot, str):
            img_ori = img
            dot_ori = dot
        else:
            self.memory_saving = False
            img_ori = None
            dot_ori = None

        if self.memory_saving:
            img = np.array(cv2.imread(img))
            dot = np.array(cv2.imread(dot))
        else:
            img = np.copy(img)
            dot = np.copy(dot)

        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        if len(dot.shape) == 3:
            dot = cv2.cvtColor(dot, cv2.COLOR_BGR2GRAY)
            # dot = dot.squeeze(-1)

        img = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
        dot = torch.from_numpy(dot).float()

        if self.R_size:
            img = transforms.functional.resize(
                img, self.R_size, antialias=True)
            # dot = loss_less_dot_resize(dot, self.R_size)
            dot = transforms.functional.resize(dot.unsqueeze(0).unsqueeze(
                0), self.R_size, transforms.functional.InterpolationMode.NEAREST_EXACT)
            dot[dot > 0] = 1
        img.float()
        dot.float()
        count = torch.sum(dot).int().squeeze(-1)

        # # Random cropping
        # if self.C_size and img is not None:
        #     if self.mode == 'train' and min(img.shape) >= self.C_size[0]:
        #         # i, j, height, width = random_crop(img.shape, self.C_size)
        #         # img = transforms.functional.crop(img, i, j, height, width)
        #         # dot = transforms.functional.crop(dot, i, j, height, width)
        #         img = transforms.functional.resize(img.unsqueeze(0),
        #                                            self.C_size,
        #                                            antialias=True).squeeze(0)
        #         dot = transforms.functional.resize(dot.unsqueeze(0),
        #                                            self.C_size,
        #                                            antialias=True).squeeze(0)
        #     else:
        #         img = transforms.functional.resize(img.unsqueeze(0),
        #                                            self.C_size,
        #                                            antialias=True).squeeze(0)
        #         dot = transforms.functional.resize(dot.unsqueeze(0),
        #                                            self.C_size,
        #                                            antialias=True).squeeze(0)

        return (img.squeeze(0), dot.squeeze(0), img_ori, dot_ori)

    # def load_data(self):
    #     if self.type == 'h5py':
    #         try:
    #             filename = glob(pathname='*.hdf5', root_dir=self.root)[0]
    #         except IndexError:
    #             raise FileNotFoundError(
    #                 "hdf5 file not found in {}. Check dataset type and file.".
    #                 format(self.root))
    #         h5_file = h5py.File(os.path.join(self.root, filename))
    #         try:
    #             self.imgs = np.asarray(h5_file['imgs'])
    #             self.dots = np.asarray(h5_file['counts'])
    #         except KeyError:
    #             raise KeyError(
    #                 "Not a proper hdf5 structure. It should include \'imgs\' key and \'counts\' key!"
    #             )
    #     else:  # type == 'image'
    #         file_list = get_image_filename_list(self.root)
    #         if file_list is None:
    #             raise FileNotFoundError(
    #                 "There are not any supported image files in {}".format(
    #                     self.root))
    #         imgs = []
    #         dots = []
    #         for dot_filename, raw_filename in file_list:
    #             dot_filename = os.path.join(self.root, 'dot', dot_filename)
    #             raw_filename = os.path.join(self.root, 'raw', raw_filename)
    #             if os.path.exists(raw_filename) is not True:
    #                 logger.warning(
    #                     "Could not find the raw file {}. Skipping this file"
    #                     .format(raw_filename))
    #                 continue
    #             if os.path.exists(dot_filename) is not True:
    #                 logger.warning(
    #                     "Could not find the annotations file {}. Skipping this file"
    #                     .format(dot_filename))
    #                 continue
    #             if self.memory_saving:
    #                 imgs.append(raw_filename)
    #                 dots.append(dot_filename)
    #             else:
    #                 imgs.append(np.asarray(cv2.imread(raw_filename)))
    #                 dots.append(np.asarray(cv2.imread(dot_filename)))
    #         self.imgs = imgs
    #         self.dots = dots


class MNIST_dataset(data.Dataset):
    def __init__(self,
                 dataset_dictionary: str,
                 filelist: list[str],
                 transform: tuple | None = None):
        self.root = dataset_dictionary
        self.transform = transform
        self.filenames = filelist
        self.load_data()

    def __len__(self):
        return self.labels.__len__()

    def __getitem__(self, index):
        img = np.copy(self.imgs[index])
        label = self.labels[index]
        if self.transform is not None:
            img = self.transform(img)

        return img.float(), torch.from_numpy(label).long()

    def load_data(self):
        self.imgs = []
        self.labels = []
        with open(self.filenames, 'r') as f:
            filelist = f.readlines()
        img_dir = os.path.basename(self.filenames)
        for filename in filelist:
            img_filename, label = filename.split(' ')
            img = np.asarray(
                cv2.imread(os.path.join(self.root, img_dir, img_filename)))
            if self.resize is not None:
                img = np.resize(img, (self.resize, self.resize))
            self.imgs.append(img)
            self.labels.append(np.asarray([int(label)]))
