# %%
import os
import sys
import glob
import random

from functools import partial

import numpy as np
from PIL import Image

import torch
import torchvision.transforms as transforms

from torchnet.dataset import ListDataset, TransformDataset
from torchnet.transform import compose

import protonets
from protonets.data.base import TensorTransform, convert_dict, CudaTransform, EpisodicBatchSampler, SequentialBatchSampler
import sys
import pickle

CIFAR_DATA_DIR  = os.path.join(sys.path[0], '../../../data/cifar-100-fewshot')
# CIFAR_DATA_DIR = '../../data/cifar-100-fewshot'
# MINIIMAGENET_CACHE = { }
# random.seed(0)
# # %%
# data = np.load(os.path.join(CIFAR_DATA_DIR, 'few-shot-test.npz'))
# print(data['features'][0])
# %%
def load_image_path(key, out_field, d):
    d[out_field] = Image.open(d[key])
    return d

def convert_tensor(key, d):
    d[key] = 1.0 - torch.from_numpy(np.array(d[key], np.float32, copy=False)).transpose(0, 1).contiguous().view(1, d[key].size[0], d[key].size[1])
    return d

def rotate_image(key, rot, d):
    d[key] = d[key].rotate(rot)
    return d

def scale_image(key, height, width, d):
    d[key] = d[key].resize((height, width))
    return d

def extract_episode_for_one_class(n_support, n_query, d):
    # data: N x C x H x W
    n_examples = d['data'].shape[0]

    if n_query == -1:
        n_query = n_examples - n_support

    example_inds = torch.randperm(n_examples)[:(n_support+n_query)]
    support_inds = example_inds[:n_support]
    query_inds = example_inds[n_support:]

    xs = d['data'][support_inds]
    xq = d['data'][query_inds]

    return {
        'class': d['class'],
        'xs': xs,
        'xq': xq
    }

def collate(data):
    batch_data = {}
    for k in data[0]:
        if type(data[0][k]) == torch.Tensor:
            batch_data[k] = torch.vstack([d[k].unsqueeze(0) for d in data])
        else:
            batch_data[k] = [d[k] for d in data]
    return batch_data

def extract_episodes(n_way, n_support, n_query, n_episodes, samples_by_class):
    episodes = []
    for i in range(n_episodes):
        indices = random.sample(range(len(samples_by_class)), n_way)
        episodes.append(collate([extract_episode_for_one_class(n_support, n_query, samples_by_class[i]) for i in indices]))
    return episodes


def load(opt, splits):
    # split_dir = os.path.join(MINIIMAGENET_DATA_DIR, 'splits', opt['data.split'])
    random.seed(opt['seed'])
    ret = { }
    for split in splits:
        if split in ['val', 'test'] and opt['data.test_way'] != 0:
            n_way = opt['data.test_way']
        else:
            n_way = opt['data.way']

        if split in ['val', 'test'] and opt['data.test_shot'] != 0:
            n_support = opt['data.test_shot']
        else:
            n_support = opt['data.shot']

        if split in ['val', 'test'] and opt['data.test_query'] != 0:
            n_query = opt['data.test_query']
        else:
            n_query = opt['data.query']

        if split in ['val', 'test']:
            n_episodes = opt['data.test_episodes']
        else:
            n_episodes = opt['data.train_episodes']
        
        featurestransform = [
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]
        featurestransform = compose(featurestransform)

        samples = []
        data = np.load(os.path.join(CIFAR_DATA_DIR, f'few-shot-{split}.npz'))
        features = list(map(lambda x: featurestransform(Image.fromarray(x)), data['features']))
        features = torch.stack(features, 0)
        targets = data['targets']
        for c in np.unique(targets):
            samples.append({
                'class': str(c),
                'data': features[targets==c]
            })
        transforms_ = [TensorTransform()]
        if opt['data.cuda']:
            transforms_.append(CudaTransform())

        transforms_ = compose(transforms_)

        # class_names = []
        # with open(os.path.join(split_dir, "{:s}.txt".format(split)), 'r') as f:
        #     for class_name in f.readlines():
        #         class_names.append(class_name.rstrip('\n'))
        # ds = TransformDataset(ListDataset(class_names), transforms)
        ds = TransformDataset(samples, transforms_)
        print(f'{split} data class num: ', len(ds))
        episodes = extract_episodes(n_way, n_support, n_query, n_episodes, ds)
        print(f'{split} support data size: ', episodes[0]['xs'].size())
        print(f'{split} query data size: ', episodes[0]['xq'].size())
        print(f'{split} num of episode:', len(episodes))
        # print('data xs:', episodes[0]['xs'].size())
        # print('data xq:', episodes[0]['xq'].size())
        # print('data class:', len(episodes[0]['class']))


        if opt['data.sequential']:
            raise Warning("data.sequential no longer supported")

        # only batch_size=1 is allowed
        ret[split] = torch.utils.data.DataLoader(episodes, batch_size=1, shuffle=True, collate_fn=lambda data:data[0])

    return ret

# %%
if __name__ == '__main__':
    opt = {
        'seed': 0,
        'data.way': 5,
        'data.shot': 5,
        'data.query': 5,
        'data.train_episodes': 10,
        'data.test_episodes': 10,
        'data.cuda': False,
        'data.sequential': False
    }
    splits = ['train']
    res = load(opt, splits)
    for split in splits:
        print(split)
        for batch in res[split]:
            for k in batch:
                print(k)
                if k == 'class':
                    print(batch[k])
                else:
                    print(batch[k].size())
                    print(type(batch[k]))
            break
# %%
