import os
import json
import torch
import random
import shutil
import numpy as np
from utils import *
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
from market.cfe import CFE

def get_splits(role):
    return ['train', 'eval', 'test'] if role == 'learnware' else ['specification', 'test']

def get_transform(split):
    resize    = transforms.Resize((224, 224))
    flip      = transforms.RandomHorizontalFlip()
    totensor  = transforms.ToTensor()
    normalize = transforms.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
    if split == 'train':
        return transforms.Compose([resize, flip, totensor, normalize])
    return     transforms.Compose([resize,       totensor, normalize])

class Task:
    def __init__(self, cfg, role, role_id):
        data_path = os.path.join(cfg['dataset_path'], 'dataset', role)
        task_file = os.path.join(cfg['dataset_path'], 'tasks',   role, f'{role_id}.json')

        if not os.path.exists(task_file):
            self.create_task(task_file, data_path)
        else:
            with open(task_file) as f:
                self.class_to_paths = json.load(f)

    def create_task(self, task_file, data_path, N_class=60):
        os.makedirs(os.path.dirname(task_file), exist_ok=True)

        n_class = random.randint(30, 40)
        classes = random.sample(range(N_class), n_class)

        self.class_to_paths = dict()

        for class_id in classes:
            class_path = os.path.join(data_path, str(class_id))
            N_files = len(os.listdir(class_path))
            n_files = int(0.9 * N_files)
            self.class_to_paths[class_id] = random.sample(range(1, 1 + N_files), n_files)

        with open(task_file, 'w') as f:
            json.dump(self.class_to_paths, f)

    def __iter__(self):
        return iter(self.class_to_paths.items())

    def __len__(self):
        return len(self.class_to_paths)

class ImageFolder:
    def __init__(self, cfg, role, role_id):
        self.role = role
        self.role_id = role_id
        self.splits = get_splits(role)
        self.ratios = [0.8, 0.1, 0.1] if role == 'learnware' else [0.2, 0.8]
        data_path = os.path.join(cfg['dataset_path'], 'dataset', role)
        imagefolder_path = os.path.join(cfg['dataset_path'], 'imagefolder', role, str(role_id))

        if not os.path.exists(imagefolder_path):
            self.create_imagefolder(cfg, data_path, imagefolder_path)

        self.classes = sorted(os.listdir(os.path.join(imagefolder_path, 'test')))

        self.imagefolders = {
            split: datasets.ImageFolder(
                os.path.join(imagefolder_path, split),
                transform=get_transform(split),
                target_transform=self._target_transform
            ) for split in self.splits
        }

    def _target_transform(self, x):
        return int(self.classes[int(x)])

    def create_imagefolder(self, cfg, data_path, imagefolder_path):
        task = Task(cfg, self.role, self.role_id)
        bar = tqdm(task)
        bar.set_description(f'({self.role} / {self.role_id})')

        for class_id, image_ids in bar:
            n = len(image_ids)
            for i, image_id in enumerate(image_ids):
                if i < int(n * self.ratios[0]):
                    split = self.splits[0]
                elif self.role == 'learnware' and i < int(n * self.ratios[0]) + int(n * self.ratios[1]):
                    split = self.splits[1]
                else:
                    split = self.splits[-1]

                src_path = os.path.join(data_path, str(class_id), f'{image_id}.jpg')
                dst_path = os.path.join(imagefolder_path, split, str(class_id), f'{image_id}.jpg')
                os.makedirs(os.path.dirname(dst_path), exist_ok=True)
                shutil.copy(src_path, dst_path)

class NICODataset:
    def __init__(self, cfg, role, role_id, batch_size=2048):
        imagefolders = ImageFolder(cfg, role, role_id)
        self.dataloaders = {
            split: DataLoader(
                imagefolders.imagefolders[split],
                batch_size=batch_size,
                shuffle=(split == 'train'),
                num_workers=4,
                pin_memory=True
            ) for split in get_splits(role)
        }

    def nico_loader(self, split):
        return self.dataloaders[split]

class NICOFeature:
    def __init__(self, cfg, role, role_id):
        self.feature_path = os.path.join(cfg['dataset_path'], 'features', role, str(role_id))
        self.splits = get_splits(role)
        self.batch_size = cfg['batch_size']
        if not os.path.exists(self.feature_path):
            self.cfe = CFE()
            self.device = torch.device(select_device(cfg, role_id))
            self.cfe.to(self.device)
            os.makedirs(self.feature_path, exist_ok=True)
            self.nicodataset = NICODataset(cfg, role, role_id)
            for split in self.splits:
                self.__feature_extraction(split, role, role_id)

        self.nico_features = { split: torch.from_numpy(np.load(os.path.join(self.feature_path, f'{split}_features.npy'))) for split in self.splits }
        self.nico_labels   = { split: torch.from_numpy(np.load(os.path.join(self.feature_path, f'{split}_labels.npy')))   for split in self.splits }

    def __feature_extraction(self, split, role, role_id):
        dataloader = self.nicodataset.nico_loader(split)
        features = []
        labels = []
        for inputs, label in tqdm(dataloader, desc=f'{role} {role_id} {split}'):
            inputs = inputs.to(self.device)
            feature = self.cfe(inputs)
            features.append(feature)
            labels.append(label.numpy())
        features = torch.cat(features)
        labels = np.concatenate(labels)
        np.save(os.path.join(self.feature_path, f'{split}_features.npy'), features.cpu().numpy())
        np.save(os.path.join(self.feature_path, f'{split}_labels.npy'),   labels)

    def specification_data(self):
        return self.get_data(self.splits[0])

    def get_data(self, split):
        return self.nico_features[split], self.nico_labels[split]

    def get_loader(self, split=None):
        if split is None:
            return {
                split: self.get_loader(split)
                for split in self.splits
            }
        features = self.nico_features[split]
        labels   = self.nico_labels[split]
        dataset = TensorDataset(features, labels)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=(split == 'train'), num_workers=2)