import os
from pathlib import Path
import shutil
import tarfile
from typing import Optional

import requests
from tqdm import tqdm
from generators.utils.arrows import make_arrows_dataset
from generators.utils.padded_and_downscaled import generate_downscaled_dataset, generate_padded_dataset, generate_sampled_dataset, generate_stretched_dataset, generate_upscaled_dataset
from generators import LIDBenchmarkDatasetGenerator, PCALIDBenchmarkDatasetGenerator
from generators.utils.pca import crescent_moon_pca, exp_pca, gaussian4_pca, gaussian_pca, spaghetti_pca, sphere4_pca, sphere_pca, spiral_pca, train_or_load_pca, uniform_pca
import numpy as np
from sklearn.decomposition import PCA


class SampledDatasetGenerator(LIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e1_sampled_fmnist",
            sampling_step=1,
            N_train=50_000,
            N_val=5_000,
            N_test=5_000,
            copy_val_to_test=True,
            seed: int = 0,
            *args,
            **kwargs,
    ):
        self.copy_val_to_test = copy_val_to_test
        self.sampling_step = sampling_step
        dataset_name += f"_step{sampling_step}"

        super().__init__(dataset_root_dir, dataset_name,
                         N_train, N_val, N_test, seed, *args, **kwargs)

    def _generate_artifacts(self):

        dataset, labels = generate_sampled_dataset(
            self.sampling_step, self.n_train, type="FMNIST")

        self.n_train = dataset.shape[0] - self.n_val - self.n_test

        return {"dataset": dataset, "labels": labels}


class ThreeDIdentDatasetGenerator(LIDBenchmarkDatasetGenerator):

    def __init__(
        self,
        dataset_root_dir: str = "data",
        dataset_name="e2_3dident",
        N_train=225_000,
        N_val=25_000,
        N_test=25_000,
        seed=0,
    ):
        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
        )

    def _download_file(self, url: str, local_filename: str):

        # local_filename = url.split('/')[-1]
        response = requests.get(url, stream=True)

        total_size = int(response.headers.get("content-length", 0))
        block_size = 8192

        already_downloaded = self.check_if_already_downloaded(
            local_filename, total_size)

        if not already_downloaded:
            with tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Downloading {local_filename}") as progress_bar:

                with response as r:
                    r.raise_for_status()
                    with open(local_filename, 'wb') as f:
                        for chunk in r.iter_content(chunk_size=block_size):
                            progress_bar.update(len(chunk))
                            f.write(chunk)
        else:
            self.logger.info(
                f"File {local_filename} already downloaded, skipping...")
        return local_filename

    def check_if_already_downloaded(self, local_filename: str, expected_size: int) -> bool:
        try:
            dowloaded_size = os.path.getsize(local_filename)
        except:
            dowloaded_size = -1
        return dowloaded_size == expected_size

    def _generate_artifacts(self):
        self.base_datasets_dir = Path(self.dataset_dir) / "base_datasets"
        self.base_datasets_dir.mkdir(exist_ok=True, parents=True)

        self.download_3dident_dataset()

        try:
            self.split_3dident_artifacts_to_train_val_test()
        except:
            self.logger.info(
                "Skipping split_3dident_artifacts_to_train_val_test() in 3D Ident - perhaps you already did it?")

        return {}

    def split_3dident_artifacts_to_train_val_test(self):
        base_train = self.base_datasets_dir / "3dident" / "train"
        base_test = self.base_datasets_dir / "3dident" / "test"

        # will raise error if dir exist
        (self.dataset_dir / "train" / "images").mkdir(parents=True)
        (self.dataset_dir / "val" / "images").mkdir(parents=True)

        for img in (base_train / "images").glob("*.png"):
            destination = "train" if int(
                img.stem) < self.n_train else "val"
            shutil.move(img, self.dataset_dir / destination / "images")

        for fname in ["latents.npy", "raw_latents.npy"]:
            x = np.load(base_train / fname)
            np.save(self.dataset_dir / "train" / fname, x[:self.n_train])
            np.save(self.dataset_dir / "val" / fname, x[self.n_train:])

        shutil.move(base_test / "images", self.test_path)
        shutil.move(base_test / "latents.npy", self.test_path)
        # special case - we already did everything, but returning empty dict for compliance
        shutil.move(base_test / "raw_latents.npy", self.test_path)

    def download_3dident_dataset(self):
        self.test_url = "https://zenodo.org/records/4502485/files/3dident_test.tar?download=1"
        self.train_url = "https://zenodo.org/records/4502485/files/3dident_train.tar?download=1"

        for url, file_name in [(self.train_url, "train.tar"), (self.test_url, "test.tar")]:
            tar_path = self.base_datasets_dir / file_name
            self._download_file(url, tar_path)
            self._untar_file(tar_path, self.base_datasets_dir)

    def _untar_file(self, path_tofile, extract_direcotry):
        if tarfile.is_tarfile(path_tofile):
            with tarfile.open(path_tofile) as f:
                f.extractall(path=extract_direcotry)


class CrescentMoonPCADatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e7_crescent_moon",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            radius=3.00,
            seed=0,
            pca: Optional[PCA] = None,
    ):
        dataset_name += f"_radius{radius}"
        self.radius = radius

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca,
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = crescent_moon_pca(
            self.n_train + self.n_val + self.n_test,
            self.pca,
            self.radius,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients
        }


class UniformPCADatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name: str = "e2_uniform_pca",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            dim=20,
            seed=0,
            pca: Optional[PCA] = None,
    ):
        self.dim = dim

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = uniform_pca(
            self.n_train + self.n_val + self.n_test,
            self.dim,
            self.pca,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients}


class GaussianPCADatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e3_gaussian_pca",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            dim=20,
            seed=0,
            pca: Optional[PCA] = None,
    ):
        self.dim = dim

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca,
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = gaussian_pca(
            self.n_train + self.n_val + self.n_test,
            self.dim,
            self.pca,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients}


class SpherePCADatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name: str = "e4_sphere_pca",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            dim=20,
            radius=1,
            seed=0,
            pca: Optional[PCA] = None,
    ):
        self.dim = dim
        self.radius = radius

        dataset_name += f"_radius{radius}"

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca,
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = sphere_pca(
            self.n_train + self.n_val + self.n_test,
            self.dim,
            self.pca,
            self.radius,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients}


class ExpPCADatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name: str = "e6_exp_pca",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            seed=0,
            pca: Optional[PCA] = None,
    ):

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca,
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = exp_pca(
            self.n_train + self.n_val + self.n_test,
            self.pca,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients}


class SpiralPCADatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name: str = "e1_spiral_pca",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            seed=0,
            pca: Optional[PCA] = None,
    ):

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca,
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = spiral_pca(
            self.n_train + self.n_val + self.n_test,
            self.pca,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients}

class StretchedDatasetGenerator(LIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name: str = "e5_stretched",
            N_train=50_000,
            N_val=5_000,
            N_test=5_000,
            copy_val_to_test=True,
            power=4,
            seed=0,
            *args,
            **kwargs,
    ):
        dataset_name += f"_power{power}"
        self.power = power
        self.copy_val_to_test = copy_val_to_test

        super().__init__(dataset_root_dir, dataset_name,
                         N_train, N_val, N_test, seed, *args, **kwargs)

    def _generate_artifacts(self):
        stretched_images, labels = generate_stretched_dataset(
            dataset_name='FMNIST', power=self.power)

        return {"dataset": stretched_images, "labels": labels}



class DownscaledDatasetGenerator(LIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e5_downscaled_fmnist",
            N_train=50_000,
            N_val=5_000,
            N_test=5_000,
            copy_val_to_test=True,
            seed: int = 0,
            *args,
            **kwargs,
    ):
        self.copy_val_to_test = copy_val_to_test
        super().__init__(dataset_root_dir, dataset_name,
                         N_train, N_val, N_test, seed, *args, **kwargs)

    def _generate_artifacts(self):
        downscaled_images, labels = generate_downscaled_dataset(
            dataset_name='FMNIST')

        return {"dataset": downscaled_images, "labels": labels}


class UpscaledDatasetGenerator(LIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e5_upscaled_fmnist",
            N_train=50_000,
            N_val=5_000,
            N_test=5_000,
            copy_val_to_test=True,
            seed: int = 0,
            *args,
            **kwargs,
    ):
        self.copy_val_to_test = copy_val_to_test
        super().__init__(dataset_root_dir, dataset_name,
                         N_train, N_val, N_test, seed, *args, **kwargs)

    def _generate_artifacts(self):
        downscaled_images, labels = generate_upscaled_dataset(
            dataset_name='FMNIST')

        return {"dataset": downscaled_images, "labels": labels}


class PaddedDatasetGenerator(LIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e5_padded_fmnist",
            N_train=50_000,
            N_val=5_000,
            N_test=5_000,
            copy_val_to_test=True,
            seed: int = 0,
            additional_dimension: int = 0,
            *args,
            **kwargs,
    ):
        dataset_name += f"_adddim{additional_dimension}"
        self.additional_dimension = additional_dimension
        self.copy_val_to_test = copy_val_to_test

        super().__init__(dataset_root_dir, dataset_name,
                         N_train, N_val, N_test, seed, *args, **kwargs)

    def _generate_artifacts(self):
        padded_images, labels = generate_padded_dataset(
            dataset_name="FMNIST", additional_dimensions=self.additional_dimension)

        return {"dataset": padded_images, "labels": labels}


class ArrowsDatasetGenerator(LIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e2_arrows",
            N_train=100_000,
            N_val=10_000,
            N_test=10_000,
            seed: int = 0,
            max_arrows: int = 4,
            size: int = 32,
            *args,
            **kwargs,
    ):
        self.max_arrows = max_arrows
        self.size = size

        super().__init__(dataset_root_dir, dataset_name,
                         N_train, N_val, N_test, seed, *args, **kwargs)

    def _generate_artifacts(self):
        dataset, lid = make_arrows_dataset(
            self.n_train + self.n_test + self.n_val,
            self.max_arrows,
            self.size,
            self.seed
        )

        return {"dataset": dataset, "lid": lid}


class Sphere4DatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e8_sphere4_pca",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            dim=6,
            seed=0,
            pca: Optional[PCA] = None,
    ):
        self.dim = dim

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca,
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = sphere4_pca(
            self.n_train + self.n_val + self.n_test,
            self.dim,
            self.pca,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients}

class Gaussian4DatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e8_gaussian4_pca",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            dim=5,
            seed=0,
            pca: Optional[PCA] = None,
    ):
        self.dim = dim

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca,
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = gaussian4_pca(
            self.n_train + self.n_val + self.n_test,
            self.dim,
            self.pca,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients}

class SpaghettiDatasetGenerator(PCALIDBenchmarkDatasetGenerator):

    def __init__(
            self,
            dataset_root_dir: str = "data",
            dataset_name="e8_spaghetti_pca",
            N_train=100_000,
            N_val=1_000,
            N_test=1_000,
            dim=20,
            seed=0,
            pca: Optional[PCA] = None,
    ):
        self.dim = dim

        super().__init__(
            dataset_root_dir=dataset_root_dir,
            dataset_name=dataset_name,
            N_train=N_train,
            N_val=N_val,
            N_test=N_test,
            seed=seed,
            pca=pca,
        )

    def _generate_artifacts(self):
        dataset, lid, coefficients = spaghetti_pca(
            self.n_train + self.n_val + self.n_test,
            self.dim,
            self.pca,
            self.seed,
        )

        return {
            "dataset": dataset,
            "lid": lid,
            "coefficients": coefficients}


