import numpy as np
from scipy.linalg import block_diag


import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn import functional as F

import util


def mean(arr):
    return sum(arr)/len(arr)


def embed_mask(n_ways, n_shots):
    ones = np.ones([n_shots, n_shots])
    arrs = [ones]*n_ways
    return block_diag(*arrs)


class MetaLabeler(nn.Module):
    def __init__(self, backbone, args, feat_dim, intercept=True, extra_reg=True):
        super(MetaLabeler, self).__init__()

        #parameters for
        self.n_ways = args.n_ways
        self.lam = args.lam
        self.n_shots = args.n_shots
        self.n_queries = args.n_queries
        self.n_per_class = self.n_shots + self.n_queries
        self.backbone = backbone
        self.intercept = intercept
        self.feat_dim = feat_dim
        self.extra_reg = extra_reg

        support_embed = embed_mask(self.n_ways, self.n_shots) #*args.n_aug_support_samples
        query_embed = embed_mask(self.n_ways, args.n_queries)
        combined_embed = block_diag(support_embed, query_embed)
        embed_label = combined_embed * 2 - np.ones_like(combined_embed)
        self.embed_label = util.np_to_cuda(embed_label)

        # parameters for labeling
        self.K = args.K
        self.centroid_counter = np.zeros([self.K])
        self.filtering_counter = np.zeros([self.K])

        self.register_buffer("centroid", torch.zeros(self.K, self.feat_dim))
        self.dict_lookup = 0
        self.ptr = 0

        self.class_count = np.zeros([self.K])
        self.cur_class_count = np.zeros([self.K])

    def add_intercept(self, xs):
        return torch.cat([xs, torch.ones([xs.shape[0], 1]).cuda()], 1)

    def add_classifier(self, K):
        self.classifier = nn.Linear(self.feat_dim, K)

    def proc_input(self, support_xs, support_ys, query_xs, query_ys):
        offset = support_xs.shape[0]

        combined_xs = torch.cat([support_xs, query_xs])
        combined_ys = torch.cat([support_ys, query_ys])

        combined_feat = self.encode(combined_xs)
        return combined_feat, combined_ys, offset

    def forward(self, support_xs, support_ys, query_xs, query_ys):
        combined_feat, combined_ys, offset = self.proc_input(support_xs, support_ys, query_xs, query_ys)
        l2_dist = torch.cdist(combined_feat, combined_feat)
        embed_loss = F.hinge_embedding_loss(l2_dist, self.embed_label, margin=1.)

        if self.intercept:
            combined_feat = self.add_intercept(combined_feat)

        ideal_weights = self.least_square(combined_feat, combined_ys)

        support_feat = combined_feat[:offset]
        query_feat = combined_feat[offset:]

        side_info = self.least_square(support_feat, support_ys)
        query_loss, query_acc = self.predict_on_task(query_feat, query_ys, side_info)

        # sub_cls, sub_info = self.sub_pred(support_feat)
        if self.extra_reg:
            full_loss = query_loss + embed_loss + F.mse_loss(side_info, ideal_weights)
                    # + F.mse_loss(sub_info, ideal_weights[:, sub_cls])
        else:
            full_loss = query_loss

        return full_loss, query_acc

    def forward_sup(self, xs, ys):
        combined_feat = self.encode(xs)

        logits = self.classifier(combined_feat)
        guess_loss = F.cross_entropy(logits, ys)
        return guess_loss, 0

    def sub_pred(self, support_feat):
        n = np.random.randint(2, self.n_ways)
        subset = np.random.choice(np.arange(self.n_ways), n, False)
        feat_dim = support_feat.shape[-1]
        sub_feat = support_feat.view(self.n_ways, -1, feat_dim)[subset].view(-1, feat_dim)
        sub_ys = torch.from_numpy(np.repeat(np.arange(n), self.n_shots)).cuda()

        return subset, self.least_square(sub_feat, sub_ys, n_ways=n)

    def encode(self, feat_in):
        return self.backbone.encode(feat_in)

    def group_by_class(self, combined_feat, offset):
        support_feat = combined_feat[:offset]
        query_feat = combined_feat[offset:]
        support_feat = support_feat.view(self.n_ways, -1, self.feat_dim)
        query_feat = query_feat.view(self.n_ways, -1, self.feat_dim)
        return torch.cat([support_feat, query_feat], 1)

    def least_square(self, X, y, lam=None, n_ways=None):
        X_t = torch.transpose(X, 0, 1)
        eye = torch.eye(X.shape[0]).cuda()
        if not lam:
            lam = self.lam
        A = lam * eye + torch.matmul(X, X_t)
        if not n_ways:
            n_ways = self.n_ways
        y_one_hot = F.one_hot(y, n_ways).float()
        tmp = torch.matmul(X_t, torch.inverse(A))
        weights = torch.matmul(tmp, y_one_hot)

        return weights

    def predict_on_task(self, xs, ys, weights):
        y_pred = torch.matmul(xs, weights)
        y_target = F.one_hot(ys, self.n_ways).float()
        loss = F.mse_loss(y_pred, y_target)

        acc = util.accuracy(y_pred, ys)[0].item()

        return loss, acc

    @torch.no_grad()
    def populate_queue(self, keys, idx=None):
        if idx is None:
            batch_size = keys.shape[0]

            assert self.K % batch_size == 0  # for simplicity
            self.centroid[self.ptr:self.ptr + batch_size] = keys
            self.ptr = (self.ptr + batch_size) % self.K

        else:
            self.centroid[idx] = keys

    @torch.no_grad()
    def cluster_task(self, xs):
        class_keys, class_labels = self.pseudo_labels(xs)
        if class_labels is not None:
            keys = self.centroid[class_labels]

            np_idx = util.cuda_to_np(class_labels).astype(np.int)
            before = self.centroid_counter[np_idx]
            self.filtering_counter[np_idx] += 1
            after = before + 1
            self.centroid_counter[np_idx] = after
            ratio = before/after
            keys = keys * util.np_to_cuda(ratio).view(-1, 1) + class_keys / util.np_to_cuda(after).view(-1, 1)

            self.populate_queue(keys, class_labels)
            return True
        return False

    @torch.no_grad()
    def init_centroid(self, xs):
        combined_feat = self.encode(xs)
        feat_by_class = combined_feat.reshape(self.n_ways, -1, self.feat_dim)
        feat_by_class = feat_by_class.mean(1)
        self.populate_queue(feat_by_class)

    @torch.no_grad()
    def remove_cluster(self, n_ways, std_factor):
        counter = self.filtering_counter
        p = n_ways / counter.shape[0]
        mean = np.mean(counter)
        std = np.sqrt(p * (1 - p) * np.sum(counter) / n_ways)
        threshold = np.max([mean - std_factor * std, 1])

        idx = counter > threshold

        self.centroid = self.centroid[idx]
        self.centroid_counter = self.centroid_counter[idx]
        self.K = self.centroid.shape[0]
        self.filtering_counter = np.zeros(self.K)

    def load_cluster(self, file):
        self.centroid = torch.load(file)
        self.K = self.centroid.shape[0]


    @torch.no_grad()
    def pseudo_labels(self, combined_xs, topk=1):
        combined_feat = self.encode(combined_xs)

        feat_by_class = combined_feat.reshape(self.n_ways, -1, self.feat_dim)
        class_keys = feat_by_class.mean(1)
        class_vote = torch.cdist(class_keys, self.centroid)

        vals, idx = torch.topk(class_vote, topk, largest=False)
        top_choice = idx[:, 0].cuda()

        if torch.unique(top_choice).shape[0] == self.n_ways:
            return class_keys, top_choice
        return None, None

    def label_samples(self, combined_xs):
        class_labels = self.pseudo_labels(combined_xs)[1]
        if class_labels is not None:
            return torch.repeat_interleave(class_labels, self.n_per_class), class_labels
        return None, None


class FlatDataset(Dataset):
    def __init__(self, input_dims):
        super(Dataset, self).__init__()
        self.input_dims = input_dims
        self.reset()

    def reset(self):
        self.xs = []
        self.ys = []

    def add_task(self, xs, ys):
        self.xs.append(xs)
        self.ys.append(ys)

    def merge_tasks(self):
        self.xs = torch.cat(self.xs)
        self.ys = torch.cat(self.ys)
        assert self.xs.shape[0] == self.ys.shape[0]

    def __getitem__(self, item):
        return self.xs[item], self.ys[item], item

    def __len__(self):
        return self.ys.shape[0]


class SampleBuffer(object):
    def __init__(self, input_dims, labeler):
        self.xs = torch.zeros(0, *input_dims).cuda()
        self.ys = torch.zeros(0).long().cuda()
        self.labeler = labeler
        self.ptr = 0

    def add_batch(self, xs):
        labels, _ = self.labeler.label_samples(xs)
        if labels is not None:
            self.xs = torch.cat([self.xs, xs])
            self.ys = torch.cat([self.ys, labels.long()])
            return True
        return False

    def sample_and_remove(self, batch_size):
        db_size = self.ys.shape[0]
        idx = range(db_size)
        batch_idx = np.random.choice(idx, batch_size, replace=False)
        xs = self.xs[batch_idx]
        ys = self.ys[batch_idx]

        t_idx = np.ones(db_size)
        t_idx[batch_idx] = 0
        t_idx = t_idx.astype(np.bool)

        self.xs = self.xs[t_idx]
        self.ys = self.ys[t_idx]

        return xs, ys

    def size(self):
        return self.ys.shape[0]