import os
import sys
import glob
import random

from functools import partial

import numpy as np
from PIL import Image

import torch
import torchvision.transforms as transforms

from torchnet.dataset import ListDataset, TransformDataset
from torchnet.transform import compose

import protonets
from protonets.data.base import TensorTransform, convert_dict, CudaTransform, EpisodicBatchSampler, SequentialBatchSampler
import sys
import pickle

MINIIMAGENET_DATA_DIR  = os.path.join(sys.path[0], '../../../data/miniimagenet-fewshot')
MINIIMAGENET_CACHE = { }
# random.seed(0)

def load_image_path(key, out_field, d):
    d[out_field] = Image.open(d[key])
    return d

def convert_tensor(key, d):
    d[key] = 1.0 - torch.from_numpy(np.array(d[key], np.float32, copy=False)).transpose(0, 1).contiguous().view(1, d[key].size[0], d[key].size[1])
    return d

def rotate_image(key, rot, d):
    d[key] = d[key].rotate(rot)
    return d

def scale_image(key, height, width, d):
    d[key] = d[key].resize((height, width))
    return d

def load_class_images(d):
    if d['class'] not in MINIIMAGENET_CACHE:
        alphabet, character, rot = d['class'].split('/')
        image_dir = os.path.join(MINIIMAGENET_DATA_DIR, 'data', alphabet, character)

        class_images = sorted(glob.glob(os.path.join(image_dir, '*.png')))
        if len(class_images) == 0:
            raise Exception("No images found for omniglot class {} at {}. Did you run download_omniglot.sh first?".format(d['class'], image_dir))

        image_ds = TransformDataset(ListDataset(class_images),
                                    compose([partial(convert_dict, 'file_name'),
                                             partial(load_image_path, 'file_name', 'data'),
                                             partial(rotate_image, 'data', float(rot[3:])),
                                             partial(scale_image, 'data', 28, 28),
                                             partial(convert_tensor, 'data')]))

        loader = torch.utils.data.DataLoader(image_ds, batch_size=len(image_ds), shuffle=False)

        for sample in loader:
            MINIIMAGENET_CACHE[d['class']] = sample['data']
            break # only need one sample because batch size equal to dataset length
    
    return { 'class': d['class'], 'data': MINIIMAGENET_CACHE[d['class']] }

def extract_episode_for_one_class(n_support, n_query, d):
    # data: N x C x H x W
    n_examples = d['data'].shape[0]

    if n_query == -1:
        n_query = n_examples - n_support

    example_inds = torch.randperm(n_examples)[:(n_support+n_query)]
    support_inds = example_inds[:n_support]
    query_inds = example_inds[n_support:]

    xs = d['data'][support_inds]
    xq = d['data'][query_inds]

    return {
        'class': d['class'],
        'xs': xs,
        'xq': xq
    }

def collate(data):
    batch_data = {}
    for k in data[0]:
        if type(data[0][k]) == torch.Tensor:
            batch_data[k] = torch.vstack([d[k].unsqueeze(0) for d in data])
        else:
            batch_data[k] = [d[k] for d in data]
    return batch_data

def extract_episodes(n_way, n_support, n_query, n_episodes, samples_by_class):
    episodes = []
    for i in range(n_episodes):
        indices = random.sample(range(len(samples_by_class)), n_way)
        episodes.append(collate([extract_episode_for_one_class(n_support, n_query, samples_by_class[i]) for i in indices]))
    return episodes
        
def load_deprecated(opt, splits):
    # split_dir = os.path.join(MINIIMAGENET_DATA_DIR, 'splits', opt['data.split'])

    ret = { }
    for split in splits:
        if split in ['val', 'test'] and opt['data.test_way'] != 0:
            n_way = opt['data.test_way']
        else:
            n_way = opt['data.way']

        if split in ['val', 'test'] and opt['data.test_shot'] != 0:
            n_support = opt['data.test_shot']
        else:
            n_support = opt['data.shot']

        if split in ['val', 'test'] and opt['data.test_query'] != 0:
            n_query = opt['data.test_query']
        else:
            n_query = opt['data.query']

        if split in ['val', 'test']:
            n_episodes = opt['data.test_episodes']
        else:
            n_episodes = opt['data.train_episodes']

        samples = []
        with open(os.path.join(MINIIMAGENET_DATA_DIR, f'mini-imagenet-cache-{split}.pkl'), 'rb')as f:
            data = pickle.load(f)
            data['image_data'] = data['image_data'].transpose((0,3,1,2))
            for c in data['class_dict']:
                samples.append({
                    'class': c,
                    'data': torch.tensor(data['image_data'][data['class_dict'][c]]).float()
                })


        transforms_ = [partial(extract_episode_for_one_class, n_support, n_query), TensorTransform()]
        # if opt['data.cuda']:
        #     transforms_.append(CudaTransform())
        transforms_ = compose(transforms_)
        ds = TransformDataset(samples, transforms_)

        if opt['data.sequential']:
            sampler = SequentialBatchSampler(len(ds))
        else:
            sampler = EpisodicBatchSampler(len(ds), n_way, n_episodes)

        # use num_workers=0, otherwise may receive duplicate episodes
        ret[split] = torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=0)

    return ret


def load(opt, splits):
    # split_dir = os.path.join(MINIIMAGENET_DATA_DIR, 'splits', opt['data.split'])
    random.seed(opt['seed'])
    ret = { }
    for split in splits:
        if split in ['val', 'test'] and opt['data.test_way'] != 0:
            n_way = opt['data.test_way']
        else:
            n_way = opt['data.way']

        if split in ['val', 'test'] and opt['data.test_shot'] != 0:
            n_support = opt['data.test_shot']
        else:
            n_support = opt['data.shot']

        if split in ['val', 'test'] and opt['data.test_query'] != 0:
            n_query = opt['data.test_query']
        else:
            n_query = opt['data.query']

        if split in ['val', 'test']:
            n_episodes = opt['data.test_episodes']
        else:
            n_episodes = opt['data.train_episodes']

        featurestransform = [
            transforms.Resize(84),
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]
        featurestransform = compose(featurestransform)

        samples = []
        with open(os.path.join(MINIIMAGENET_DATA_DIR, f'mini-imagenet-cache-{split}.pkl'), 'rb')as f:
            data = pickle.load(f)
            data['image_data'] = list(map(lambda x: featurestransform(Image.fromarray(x)), data['image_data']))
            data['image_data'] = torch.stack(data['image_data'], 0)
            for c in data['class_dict']:
                samples.append({
                    'class': c,
                    'data': data['image_data'][data['class_dict'][c]]
                })

        transforms_ = [TensorTransform()]
        # if opt['data.cuda']:
        #     transforms_.append(CudaTransform())
        transforms_ = compose(transforms_)
        ds = TransformDataset(samples, transforms_)
        print(f'{split} data class num: ', len(ds))
        episodes = extract_episodes(n_way, n_support, n_query, n_episodes, ds)
        print(f'{split} support data size: ', episodes[0]['xs'].size())
        print(f'{split} query data size: ', episodes[0]['xq'].size())
        print(f'{split} num of episode:', len(episodes))

        if opt['data.sequential']:
            raise Warning("data.sequential no longer supported")

        # only batch_size=1 is allowed
        ret[split] = torch.utils.data.DataLoader(episodes, batch_size=1, shuffle=True, collate_fn=lambda data:data[0])

    return ret

# %%
if __name__ == '__main__':
    opt = {
        'data.way': 5,
        'data.shot': 5,
        'data.query': 5,
        'data.train_episodes': 10,
        'data.test_episodes': 10,
        'data.cuda': False,
        'data.sequential': False

    }
    splits = ['train']
    res = load(opt, splits)
    for split in splits:
        print(split)
        for batch in res[split]:
            for k in batch:
                print(k)
                if k == 'class':
                    print(batch[k])
                else:
                    print(batch[k].size())
                    print(type(batch[k]))
            break