from torchvision import transforms
# from tralo.utils import get_from_repository, random_crop, resize, imread
from os.path import join, expanduser, join, dirname, realpath
import random
import os
from collections import Counter
from PIL import Image
from torchvision.transforms.functional import to_tensor
# from utils import files
import Dataloader_funcs.files
from Dataloader_funcs.utils import ThreeAugment
from torchvision.transforms import RandAugment, RandomResizedCrop, v2, ColorJitter,RandomHorizontalFlip, Resize
from concurrent.futures import ProcessPoolExecutor

class ILSVRC2012(object):

    def __init__(self, split, version='imagenet235px', image_size=224, shuffle=False, 
                limit_classes_to=None, scale=1, download=False, chunks=None,
                aug=None,
                
                ):
        self.image_size = image_size
        self.version = version

        super().__init__()

        if chunks is not None:
            chunks = [f'{i:02d}' for i in chunks]

        if download:
            download_chunks = chunks if chunks is not None else list(range(25))

            file_type = {'imagenet': 'zip', 'imagenet235px': 'tar'}[version]

            print(version)

            with ProcessPoolExecutor(max_workers=4) as executor:
                futures = [executor.submit(files.get_from_s3, chunk=i, name=version, file_type=file_type) for i in download_chunks]
                [f.result() for f in futures]

        if aug is None:
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
            ])
        elif aug == 'resize_crop':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size, scale=(0.3, 1.0),),
            ])
        elif aug == 'resize_crop+aug3':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size, scale=(0.3, 1.0),),
                v2.RandomApply([v2.Grayscale(num_output_channels=3)], p=0.5),
                v2.RandomApply([v2.GaussianBlur(7, sigma=(0.3, 2))], p=0.5),
                v2.RandomSolarize(threshold=100, p=0.5),
                v2.RandomHorizontalFlip(p=0.5)
            ])            

        self.scale = scale
        self.data_path = files.get_dataset_path(version)

        try:
            all_files = open(expanduser(join(self.data_path, 'list_of_files.txt'))).read().split()
        except FileNotFoundError:
            print('list of files could not be found. create new one.')
            folders = [f for f in os.listdir(self.data_path) if f[-4:] not in {'.tar', '.txt'}]
            
            if chunks is not None:
                folders = chunks
            
            all_files = []
            for folder in folders:
                all_files += [join(folder, file) for file in os.listdir(join(self.data_path, folder))]

            # open(expanduser(join('~/datasets', f'ILSVRC2012_{version}', 'list_of_files.txt')), 'w').write('\n'.join(all_files))

        self.classes = sorted(list(set(f.split('/')[-1][:9] for f in all_files)))
        self.classes = [c for c in self.classes if c.startswith('n') and len(c)==9]

        if split == 'train':
            all_files = [s for s in all_files if hash(s) % 15 < 10]
        elif split == 'val':
            all_files = [s for s in all_files if 10 <= hash(s) % 15 < 12]
        elif split == 'test':
            all_files = [s for s in all_files if 12 <= hash(s) % 15 < 14]

        elif split == 'train+':
            all_files = [s for s in all_files if hash(s) % 15 < 14]
        elif split == 'val+':
            all_files = [s for s in all_files if 14 <= hash(s) % 15]
        elif split == 'test+':
            raise ValueError('subset test+ does not exist')

        if shuffle:
            random.shuffle(all_files)

        self.class_counts = Counter(s[:9] for s in all_files)

        print('class frequencies, min: {}, max: {}'.format(min(self.class_counts.values()), max(self.class_counts.values())))

        some_classes = [sorted(self.class_counts.keys())[i] for i in [0, 1, 2, 3, 50, 100, 200, 300]]
        print(some_classes)
        print([self.class_counts[c] for c in some_classes])

        self.all_files = all_files
        # self.classes = set([c.split('/')[-1] for c in self.classes])
        self.label_name_to_class_id = {c: i for i, c in enumerate(self.classes)}

        # import ipdb; ipdb.set_trace()
        if limit_classes_to is not None:

            if type(limit_classes_to) == str:
                valid_classes = open(limit_classes_to).read().split('\n')
                classes = [c for c in self.classes if c in valid_classes]
                self.label_name_to_class_id = {c: i for i, c in enumerate(classes)}
                print(f'{len(classes)} classes remaining')
                all_files = [f for f in all_files if f[:9] in valid_classes]

            elif type(limit_classes_to) == dict:
                valid_classes = sorted(list(set(limit_classes_to.keys())))
                all_values = sorted(list(set(limit_classes_to.values())))
                self.label_name_to_class_id = {k: all_values.index(v) for k, v in limit_classes_to.items()}
                print(f'{len(valid_classes)} classes remaining')

                all_files = [f for f in all_files if f[:9] in valid_classes]

            else:
                raise ValueError()

        self.sample_ids = tuple(all_files)

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, index):
        filename = self.sample_ids[index]

        label_name = filename.split('/')[-1].split('_')[0]

        img = Image.open(expanduser(join(self.data_path, filename))).convert('RGB')

        label = self.label_name_to_class_id[label_name]

        return img, label

    def install(self):
        raise IOError('The ImageNet dataset must be downloaded manually from image-net.org (after registration)')

class ImageNet(object):
    def __init__(self, split='train+',version='imagenet',  image_size=224, chunks=None):
        super().__init__()
        # from eckerlab_vision.lib.datasets import ILSVRC2012
        self.imagenet = ILSVRC2012(split, version=version, chunks=chunks, download=False)

        if split[:2] == 'tr':
            self.transform = transforms.Compose([
                RandomHorizontalFlip(),
                RandomResizedCrop(image_size),
                ColorJitter(.3,.3,.3,0.),
                ThreeAugment(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize with ImageNet mean and std
                                std=[0.229, 0.224, 0.225]),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(256),             # Resize the shorter side to 256
                transforms.CenterCrop(224),         # Then center crop to 224x224
                transforms.ToTensor(),              # Convert PIL image to tensor
                transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize with ImageNet mean and std
                                std=[0.229, 0.224, 0.225]),
            ])


    def __len__(self):
        return len(self.imagenet)

    def __getitem__(self, idx):
        sample, label = self.imagenet[idx]
        return self.transform(sample), label

if __name__ == "__main__":
    def split_list(lst, k):
        """ GPT4 generated """
        n = len(lst)
        return [lst[ (i*n)//k : ((i+1)*n)//k ] for i in range(k)]
    
    world_size = int(os.environ["WORLD_SIZE"])
    print('world', world_size)

    local_rank = int(os.environ["LOCAL_RANK"])

    chunks = split_list(list(range(26)), world_size)
    # version = 'imagenet235px'
    version = 'imagenet'

    dataset = ILSVRC2012('train', version=version, image_size=224, download=True, chunks=chunks[local_rank])
    val_dataset = ILSVRC2012('val', version=version, image_size=224, download=True, chunks=chunks[local_rank])