import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from Dataloader_funcs.DL_registry import *

# Generate label map by assigning a unique integer to each class ID
def create_label_map(image_files):
    label_map = {}
    label_counter = 0
    for img_name in image_files:
        class_label = img_name.split('_')[0]
        if class_label not in label_map:
            label_map[class_label] = label_counter
            label_counter += 1
    return label_map

def split_valid_set(dataset, val_split_ratio=0.2, transform=None):
    # Always do the same split
    cpu_rng_state = torch.get_rng_state()
    torch.manual_seed(42)

    # Calculate split sizes
    total_size = len(dataset)
    valid_size = int(total_size * val_split_ratio)
    train_size = total_size - valid_size
    
    # Split dataset into validation and test sets
    val_dataset, train_dataset = random_split(dataset, [valid_size, train_size])
    val_dataset.transform = transforms

    # return seed to normal so that other RNG is not affected
    torch.set_rng_state(cpu_rng_state)

    return train_dataset, val_dataset

# Custom Dataset Class
class ImageNetDataset(Dataset):
    def __init__(self, root_dir, transform=None, lazy_load=True, label_map=None, subset=-1):
        """
        Args:
            root_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            label_map (dict, optional): Optional mapping of label IDs to human-readable labels.
        """
        self.root_dir = root_dir
        self.transform = transform
        image_files = [f for f in os.listdir(root_dir) if f.endswith('.JPEG')]  # Filter only JPEG images
        if label_map == None:
            self.label_map = create_label_map(image_files)
        else:
            self.label_map = label_map
        
        if lazy_load:
            self.data = image_files[:subset]
            # self.__getitem__ = self.__lazy_getitem__
        else:
            self.data, self.label = self.__store_images__(image_files)
            
    def __len__(self):
        return len(self.data)
    
    # def __getitem__(self, idx):
    #     # Get image file path
    #     image = self.data[idx].copy()
    #     label = self.label[idx]
    #     # Apply transformations if provided
    #     if self.transform:
    #         image = self.transform(image)
    #     # Return image and label
    #     return image, label
    
    def __getitem__(self, idx):
        # Get image file path
        img_name = self.data[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        # Extract the label from filename (e.g., n03786901_18245.JPEG -> n03786901)
        class_label = img_name.split('_')[0]
        
        # Open the image
        image = Image.open(img_path).convert('RGB')
        
        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)
        id_label = self.label_map[class_label]

        # Return image and label
        return image, id_label
    
    def __store_images__(self, image_files):
        images = []
        labels = []
        for img_name in tqdm(image_files):
            img_path = os.path.join(self.root_dir, img_name)
            
            # Extract the label from filename (e.g., n03786901_18245.JPEG -> n03786901)
            class_label = img_name.split('_')[0]
            id_label = self.label_map[class_label]
            
            # Open the image
            image = Image.open(img_path).convert('RGB')

            images.append(image)
            labels.append(id_label)
        return images, labels
            
@register_DL('ImageNet')
def getImageNetDataloaders(root_dir, batch_size=32, shuffle=True,num_workers=4, valid_transform=transforms.ToTensor(), subset=-1):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),        # Randomly crop and resize the image to 224x224
        transforms.RandomHorizontalFlip(),         # Randomly flip the image horizontally
        transforms.ColorJitter(brightness=0.2,    # Randomly change the brightness
                            contrast=0.2,      # Randomly change the contrast
                            saturation=0.2,     # Randomly change the saturation
                            hue=0.1),          # Randomly change the hue
        transforms.ToTensor(),                    # Convert the image to a tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize with ImageNet mean and std
                            std=[0.229, 0.224, 0.225]),
    ])
    # Define dataset and dataloader
    train_dataset = ImageNetDataset(root_dir=root_dir+'/train_data/train', transform=transform, subset=subset)
    validation_dataset = ImageNetDataset(root_dir=root_dir+'/valid_data/valid', transform=valid_transform, label_map=train_dataset.label_map)


    # Create DataLoader (with a batch size of 32)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    valid_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_dataloader, valid_dataloader