import torch
import json
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data.sampler import Sampler
import os
from collections import defaultdict

from utils.tools import load_dict, set_random_seed

identity = lambda x: x


class SetDataset:
    def __init__(self, data_file, batch_size, transform):
        with open(data_file, 'r') as f:
            self.meta = json.load(f)

        self.cl_list = np.unique(self.meta['image_labels']).tolist()

        self.sub_meta = {}
        for cl in self.cl_list:
            self.sub_meta[cl] = []

        for x, y in zip(self.meta['image_names'], self.meta['image_labels']):
            self.sub_meta[y].append(x)

        self.sub_dataloader = []
        sub_data_loader_params = dict(batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=0,  # use main thread only or may receive multiple batches
                                      pin_memory=False)
        for cl in self.cl_list:
            sub_dataset = SubDataset(self.sub_meta[cl], cl, transform=transform)
            self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params))

    def __getitem__(self, i):
        return next(iter(self.sub_dataloader[i]))

    def __len__(self):
        return len(self.cl_list)


class SubDataset:
    def __init__(self, data_root, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity):
        self.data_root = data_root
        self.sub_meta = sub_meta
        self.cl = cl
        self.transform = transform
        self.target_transform = target_transform
        # self.image_paths = []

    def __getitem__(self, i):
        # print( '%d -%d' %(self.cl,i))
        file_path = os.path.join(self.data_root, self.sub_meta[i])
        # print(file_path)
        data = load_dict(file_path)
        data = self.transform(data)
        target = self.target_transform(self.cl)
        return data, target

    def __len__(self):
        return len(self.sub_meta)


class UserSetDataset:
    def __init__(self, data_file, users, n_way, n_shot, n_query, transform, mode, data_root,  **kwargs):

        # calculate batch size for each class
        self.batch_size = n_shot + n_query
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_user = len(users)
        self.users = users
        self.set_mode(mode)
        # generate class list
        self.cl_list = [cl for cl in range(n_way)]
        self.meta = load_dict(data_file)

        self.sub_meta = defaultdict(list)
        self.user_meta = defaultdict(defaultdict)
        for u in users:
            for cl in self.cl_list:
                self.user_meta[cl][u] = self.meta[u][cl]
                self.sub_meta[cl].extend(self.meta[u][cl])

        self.sub_dataloader = []
        self.samplers = []
        sub_data_loader_params = dict(batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=0,  # use main thread only or may receive multiple batches
                                      pin_memory=False)

        for cl in self.cl_list:
            if mode in ['p', 'd', 'm', 'meta']:
                sampler = RandomUserSampler(user_meta=self.user_meta[cl], n_shot=self.n_shot,
                                            n_query=self.n_query, n_user=self.n_user)
            else:
                sampler = RandomUserSamplerIndependent(user_meta=self.user_meta[cl], n_shot=self.n_shot,
                                                       n_query=self.n_query, n_user_support=self.n_user_support,
                                                       n_user_query=self.n_user_query)
            sub_dataset = SubDataset(data_root, self.sub_meta[cl], cl, transform=transform)
            self.sub_dataloader.append(
                torch.utils.data.DataLoader(sub_dataset, sampler=sampler, **sub_data_loader_params))
            self.samplers.append(sampler)

    def __getitem__(self, i):
        return next(iter(self.sub_dataloader[i]))

    def __len__(self):
        return len(self.cl_list)

    def set_mode(self, mode):
        self.mode = mode
        if mode == 'p':
            self.n_user = 1
            if len(self.users) < self.n_user:
                raise ValueError("Number of users is not enough to generate data!")

        elif mode == 'd':
            self.n_user = 5
            if len(self.users) < self.n_user:
                raise ValueError("Number of users is not enough to generate data!")

        elif mode == 'i':
            self.n_user_support = 5
            self.n_user_query = 1
            if len(self.users) < (self.n_user_support + self.n_user_query):
                raise ValueError("Number of users is not enough to generate data!")
        elif mode == 'm':
            self.n_user = 1
            if len(self.users) <  self.n_user:
                raise ValueError("Number of users is not enough to generate data!")
        elif mode == 'meta':
            self.n_user = 1
        else:
            raise ValueError('There is no mode {} in system'.format(mode))


class EpisodicBatchSampler(object):
    def __init__(self, n_classes, n_way, n_episodes, seed, samplers, mode, users_list = None):
        self.n_classes = n_classes
        self.n_way = n_way
        self.n_episodes = n_episodes
        self.seed = seed
        set_random_seed(self.seed[0])
        self.samplers = samplers
        self.mode = mode
        self.users_list = users_list

    def __len__(self):
        return self.n_episodes

    def __iter__(self):
        for i in range(self.n_episodes):
            # set_random_seed(self.seed[0])
            self.mode_sampler(self.mode, self.samplers)
            yield torch.randperm(self.n_classes)[:self.n_way]

    def mode_sampler(self, mode, samplers):
        sampler = samplers[0]
        if mode in ['p', 'd', 'm']:
            users_select = np.random.choice(sampler.users, size=sampler.n_user, replace=False)
            # print('Support users: {}, Query users: {}'.format(users_select, users_select))
            for s in samplers:
                s.set_users(users_select)
        elif mode == 'i':
            users_select = np.random.choice(sampler.users, size=(sampler.n_user_query + sampler.n_user_support),
                                            replace=False)
            users_support = users_select[:sampler.n_user_support]
            users_query = users_select[-sampler.n_user_query:]
            # print('Support users: {}, Query users: {}'.format(users_support, users_query))
            for s in samplers:
                s.set_users(users_support, users_query)
        elif mode == 'meta':
            for id, s in enumerate(samplers):
                users_select = np.random.choice(sampler.users, size=sampler.n_user, replace=False)
                s.set_users(users_select)



class RandomUserSampler(Sampler):
    def __init__(self, user_meta, n_shot, n_query, n_user, **kwargs):
        self.user_meta = user_meta
        self.n_user = n_user
        self.users = [u for u in self.user_meta.keys()]
        self.n_shot = n_shot
        self.n_query = n_query
        self.support_size = self.n_shot // self.n_user
        self.query_size = self.n_query // self.n_user
        self.users_select = []

    def set_users(self, users):
        self.users_select = users

    def __len__(self):
        return self.n_shot + self.n_query

    def __iter__(self):
        support_idxs = []
        query_idxs = []
        final_idxs = []
        u_startid = 0
        for u in self.user_meta.keys():
            if u in self.users_select:
                # print('user:',u)
                perm = torch.randperm(len(self.user_meta[u])) + u_startid
                sidx = perm[:self.support_size].tolist()
                qidx = perm[-self.query_size:].tolist()
                support_idxs.extend(sidx)
                query_idxs.extend(qidx)
            u_startid += len(self.user_meta[u])
        final_idxs.extend(support_idxs)
        final_idxs.extend(query_idxs)
        # print("Index: {}".format(final_idxs))
        return iter(final_idxs)


class RandomUserSamplerIndependent(Sampler):
    def __init__(self, user_meta, n_shot, n_query, n_user_support, n_user_query, **kwargs):
        self.user_meta = user_meta
        self.users = [u for u in self.user_meta.keys()]
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_user_support = n_user_support
        self.n_user_query = n_user_query
        self.support_size = self.n_shot // self.n_user_support
        self.query_size = self.n_query // self.n_user_query
        self.users_support = []
        self.users_query = []

    def set_users(self, us, uq):
        self.users_support = us
        self.users_query = uq

    def __len__(self):
        return self.n_shot + self.n_query

    def __iter__(self):
        support_idxs = []
        query_idxs = []
        final_idxs = []
        u_startid = 0
        for u in self.user_meta.keys():
            if u in self.users_support:
                perm = torch.randperm(len(self.user_meta[u])) + u_startid
                sidx = perm[:self.support_size].tolist()
                support_idxs.extend(sidx)

            if u in self.users_query:
                perm = torch.randperm(len(self.user_meta[u])) + u_startid
                qidx = perm[:self.query_size].tolist()
                query_idxs.extend(qidx)
            u_startid += len(self.user_meta[u])
        final_idxs.extend(support_idxs)
        final_idxs.extend(query_idxs)
        # print("Index: {}".format(final_idxs))
        return iter(final_idxs)
