from benchmark.toolkits import DefaultTaskGen

from torchvision import datasets, transforms
from benchmark.toolkits import ClassificationCalculator as TaskCalculator
from benchmark.toolkits import IDXTaskPipe as TaskPipe
from benchmark.toolkits import DefaultTaskGen
from torch.utils.data import Dataset
import os, urllib
import numpy as np
import torch
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import os, glob
import glob
import os
from shutil import move
from os import rmdir
from torchvision.io import read_image, ImageReadMode
class TaskGen(DefaultTaskGen):
    def __init__(self, dist_id, num_clients = 1, skewness = 0.5, local_hld_rate=0.2, seed=0):
        super(TaskGen, self).__init__(benchmark='tiny_classification',
                                      dist_id=dist_id,
                                      num_clients=num_clients,
                                      skewness=skewness,
                                      rawdata_path='./benchmark/RAW_DATA/tiny-imagenet-200',
                                      local_hld_rate=local_hld_rate,
                                      seed=seed
                                      )
        self.num_classes = 200
        self.taskname = self.get_taskname()
        self.taskpath = os.path.join(self.task_rootpath, self.taskname)
        self.save_task = TaskPipe.save_task
        self.visualize = self.visualize_by_class
        self.source_dict = {
            'class_path': 'benchmark.tiny_classification.core',
            'class_name': 'Tiny',
            'train_args': {
                'root': '"'+self.rawdata_path+'"',
                'download': 'True',
                # 'transform': 'transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])',
                'train': 'True'
            },
            'test_args': {
                'root': '"'+self.rawdata_path+'"',
                'download': 'True',
                # 'transform': 'transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])',
                'train': 'False'
            }
        }


    def load_data(self):
        target_folder = './benchmark/RAW_DATA/tiny-imagenet-200/val/'
        val_dict = {}
        with open('./benchmark/RAW_DATA/tiny-imagenet-200/val/val_annotations.txt', 'r') as f:
            for line in f.readlines():
                split_line = line.split('\t')
                val_dict[split_line[0]] = split_line[1]

        paths = glob.glob('./benchmark/RAW_DATA/tiny-imagenet-200/val/images/*')
        for path in paths:
            file = path.split('/')[-1]
            folder = val_dict[file]
            if not os.path.exists(target_folder + str(folder)):
                os.mkdir(target_folder + str(folder))
                os.mkdir(target_folder + str(folder) + '/images')

        for path in paths:
            file = path.split('/')[-1]
            folder = val_dict[file]
            dest = target_folder + str(folder) + '/images/' + str(file)
            move(path, dest)

        # rmdir('./benchmark/RAW_DATA/tiny-imagenet-200/val/images')
        self.train_data, self.test_data, _ = self.tiny_loader(batch_size=64)
        # self.test_data = Tiny(self.rawdata_path, train=False)

    def tiny_loader(self, batch_size=64, data_dir='./benchmark/RAW_DATA/tiny-imagenet-200/'):
        num_label = 200
        normalize = transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
        transform_train = transforms.Compose(
            [transforms.RandomHorizontalFlip(), transforms.ToTensor(),
             normalize, ])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize, ])
        trainset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform_train)
        testset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=transform_test)
        # train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
        # test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, pin_memory=True)
        return trainset, testset, num_label

