import os
import sys
import shutil
import torch
import pickle
import time
import numpy as np
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from PIL import Image

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

def get_imagenet_iter(data_type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, manual_seed, val_size=256, world_size=1,
                      local_rank=0, save_path="/home/admin1/2Tsdc/Syh/Training-free-quant/mixed_bit/data/train", target_classes=200):
    print("get_imagenet_iter")
    time.sleep(2)
    if data_type == 'train':
        train_dataset = ImageFolder(root='/home/admin1/dataset/Dataset/imagenet/train')
        class_to_idx = train_dataset.class_to_idx

        class_to_folder = ""

        for key, value in class_to_idx.items():
            if value == target_classes:
                class_to_folder = key
                break

        for filename in os.listdir(save_path):
            file_path = os.path.join(save_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')

        source_folder = os.path.join(image_dir, class_to_folder)
        target_folder = os.path.join(save_path, class_to_folder)
        if not os.path.exists(target_folder):
            os.makedirs(target_folder, exist_ok=True)
            print("Create")
        
        for filename in os.listdir(source_folder):
            source_file = os.path.join(source_folder, filename)
            target_file = os.path.join(target_folder, filename)
            shutil.copy(source_file, target_file)

        files = [file for file in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, file))]
        file_count = len(files)
        print(file_count)
        # time.sleep(5)

        train_dir = save_path
        train_dataset = datasets.ImageFolder(
            train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_threads,
            pin_memory=True
        )
        return train_loader
    elif data_type == 'val':
        pass

def get_imagenet_all_iter(data_type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, manual_seed, val_size=256, world_size=1,
                      local_rank=0):
    
    if data_type == 'train':
        train_dataset = datasets.ImageFolder(
            root='/home/admin1/dataset/Dataset/imagenet/train',
            transform=transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_threads,
            pin_memory=True
        )
        return train_loader
    elif data_type == 'val':
        pass

def get_cifar10_iter(data_type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, manual_seed, val_size=256, world_size=1,
                      local_rank=0, save_path="/home/admin1/2Tsdc/Syh/Training-free-quant/mixed_bit/cifar10/train", target_classes=2):
    print("get_cifar10_iter")
    time.sleep(2)
    if data_type == 'train':

        if os.path.isdir(save_path):
            for root, dirs, files in os.walk(save_path):
                for file in files:
                    file_path = os.path.join(root, file)
                    os.remove(file_path)
            print("Delete End")

        os.makedirs(save_path + '/' + str(target_classes), exist_ok=True)
        if os.path.isdir(save_path + '/' + str(target_classes)):
            print("Create")
        else:
            print("Error")

        transform = transforms.ToTensor()

        cifar10_dataset = datasets.CIFAR10(root='/home/admin1/dataset/Dataset/cifar-10', train=True, download=True, transform=transform)

        for idx, (image, label) in enumerate(cifar10_dataset):
            if label == target_classes:
                img = transforms.ToPILImage()(image)
                img.save(os.path.join(save_path+"/"+str(target_classes), f"{idx}.png"))
        print("end")
        # time.sleep(10)

        train_dir = save_path
        train_dataset = datasets.ImageFolder(
            root=train_dir,
            transform=transforms.Compose([
                transforms.Resize(224),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
        )

        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_threads,
            pin_memory=True
        )

        return train_loader
    else:
        pass

def get_cifar10_all_iter(data_type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, manual_seed, val_size=256, world_size=1,
                      local_rank=0):
    if data_type == 'train':

        transform=transforms.Compose([
                transforms.Resize(224),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])

        train_dataset = datasets.CIFAR10(
            root='/home/admin1/dataset/Dataset/cifar-10',
            train=True,
            download=True,
            transform=transform
        )

        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_threads,
            pin_memory=True
        )

        return train_loader
    else:
        pass

        

       

        