import numpy as np
import scipy
from scipy.stats import t
from tqdm import tqdm

import util

import torch
from sklearn import metrics
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler


def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * t._ppf((1+confidence)/2., n-1)
    return m, h


def np_normalize(x):
    norm = np.sqrt(np.sum(np.power(x, 2), axis=1, keepdims=True))
    return x/norm


def normalize(x):
    norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2)
    out = x.div(norm)
    return out


def meta_test(net, testloader, opt, classifier='LR'):
    net = net.eval()
    acc = []

    with torch.no_grad():
        for idx, data in enumerate(tqdm(testloader)):
            task_data = list(map(lambda x: x[0], data))
            acc.append(meta_test_single(net, task_data, opt, classifier))
    return mean_confidence_interval(acc)


def meta_test_single(net, data, opt, classifier='LR'):
    support_xs, support_ys, query_xs, query_ys = data
    support_xs, query_xs = util.to_cuda_maybe([support_xs, query_xs])

    support_size = support_xs.shape[0]
    combined_xs = torch.cat([support_xs, query_xs])

    combined_feat = net.encode(combined_xs)

    if opt.is_norm:
        combined_feat = normalize(combined_feat)

    combined_feat = util.cuda_to_np(combined_feat)
    support_features = combined_feat[:support_size]
    query_features = combined_feat[support_size:]

    support_ys = support_ys.numpy()
    query_ys = query_ys.numpy()

    return base_learner([support_features, support_ys, query_features, query_ys], opt, classifier=classifier)


def base_learner(data, opt, classifier="LR"):
    support_features, support_ys, query_features, query_ys = data

    if classifier == 'LR':
        clf = LogisticRegression(penalty='l2',
                                 C=opt.test_C,
                                 random_state=0,
                                 solver='lbfgs',
                                 max_iter=1000,
                                 fit_intercept=opt.use_bias,
                                 multi_class='multinomial')
        clf.fit(support_features, support_ys)
        query_ys_pred = clf.predict(query_features)
    elif classifier == 'SVM':
        clf = make_pipeline(StandardScaler(), SVC(gamma='auto',
                                                  C=1,
                                                  kernel='linear',
                                                  decision_function_shape='ovr'))
        clf.fit(support_features, support_ys)
        query_ys_pred = clf.predict(query_features)
    elif classifier == 'NN':
        query_ys_pred = NN(support_features, support_ys, query_features)
    elif classifier == 'Cosine':
        query_ys_pred = Cosine(support_features, support_ys, query_features)
    elif classifier == 'Proto':
        query_ys_pred = Proto(support_features, support_ys, query_features, opt)
    else:
        raise NotImplementedError('classifier not supported: {}'.format(classifier))

    return metrics.accuracy_score(query_ys, query_ys_pred)


def Proto(support, support_ys, query, opt):
    """Protonet classifier"""
    nc = support.shape[-1]
    support = np.reshape(support, (-1, 1, opt.n_ways, opt.n_shots, nc))
    support = support.mean(axis=3)
    batch_size = support.shape[0]
    query = np.reshape(query, (batch_size, -1, 1, nc))
    logits = - ((query - support)**2).sum(-1)
    pred = np.argmax(logits, axis=-1)
    pred = np.reshape(pred, (-1,))
    return pred


def NN(support, support_ys, query):
    """nearest classifier"""
    support = np.expand_dims(support.transpose(), 0)
    query = np.expand_dims(query, 2)

    diff = np.multiply(query - support, query - support)
    distance = diff.sum(1)
    min_idx = np.argmin(distance, axis=1)
    pred = [support_ys[idx] for idx in min_idx]
    return pred


def Cosine(support, support_ys, query):
    """Cosine classifier"""
    support_norm = np.linalg.norm(support, axis=1, keepdims=True)
    support = support / support_norm
    query_norm = np.linalg.norm(query, axis=1, keepdims=True)
    query = query / query_norm

    cosine_distance = query @ support.transpose()
    max_idx = np.argmax(cosine_distance, axis=1)
    pred = [support_ys[idx] for idx in max_idx]
    return pred
