import os
import requests
import tarfile
import shutil
import random
import numpy as np
import torch
from torch.utils.data import Dataset, Subset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

def load_tiny_imagenet(root_dir):
    """
    Load Tiny ImageNet dataset from the provided root directory.
    If the dataset doesn't exist at the specified location, it will be downloaded.
    
    Args:
        root_dir: Path where the Tiny ImageNet dataset should be stored
        
    Returns:
        train_dataset: Dataset object for training data
        val_dataset: Dataset object for validation data
        class_to_idx: Dictionary mapping class names to indices
    """
    # URL for the Tiny ImageNet dataset
    tiny_imagenet_url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    
    # Create directory if it doesn't exist
    os.makedirs(root_dir, exist_ok=True)
    
    # Check if the dataset exists
    dataset_path = os.path.join(root_dir, 'tiny-imagenet-200')
    if not os.path.exists(dataset_path) or not os.path.exists(os.path.join(dataset_path, 'train')):
        print(f"Tiny ImageNet dataset not found at {dataset_path}. Downloading...")
        
        # Download the dataset
        zip_path = os.path.join(root_dir, 'tiny-imagenet-200.zip')
        download_file(tiny_imagenet_url, zip_path)
        
        # Extract the dataset
        print("Extracting dataset...")
        import zipfile
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(root_dir)
        
        # Remove the zip file to save space
        os.remove(zip_path)
        print("Dataset extraction complete.")
    
    # Get class directories
    train_dir = os.path.join(dataset_path, 'train')
    class_dirs = [d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))]
    class_to_idx = {class_dir: i for i, class_dir in enumerate(sorted(class_dirs))}
    
    # Create datasets
    train_dataset = TinyImageNetDataset(
        root_dir=dataset_path,
        split='train',
        class_to_idx=class_to_idx
    )
    
    val_dataset = TinyImageNetDataset(
        root_dir=dataset_path,
        split='val',
        class_to_idx=class_to_idx
    )
    
    print(f"Loaded Tiny ImageNet with {len(train_dataset)} training and {len(val_dataset)} validation images across {len(class_to_idx)} classes.")
    
    return train_dataset, val_dataset, class_to_idx

def download_file(url, dest_path):
    """
    Downloads a file with a progress bar
    """
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    
    with open(dest_path, 'wb') as f, tqdm(
            desc=f"Downloading {os.path.basename(dest_path)}",
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
        for data in response.iter_content(block_size):
            size = f.write(data)
            bar.update(size)

class TinyImageNetDataset(Dataset):
    def __init__(self, root_dir, split='train', class_to_idx=None, transform=None):
        """
        Args:
            root_dir: Root directory of Tiny ImageNet dataset
            split: 'train' or 'val'
            class_to_idx: Class name to index mapping
            transform: Optional transforms to apply to images
        """
        self.root_dir = root_dir
        self.split = split
        self.class_to_idx = class_to_idx
        self.transform = transform if transform is not None else transforms.ToTensor()
        
        self.images = []
        self.labels = []
        
        if split == 'train':
            # Process training data
            for class_dir in os.listdir(os.path.join(root_dir, 'train')):
                class_path = os.path.join(root_dir, 'train', class_dir)
                if not os.path.isdir(class_path):
                    continue
                    
                images_dir = os.path.join(class_path, 'images')
                for img_file in os.listdir(images_dir):
                    if img_file.endswith('.JPEG'):
                        self.images.append(os.path.join(images_dir, img_file))
                        self.labels.append(self.class_to_idx[class_dir])
        
        elif split == 'val':
            # Process validation data using val_annotations.txt
            val_annotations_file = os.path.join(root_dir, 'val', 'val_annotations.txt')
            if os.path.exists(val_annotations_file):
                with open(val_annotations_file, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        img_file, class_dir = parts[0], parts[1]
                        if class_dir in self.class_to_idx:
                            self.images.append(os.path.join(root_dir, 'val', 'images', img_file))
                            self.labels.append(self.class_to_idx[class_dir])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        # Load and transform image
        try:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a placeholder in case of error
            return torch.zeros((3, 64, 64)), label
        
