#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 10 23:51:58 2022

@author: yaoyao
"""


import numpy as np
import torch
import torchvision
# from itertools import chain
from PIL import Image
import os
import cv2

class TINYIMAGENET_LT():

    def __init__(self, 
                 path,
                 resize = 'False',
                 dtype = np.float32,
                 cls_num = 200,
                 imb_type='exp', 
                 imb_factor=0.01, 
                 rand_number=0, 
                 train=True,
                 transform=None, 
                 target_transform=None,
                 index=None, split=False):
        
        self.cls_num = cls_num
        self.transform = transform
        self.target_transform = target_transform
        np.random.seed(rand_number)
        
        if train:
            # First load wnids
            wnids_file = 'wnids.txt'
            with open(os.path.join(path, wnids_file), 'r') as f:
              wnids = [x.strip() for x in f]
    
            # Map wnids to integer labels
            wnid_to_label = {wnid: i for i, wnid in enumerate(wnids)}
    
            # Use words.txt to get names for each class
            words_file = 'words.txt'
            with open(os.path.join(path, words_file), 'r') as f:
              wnid_to_words = dict(line.split('\t') for line in f)
              for wnid, words in wnid_to_words.items():
                wnid_to_words[wnid] = [w.strip() for w in words.split(',')]
            class_names = [wnid_to_words[wnid] for wnid in wnids]
    
            # Next load training data.
            X_train = []
            y_train = []
            for i, wnid in enumerate(wnids):
              if (i + 1) % 20 == 0:
                print('loading training data for synset %d / %d' % (i + 1, len(wnids)))
              # To figure out the filenames we need to open the boxes file
              boxes_file = os.path.join(path, 'train', wnid, '%s_boxes.txt' % wnid)
              with open(boxes_file, 'r') as f:
                filenames = [x.split('\t')[0] for x in f]
              num_images = len(filenames)
              
              if resize.lower() == 'true':
                X_train_block = np.zeros((num_images, 3, 32, 32), dtype=dtype)
              else:
                X_train_block = np.zeros((num_images, 3, 64, 64), dtype=dtype)
              
              y_train_block = wnid_to_label[wnid] * np.ones(num_images, dtype=np.int64)
              for j, img_file in enumerate(filenames):
                img_file = os.path.join(path, 'train', wnid, 'images', img_file)
                img = cv2.imread(img_file)
                
                X_train_block[j] = img.transpose(2, 0, 1)
              X_train.append(X_train_block)
              y_train.append(y_train_block)
                
            # We need to concatenate all training data
            X_train = np.concatenate(X_train, axis=0)
            y_train = np.concatenate(y_train, axis=0)
            
            self.data = np.einsum('iljk->ijkl', X_train)
            self.targets = y_train        
        
        else:
        
            # Next load validation data
            with open(os.path.join(path, 'val', 'val_annotations.txt'), 'r') as f:
              img_files = []
              val_wnids = []
              for line in f:
                # Select only validation images in chosen wnids set
                if line.split()[1] in wnids:
                  img_file, wnid = line.split('\t')[:2]
                  img_files.append(img_file)
                  val_wnids.append(wnid)
              num_val = len(img_files)
              y_val = np.array([wnid_to_label[wnid] for wnid in val_wnids])
              
              if resize.lower() == 'true':
                X_val = np.zeros((num_val, 3, 32, 32), dtype=dtype)
              else:
                X_val = np.zeros((num_val, 3, 64, 64), dtype=dtype)
           
              for i, img_file in enumerate(img_files):
                img_file = os.path.join(path, 'val', 'images', img_file)
                img = cv2.imread(img_file)
    
                X_val[i] = img.transpose(2, 0, 1)
                
                self.data = np.einsum('iljk->ijkl', X_val)
                self.targets = y_val
        
        # X_train = np.einsum('iljk->ijkl', X_train)
        # X_val = np.einsum('iljk->ijkl', X_val)
     
        # if train:
        #     self.data = X_train
        #     self.targets = y_train
        # else:
        #     self.data = X_val
        #     self.targets = y_val
        
        img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
        self.gen_imbalanced_data(img_num_list)
        
        if split:
            new_data = torch.as_tensor(self.data)
            new_targets = torch.as_tensor(self.targets)
            new_data = new_data[index]
            new_targets = new_targets[index]
            self.data = new_data.numpy()
            self.targets = new_targets.numpy()
    

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num

        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)

        return img_num_per_cls


    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)
        # np.random.shuffle(classes)
        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
        new_targets = np.array([0]*len(self.targets))
        self.targets = np.array(self.targets)
        new_targets[self.targets == 11] = 1      # 11,20,21,22:birds (birds: postitive(1); other: negative(0))
        new_targets[self.targets == 20] = 1
        new_targets[self.targets == 21] = 1
        new_targets[self.targets == 22] = 1
        self.targets = new_targets.tolist()
        #self.targets = new_targets.numpy()        
        
                
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list
    
    
    def __getitem__(self, idx):

        sample = self.data[idx]
        target = self.targets[idx]
        target = int(target)

        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.targets)
