import zipfile
import os
from PIL import Image, UnidentifiedImageError
import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path

# Path to your mini-imagenet zip file
zip_file_path = 'datasets/ImageNet/mini-imagenet.zip'
# Temporary directory to extract images
extract_to = 'datasets/ImageNet/imagenet.zip'

# Ensure the extract directory exists
os.makedirs(extract_to, exist_ok=True)

# Open the zip file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    # List all files in the zip
    file_list = zip_ref.namelist()
    
    # Use tqdm to show progress for extraction
    first_image_shown = False
    for file_name in tqdm(file_list, desc="Extracting images"):
        # Extract each file
        extracted_path = zip_ref.extract(file_name, extract_to)
        
        # Check if the extracted path is a file and has a valid image extension
        if os.path.isfile(extracted_path) and file_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')):
            try:
                # Visualize the first image only once
                if not first_image_shown:
                    img = Image.open(extracted_path)
                    plt.imshow(img)
                    plt.title(f"First Extracted Image: {file_name}")
                    plt.axis('off')  # Hide axes
                    plt.show()
                    first_image_shown = True
            except UnidentifiedImageError:
                print(f"Skipping file (unidentified image): {file_name}")

# Transformations for the dataset (standard for ResNet)
transform = transforms.Compose([
    transforms.Resize((224, 224)),   # ResNet standard input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the dataset from the extracted directory
dataset = datasets.ImageFolder(root=extract_to, transform=transform)

# Use tqdm to show progress when splitting the dataset
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Save the datasets as torch files with tqdm progress
torch.save(train_dataset, 'train_dataset_mini_imagenet.pth')
torch.save(test_dataset, 'test_dataset_mini_imagenet.pth')

print(f"Datasets saved: {train_size} training samples and {test_size} testing samples.")
