import os
import torch
import os.path
import shutil
import logging
import tarfile
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image
from torch import Generator

from torchvision.datasets.utils import (
    download_and_extract_archive,
    verify_str_arg,
)
from torchvision.datasets.vision import VisionDataset

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def clean_corrupted_files(directory, filename):
    """
    Clean up potentially corrupted downloaded files.
    
    Args:
        directory (str): Directory where the file is located
        filename (str): Name of the file to clean up
    """
    filepath = os.path.join(directory, filename)
    if os.path.exists(filepath):
        try:
            logger.info(f"Removing potentially corrupted file: {filepath}")
            os.remove(filepath)
        except OSError as e:
            logger.error(f"Error removing file {filepath}: {e}")
            
    # Also clean up the extracted directory if it exists
    extracted_dir = os.path.splitext(filename)[0]
    extracted_path = os.path.join(directory, extracted_dir)
    if os.path.exists(extracted_path) and os.path.isdir(extracted_path):
        try:
            logger.info(f"Removing potentially corrupted directory: {extracted_path}")
            shutil.rmtree(extracted_path)
        except OSError as e:
            logger.error(f"Error removing directory {extracted_path}: {e}")

class Caltech256:
    def __init__(self,
                 preprocess,
                 location=os.path.expanduser('~/data'),
                 batch_size=32,
                 num_workers=16,
                 train_fraction=0.8,
                 seed=0):
        dataset = PyTorchCaltech256(location, transform=preprocess, download=True)
        train_size = int(len(dataset) * train_fraction)
        test_size = len(dataset) - train_size
        self.train_dataset, self.test_dataset = torch.utils.data.random_split(
            dataset, [train_size, test_size], generator=Generator().manual_seed(seed)
        )

        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            shuffle=True,
            batch_size=batch_size,
            num_workers=num_workers,
        )
        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=batch_size,
            num_workers=num_workers
        )

        self.classnames = [n.split('.')[-1].replace('-', ' ').replace(' 101', '') for n in dataset.categories]

class Caltech101:
    def __init__(self,
                 preprocess,
                 location=os.path.expanduser('~/data'),
                 batch_size=32,
                 num_workers=16,
                 train_fraction=0.8,
                 seed=0):
        dataset = PyTorchCaltech101(location, transform=preprocess, download=True)
        train_size = int(len(dataset) * train_fraction)
        test_size = len(dataset) - train_size
        self.train_dataset, self.test_dataset = torch.utils.data.random_split(
            dataset, [train_size, test_size], generator=Generator().manual_seed(seed)
        )
        
        self.test_dataset.transform = preprocess
       
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            shuffle=True,
            batch_size=batch_size,
            num_workers=num_workers,
        )
        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=batch_size,
            num_workers=num_workers
        )

        self.classnames = [n.split('.')[-1].replace('-', ' ').replace(' 101', '') for n in dataset.categories]
        

class PyTorchCaltech101(VisionDataset):
    """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.

    .. warning::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

    Args:
        root (string): Root directory of dataset where directory
            ``caltech101`` exists or will be saved to if download is set to True.
        target_type (string or list, optional): Type of target to use, ``category`` or
            ``annotation``. Can also be a list to output a tuple with all specified
            target types.  ``category`` represents the target class, and
            ``annotation`` is a list of points from a hand-generated outline.
            Defaults to ``category``.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    def __init__(
        self,
        root: str,
        target_type: Union[List[str], str] = "category",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super().__init__(
            os.path.join(root, "caltech101"),
            transform=transform,
            target_transform=target_transform,
        )
        os.makedirs(self.root, exist_ok=True)
        if isinstance(target_type, str):
            target_type = [target_type]
        self.target_type = [
            verify_str_arg(t, "target_type", ("category", "annotation"))
            for t in target_type
        ]

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted. You can use download=True to download it"
            )

        self.categories = sorted(
            os.listdir(os.path.join(self.root, "101_ObjectCategories"))
        )
        self.categories.remove("BACKGROUND_Google")  # this is not a real class

        # For some reason, the category names in "101_ObjectCategories" and
        # "Annotations" do not always match. This is a manual map between the
        # two. Defaults to using same name, since most names are fine.
        name_map = {
            "Faces": "Faces_2",
            "Faces_easy": "Faces_3",
            "Motorbikes": "Motorbikes_16",
            "airplanes": "Airplanes_Side_2",
        }
        self.annotation_categories = list(
            map(lambda x: name_map[x] if x in name_map else x, self.categories)
        )

        self.index: List[int] = []
        self.y = []
        for i, c in enumerate(self.categories):
            n = len(
                os.listdir(os.path.join(self.root, "101_ObjectCategories", c))
            )
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """
        import scipy.io

        img = Image.open(
            os.path.join(
                self.root,
                "101_ObjectCategories",
                self.categories[self.y[index]],
                f"image_{self.index[index]:04d}.jpg",
            )
        )

        target: Any = []
        for t in self.target_type:
            if t == "category":
                target.append(self.y[index])
            elif t == "annotation":
                data = scipy.io.loadmat(
                    os.path.join(
                        self.root,
                        "Annotations",
                        self.annotation_categories[self.y[index]],
                        f"annotation_{self.index[index]:04d}.mat",
                    )
                )
                target.append(data["obj_contour"])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def _check_integrity(self) -> bool:
        # Check if the main directory exists
        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

    def __len__(self) -> int:
        return len(self.index)

    def download(self) -> None:
        if self._check_integrity():
            print("Files already downloaded and verified")
            return

        # Clean up potentially corrupted files first
        clean_corrupted_files(self.root, "caltech-101.zip")
        
        # Try to download the file
        try:
            logger.info(f"Downloading Caltech101 dataset to {self.root}")
            download_and_extract_archive(
                "https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip?download=1",
                self.root,
                filename="caltech-101.zip",
                md5="3138e1922a9193bfa496528edbbc45d0",
                remove_finished=True,
            )
        except Exception as e:
            logger.error(f"Error downloading Caltech101: {e}")
            raise RuntimeError(f"Error downloading dataset: {e}") from e
        
        # Print directory structure for debugging
        logger.info(f"Contents of {self.root}:")
        for item in os.listdir(self.root):
            item_path = os.path.join(self.root, item)
            if os.path.isdir(item_path):
                logger.info(f"  Directory: {item}")
                # List subdirectories
                try:
                    for subitem in os.listdir(item_path):
                        logger.info(f"    - {subitem}")
                except Exception as e:
                    logger.warning(f"Could not list contents of {item_path}: {e}")
            else:
                logger.info(f"  File: {item}")
        
        # Flag to track if we successfully extracted the 101_ObjectCategories
        extracted_successfully = False
        
        # Check for the 101_ObjectCategories directory
        if not os.path.exists(os.path.join(self.root, "101_ObjectCategories")):
            # Check if it's in the caltech-101 directory
            caltech_101_dir = os.path.join(self.root, "caltech-101")
            if os.path.exists(caltech_101_dir):
                # Check if we need to extract the internal tar.gz
                internal_targz = os.path.join(caltech_101_dir, "101_ObjectCategories.tar.gz")
                if os.path.exists(internal_targz):
                    logger.info(f"Found internal archive: {internal_targz}")
                    # Extract the internal tar.gz file
                    try:
                        with tarfile.open(internal_targz) as tar:
                            logger.info(f"Extracting {internal_targz} to {self.root}")
                            tar.extractall(self.root)
                        logger.info("Extraction of internal archive completed")
                        extracted_successfully = True
                    except Exception as e:
                        logger.error(f"Error extracting internal archive: {e}")
                
                # Check directly inside caltech-101 for 101_ObjectCategories
                direct_obj_cats = os.path.join(caltech_101_dir, "101_ObjectCategories")
                if os.path.exists(direct_obj_cats):
                    logger.info(f"Found 101_ObjectCategories at {direct_obj_cats}")
                    # Move to root
                    logger.info(f"Moving {direct_obj_cats} to {self.root}")
                    shutil.move(direct_obj_cats, self.root)
                    extracted_successfully = True
                
                # Check for Annotations
                direct_annotations = os.path.join(caltech_101_dir, "Annotations")
                if os.path.exists(direct_annotations):
                    logger.info(f"Found Annotations at {direct_annotations}")
                    logger.info(f"Moving {direct_annotations} to {self.root}")
                    shutil.move(direct_annotations, self.root)
            
            # If we still don't have the ObjectCategories and haven't successfully extracted yet,
            # check if there's a tar.gz file in the root
            root_targz = os.path.join(self.root, "101_ObjectCategories.tar.gz")
            if not extracted_successfully and os.path.exists(root_targz):
                logger.info(f"Found tar.gz file in root directory: {root_targz}")
                # Extract the tar.gz file
                try:
                    with tarfile.open(root_targz) as tar:
                        logger.info(f"Extracting {root_targz} to {self.root}")
                        tar.extractall(self.root)
                    logger.info("Extraction of root archive completed")
                    extracted_successfully = True
                except Exception as e:
                    logger.error(f"Error extracting root archive: {e}")
                    # Don't raise an exception here, as we might have extracted successfully from another source
        else:
            logger.info("101_ObjectCategories directory already exists")
            extracted_successfully = True
        
        # Clean up the __MACOSX directory if it exists
        macosx_dir = os.path.join(self.root, "__MACOSX")
        if os.path.exists(macosx_dir):
            logger.info(f"Removing __MACOSX directory: {macosx_dir}")
            shutil.rmtree(macosx_dir)
        
        # Clean up the caltech-101 directory if it exists and we've moved everything out
        caltech_101_dir = os.path.join(self.root, "caltech-101")
        if os.path.exists(caltech_101_dir) and not any([
            os.path.exists(os.path.join(caltech_101_dir, "101_ObjectCategories")),
            os.path.exists(os.path.join(caltech_101_dir, "Annotations"))
        ]):
            logger.info(f"Removing empty caltech-101 directory: {caltech_101_dir}")
            shutil.rmtree(caltech_101_dir)
        
        # If we still can't find the directory, raise an error
        if not self._check_integrity():
            raise RuntimeError(
                "Failed to extract the dataset properly. Please check the extracted files manually."
            )

    def extra_repr(self) -> str:
        return "Target type: {target_type}".format(**self.__dict__)


class PyTorchCaltech256(VisionDataset):
    """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``caltech256`` exists or will be saved to if download is set to True.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super().__init__(
            os.path.join(root, "caltech256"),
            transform=transform,
            target_transform=target_transform,
        )
        os.makedirs(self.root, exist_ok=True)

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted. You can use download=True to download it"
            )

        self.categories = sorted(
            os.listdir(os.path.join(self.root, "256_ObjectCategories"))
        )
        self.index: List[int] = []
        self.y = []
        for i, c in enumerate(self.categories):
            n = len(
                [
                    item
                    for item in os.listdir(
                        os.path.join(self.root, "256_ObjectCategories", c)
                    )
                    if item.endswith(".jpg")
                ]
            )
            self.index.extend(range(1, n + 1))
            self.y.extend(n * [i])

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img = Image.open(
            os.path.join(
                self.root,
                "256_ObjectCategories",
                self.categories[self.y[index]],
                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
            )
        )

        target = self.y[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def _check_integrity(self) -> bool:
        # Check if the main directory exists
        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

    def __len__(self) -> int:
        return len(self.index)

    def download(self) -> None:
        if self._check_integrity():
            print("Files already downloaded and verified")
            return

        # Clean up potentially corrupted files first
        clean_corrupted_files(self.root, "256_ObjectCategories.tar")
        
        try:
            logger.info(f"Downloading Caltech256 dataset to {self.root}")
            download_and_extract_archive(
                "https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar?download=1",
                self.root,
                filename="256_ObjectCategories.tar",
                md5="67b4f42ca05d46448c6bb8ecd2220f6d",
                remove_finished=True,
            )
        except Exception as e:
            logger.error(f"Error downloading Caltech256: {e}")
            raise RuntimeError(f"Error downloading dataset: {e}") from e
        
        # Print directory structure for debugging
        logger.info(f"Contents of {self.root}:")
        for item in os.listdir(self.root):
            item_path = os.path.join(self.root, item)
            if os.path.isdir(item_path):
                logger.info(f"  Directory: {item}")
                # List subdirectories
                try:
                    for subitem in os.listdir(item_path):
                        logger.info(f"    - {subitem}")
                except Exception as e:
                    logger.warning(f"Could not list contents of {item_path}: {e}")
            else:
                logger.info(f"  File: {item}")
        
        # Flag to track if we successfully extracted the 256_ObjectCategories
        extracted_successfully = False
        
        # Check for the 256_ObjectCategories directory
        if not os.path.exists(os.path.join(self.root, "256_ObjectCategories")):
            # Check if it's in a subdirectory like caltech-256
            for subdir in ["caltech-256", "caltech256"]:
                caltech_256_dir = os.path.join(self.root, subdir)
                if os.path.exists(caltech_256_dir):
                    # Look for 256_ObjectCategories inside this directory
                    direct_obj_cats = os.path.join(caltech_256_dir, "256_ObjectCategories")
                    if os.path.exists(direct_obj_cats):
                        logger.info(f"Found 256_ObjectCategories at {direct_obj_cats}")
                        # Move to root
                        logger.info(f"Moving {direct_obj_cats} to {self.root}")
                        shutil.move(direct_obj_cats, self.root)
                        extracted_successfully = True
                    
                    # Clean up the directory if it's now empty
                    if os.path.exists(caltech_256_dir) and len(os.listdir(caltech_256_dir)) == 0:
                        logger.info(f"Removing empty directory: {caltech_256_dir}")
                        shutil.rmtree(caltech_256_dir)
        else:
            logger.info("256_ObjectCategories directory already exists")
            extracted_successfully = True
        
        # If we still can't find the directory, raise an error
        if not self._check_integrity():
            raise RuntimeError(
                "Failed to extract the Caltech256 dataset properly. Please check the extracted files manually."
            )