# Modified from https://github.com/hbzju/PiCO/blob/main/utils/cifar100.py

import os.path
import pickle
from typing import Any, Callable, Optional, Tuple
import numpy as np
from PIL import Image
import imageio
import PIL
import os

from collections import defaultdict

from tqdm.autonotebook import tqdm
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision.datasets as dsets

from utils.wide_resnet import WideResNet
from utils.utils_algo import generate_uniform_cv_candidate_labels, generate_hierarchical_cv_candidate_labels
from utils.cutout import Cutout
from utils.autoaugment import CIFAR10Policy, ImageNetPolicy


def load_tiny_imagenet(args):
    
    test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343], \
                                  std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])
    ])

    original_train = TinyImageNetDataset(mode='train')
    ori_data, ori_labels = original_train.img_data, torch.Tensor(original_train.label_data).long()

    print('original data:', type(ori_data), "\n size:", len(ori_data))
    
    
    test_dataset = TinyImageNetDataset(mode='val')
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size*4, shuffle=False, \
                                              num_workers=args.workers,\
                                              sampler=torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False))
    
    if args.hierarchical:
        partialY_matrix = generate_hierarchical_cv_candidate_labels('cifar100', ori_labels, args)
    else:
        partialY_matrix = generate_uniform_cv_candidate_labels(ori_labels, args.partial_rate, args.noisy_rate)
        # if args.exp_type == 'rand':
        #     partialY_matrix = generate_uniform_cv_candidate_labels(args, ori_labels)
        # elif args.exp_type == 'ins':
        #     ori_data = torch.Tensor(original_train.data)
        #     model = WideResNet(depth=28, num_classes=100, widen_factor=10, dropRate=0.3)
        #     model.load_state_dict(torch.load('./pmodel/cifar100.pt'))
        #     ori_data = ori_data.permute(0, 3, 1, 2)
        #     partialY_matrix = generate_instancedependent_candidate_labels(model, ori_data, ori_labels)
        #     ori_data = original_train.data
            
    temp = torch.zeros(partialY_matrix.shape)
    temp[torch.arange(partialY_matrix.shape[0]), ori_labels] = 1
    
    if torch.sum(partialY_matrix * temp) == partialY_matrix.shape[0]:
        print('Partial labels correctly loaded !')
    else:
        print('Inconsistent permutation !')
    
    print('Average candidate num: ', partialY_matrix.sum(1).mean())
    
    partial_training_dataset = tinyImagenet_Partialize(ori_data, partialY_matrix.float(), ori_labels.float())
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(partial_training_dataset)
    
    partial_training_dataloader = torch.utils.data.DataLoader(
        dataset=partial_training_dataset, 
        batch_size=args.batch_size, 
        shuffle=(train_sampler is None), 
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True
    )
    
    return partial_training_dataloader, partialY_matrix, train_sampler, test_loader



class tinyImagenet_Partialize(Dataset):
    def __init__(self, images, given_partial_label_matrix, true_labels):
        
        self.ori_images = images
        self.given_partial_label_matrix = given_partial_label_matrix
        self.true_labels = true_labels
        
        self.transform1 = transforms.Compose([
            transforms.RandomResizedCrop(64),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        self.transform2 = transforms.Compose([
            transforms.RandomResizedCrop(64),
            transforms.RandomHorizontalFlip(),
            ImageNetPolicy(),
            transforms.ToTensor(),
            Cutout(n_holes=1, length=32),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])


    def __len__(self):
        return len(self.true_labels)
        
    def __getitem__(self, index):
        
        x_ = Image.fromarray(self.ori_images[index])
        
        
        x_i1 = self.transform1(x_)
        x_i2 = self.transform2(x_)
        x_i_partial_label = self.given_partial_label_matrix[index]
        x_i_true_label = self.true_labels[index]


        
        return x_i1, x_i2, x_i_partial_label, x_i_true_label, index


dir_structure_help = r"""
TinyImageNetPath
├── test
│   └── images
│       ├── test_0.JPEG
│       ├── t...
│       └── ...
├── train
│   ├── n01443537
│   │   ├── images
│   │   │   ├── n01443537_0.JPEG
│   │   │   ├── n...
│   │   │   └── ...
│   │   └── n01443537_boxes.txt
│   ├── n01629819
│   │   ├── images
│   │   │   ├── n01629819_0.JPEG
│   │   │   ├── n...
│   │   │   └── ...
│   │   └── n01629819_boxes.txt
│   ├── n...
│   │   ├── images
│   │   │   ├── ...
│   │   │   └── ...
├── val
│   ├── images
│   │   ├── val_0.JPEG
│   │   ├── v...
│   │   └── ...
│   └── val_annotations.txt
├── wnids.txt
└── words.txt
"""

def download_and_unzip(URL, root_dir):
  error_message = "Download is not yet implemented. Please, go to {URL} urself."
  raise NotImplementedError(error_message.format(URL))

def _add_channels(img, total_channels=3):
  while len(img.shape) < 3:  # third axis is the channels
    img = np.expand_dims(img, axis=-1)
  while(img.shape[-1]) < 3:
    img = np.concatenate([img, img[:, :, -1:]], axis=-1)
  return img

"""Creates a paths datastructure for the tiny imagenet.
Args:
  root_dir: Where the data is located
  download: Download if the data is not there
Members:
  label_id:
  ids:
  nit_to_words:
  data_dict:
"""
class TinyImageNetPaths:
  def __init__(self, root_dir, download=False):
    if download:
      download_and_unzip('http://cs231n.stanford.edu/tiny-imagenet-200.zip',
                         root_dir)
    train_path = os.path.join(root_dir, 'train')
    val_path = os.path.join(root_dir, 'val')
    test_path = os.path.join(root_dir, 'test')

    wnids_path = os.path.join(root_dir, 'wnids.txt')
    words_path = os.path.join(root_dir, 'words.txt')

    self._make_paths(train_path, val_path, test_path,
                     wnids_path, words_path)

  def _make_paths(self, train_path, val_path, test_path,
                  wnids_path, words_path):
    # collect instance id
    self.ids = []
    with open(wnids_path, 'r') as idf:
      for nid in idf:
        nid = nid.strip()
        self.ids.append(nid)
    
    # convert the id into words
    self.nid_to_words = defaultdict(list)
    with open(words_path, 'r') as wf:
      for line in wf:
        nid, labels = line.split('\t')
        labels = list(map(lambda x: x.strip(), labels.split(',')))
        self.nid_to_words[nid].extend(labels)

    self.paths = {
      'train': [],  # [img_path, id, nid, box]
      'val': [],  # [img_path, id, nid, box]
      'test': []  # img_path
    }

    # Get the test paths
    self.paths['test'] = list(map(lambda x: os.path.join(test_path, x),
                                      os.listdir(test_path)))
    # Get the validation paths and labels
    with open(os.path.join(val_path, 'val_annotations.txt')) as valf:
      for line in valf:
        fname, nid, x0, y0, x1, y1 = line.split()
        fname = os.path.join(val_path, 'images', fname)
        bbox = int(x0), int(y0), int(x1), int(y1)
        label_id = self.ids.index(nid)
        self.paths['val'].append((fname, label_id, nid, bbox))

    # Get the training paths
    train_nids = os.listdir(train_path)
    for nid in train_nids:
      anno_path = os.path.join(train_path, nid, nid+'_boxes.txt')
      imgs_path = os.path.join(train_path, nid, 'images')
      label_id = self.ids.index(nid)
      with open(anno_path, 'r') as annof:
        for line in annof:
          fname, x0, y0, x1, y1 = line.split()
          fname = os.path.join(imgs_path, fname)
          bbox = int(x0), int(y0), int(x1), int(y1)
          self.paths['train'].append((fname, label_id, nid, bbox))


class TinyImageNetDataset(Dataset):
  """Datastructure for the tiny image dataset.
  Args:
    root_dir: Root directory for the data
    mode: One of "train", "test", or "val"
    preload: Preload into memory
    load_transform: Transformation to use at the preload time
    transform: Transformation to use at the retrieval time
    download: Download the dataset
  Members:
    tinp: Instance of the TinyImageNetPaths
    img_data: Image data
    label_data: Label data
  """
  def __init__(self, root_dir="/Projects/data/tiny-imagenet-200", mode='train', preload=True, load_transform=None,
               transform=None, download=False, max_samples=None):
    tinp = TinyImageNetPaths(root_dir, download)
    self.mode = mode
    self.label_idx = 1  # from [image, id, nid, box]
    self.preload = preload
    self.transform = transform
    self.transform_results = dict()

    self.IMAGE_SHAPE = (64, 64, 3)

    self.img_data = []
    self.label_data = []

    self.max_samples = max_samples
    self.samples = tinp.paths[mode]
    self.samples_num = len(self.samples)

    self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343], \
                                  std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])
    ])

    if self.max_samples is not None:
      self.samples_num = min(self.max_samples, self.samples_num)
      self.samples = np.random.permutation(self.samples)[:self.samples_num]

    if self.preload:
      load_desc = "Preloading {} data...".format(mode)
      self.img_data = np.zeros((self.samples_num,) + self.IMAGE_SHAPE,
                               dtype=np.uint8)
      self.label_data = np.zeros((self.samples_num,), dtype=int)
      for idx in tqdm(range(self.samples_num), desc=load_desc):
        s = self.samples[idx]
        img = imageio.imread(s[0])
        img = _add_channels(img)
        self.img_data[idx] = img
        if mode != 'test':
          self.label_data[idx] = s[self.label_idx]

      if load_transform:
        for lt in load_transform:
          result = lt(self.img_data, self.label_data)
          self.img_data, self.label_data = result[:2]
          if len(result) > 2:
            self.transform_results.update(result[2])

  def __len__(self):
    return self.samples_num

  
  def __getitem__(self, idx):
    if self.preload:
        img = self.img_data[idx]
        lbl = None if self.mode == 'test' else self.label_data[idx]
    else:
        s = self.samples[idx]
        img = imageio.imread(s[0])
        img = _add_channels(img)
        lbl = None if self.mode == 'test' else s[self.label_idx]

    # 将NumPy数组转换为PIL图像
    img = PIL.Image.fromarray(img.astype('uint8'), 'RGB')

    # 如果是验证集，则应用预先定义的测试转换
    if self.mode == 'val':
        img = self.test_transform(img)

    # # 应用任何其他额外的转换
    if self.transform:
        img = self.transform(img)

    return img, lbl
