import numpy as np
import torch
from enum import Enum
from torch.nn.functional import one_hot
from itertools import islice


def accuracy(targ, pred):
    correct_predictions = (torch.argmax(pred, dim=1) == torch.argmax(targ, dim=1)).to(torch.float32)
    return torch.mean(correct_predictions).item()

def accuracy_mse(targ, pred):
    pred = torch.softmax(pred, dim=-1)
    diff = (targ - pred)**2
    return torch.mean(diff).item()

def label_gradient_alignment(mat, labels):
    mat_normalized = mat - torch.mean(mat)
    labels_normalized = torch.matmul(labels, labels.T)
    labels_normalized[labels_normalized<1] = -1
    labels_normalized = labels_normalized - torch.mean(labels_normalized)

    score = mat_normalized * labels_normalized / (torch.norm(mat_normalized, 2) * torch.norm(labels_normalized, 2))
    return torch.sum(score).item()

def frobenius_norm(mat):
    return torch.norm(mat, p="fro").item()

def mean(mat):
    return torch.mean(mat).item()

def conditional_number(mat):
    eigenvalues = torch.linalg.eigvalsh(mat, UPLO='U')
    return np.nan_to_num((eigenvalues[-1].item() / eigenvalues[0]).item(), copy=True, nan=100000.0)

def eigenvalue_score(mat):
    eigenvalues = torch.linalg.eigvalsh(mat, UPLO='U')
    k = 1 # 1e-5
    return -torch.sum(torch.log(eigenvalues + k) + 1. / (eigenvalues + k)).item()


class MetricType(Enum):
    ACC = accuracy
    MSE = accuracy_mse
    FRO = frobenius_norm
    MEAN = mean
    COND = conditional_number
    EIG = eigenvalue_score
    LGA = label_gradient_alignment

    def require_only_matrix(self):
        if self is MetricType.FRO:
            return True
        if self is MetricType.MEAN:
            return True
        if self is MetricType.COND:
            return True
        if self is MetricType.EIG:
            return True
        return False

def slight_train(network, train_loader, train_iters, device):
    network.train()
    opt = torch.optim.Adam(network.parameters(), lr=1e-3)
    loss_func = torch.nn.CrossEntropyLoss()
    for i, (inputs, targets) in enumerate(train_loader):
        if i > train_iters:
            break
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        logit = network(inputs)
        if isinstance(logit, tuple):
            logit = logit[1]  # 201 networks: return features and logits
        loss = loss_func(logit, targets)

        opt.zero_grad()
        loss.backward()
        opt.step()

def get_nngp_n(train_loader, valid_loader, networks, train_mode=False, num_batch=-1, verbose=False, gpu=None):
    if gpu is not None:
        device = torch.device('cuda:{}'.format(gpu))
    else:
        device = torch.device('cpu')
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()

    train_logits = [[] for _ in range(len(networks))]
    train_targets = [[] for _ in range(len(networks))]
    for i, (inputs, targets) in enumerate(train_loader):
        if num_batch > 0 and i >= num_batch:
            break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            train_logits[net_idx].append(logit.detach())
            train_targets[net_idx].append(targets.detach())
            torch.cuda.empty_cache()

    valid_logits = [[] for _ in range(len(networks))]
    valid_targets = [[] for _ in range(len(networks))]
    for i, (inputs, targets) in enumerate(valid_loader):
        if num_batch > 0 and i >= num_batch:
            break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            valid_logits[net_idx].append(logit)
            valid_targets[net_idx].append(targets)
            torch.cuda.empty_cache()
    ######
    train_logits = [torch.concat(l, 0) for l in train_logits]
    train_targets = [torch.concat(t, 0) for t in train_targets]
    valid_logits = [torch.concat(l, 0) for l in valid_logits]
    valid_targets = [torch.concat(t, 0) for t in valid_targets]

    # one-hot labeling
    num_classes = len(valid_loader.dataset.classes)
    train_targets = [one_hot(t, num_classes=num_classes).to(torch.float32) for t in train_targets]

    train_Ks = [torch.einsum('nc,mc->nm', [l, l]) for l in train_logits]
    valid_Ks = [torch.einsum('nc,mc->nm', [l1, l2]) for l1, l2 in zip(valid_logits,train_logits)]

    acc = [-1.0 for _ in range(len(networks))]
    # Range of regularizer set manually.
    diag_reg_values = np.logspace(-7, 2, num=20)

    for net_idx in range(len(networks)):
        K_tt = train_Ks[net_idx]
        K_vt = valid_Ks[net_idx]
        labels_t = train_targets[net_idx].cuda(device=device, non_blocking=True)
        labels_v = valid_targets[net_idx].cuda(device=device, non_blocking=True)
        n_t = K_tt.shape[0]
        for epsilon in diag_reg_values:
            # Regularize K_tt.
            K_tt_reg = K_tt + epsilon * torch.trace(K_tt).cuda(device=device, non_blocking=True) / n_t * torch.eye(n_t).cuda(device=device, non_blocking=True)
            # 'try' statement, since scipty.linalg.solve can fail.
            try:
                # Perform NNGP inference to obtain validation accuracy.
                inv_labels = torch.linalg.solve(K_tt_reg, labels_t)
                # inv_labels = scipy.linalg.solve(K_tt_reg, labels_t, sym_pos=True)
                prediction = torch.matmul(K_vt, inv_labels)
                correct_predictions = (torch.argmax(prediction, dim=1) == labels_v).to(torch.float32)
                acc[net_idx] = max(acc[net_idx], torch.mean(correct_predictions).item())
            except Exception as e:
                if verbose:
                    print("Matrix inversion error for epsilon = {}, reason {}".format(epsilon, e))
                continue
    return acc


def compute_nngp_outputs(inputs, network, use_logits=False):
    with torch.no_grad():
        output = network(inputs)
        assert isinstance(output, tuple)
        if use_logits:  # 201 networks: return features and logits
            output = output[1]
        else:
            output = output[0]
        return output

def get_nngp_n_v2(train_loader, valid_loader, networks, metric=MetricType.ACC, train_mode=False, as_correlation=False, train_iters=-1, num_batch=-1, use_logits=False, verbose=False):
    device = torch.cuda.current_device()
    for network in networks:
        if train_iters > 0:
            slight_train(network, train_loader, train_iters, device)
        if train_mode:
            network.train()
        else:
            network.eval()

    train_logits = [[] for _ in range(len(networks))]
    train_targets = []
    for i, (inputs, targets) in enumerate(train_loader):
        if num_batch > 0 and i >= num_batch:
            break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            logit = compute_nngp_outputs(inputs, network, use_logits=use_logits)
            train_logits[net_idx].append(logit.detach())
            torch.cuda.empty_cache()
        train_targets.append(targets.detach())

    train_logits = [torch.concat(l, 0) for l in train_logits]
    if as_correlation:
        train_Ks = [torch.corrcoef(l) for l in train_logits]
    else:
        train_Ks = [torch.einsum('nc,mc->nm', [l, l]) for l in train_logits]


    if MetricType.require_only_matrix(metric):
        scores = []
        for k in train_Ks:
            val = metric(k)
            scores.append(val)
        return scores

    num_classes = len(valid_loader.dataset.classes)
    train_targets = torch.concat(train_targets, 0)
    train_targets = one_hot(train_targets, num_classes=num_classes).to(torch.float32).cuda(device=device, non_blocking=True)

    if metric is MetricType.LGA:
        scores = []
        for k in train_Ks:
            val = metric(k, train_targets)
            scores.append(val)
        return scores

    valid_logits = [[] for _ in range(len(networks))]
    valid_targets = []
    for i, (inputs, targets) in enumerate(valid_loader):
        if num_batch > 0 and i >= num_batch:
            break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            valid_logits[net_idx].append(compute_nngp_outputs(inputs, network, use_logits=use_logits).detach())
            torch.cuda.empty_cache()
        valid_targets.append(targets)

    valid_logits = [torch.concat(l, 0) for l in valid_logits]
    valid_Ks = [torch.einsum('nc,mc->nm', [l1, l2]) for l1, l2 in zip(valid_logits, train_logits)]

    valid_targets = torch.concat(valid_targets, 0)
    valid_targets = one_hot(valid_targets, num_classes=num_classes).to(torch.float32).cuda(device=device, non_blocking=True)

    scores = [-1.0 for _ in range(len(networks))]
    # Range of regularizer set manually.
    diag_reg_values = np.logspace(-7, 2, num=20)
    for net_idx in range(len(networks)):
        K_tt = train_Ks[net_idx]
        K_vt = valid_Ks[net_idx]
        n_t = K_tt.shape[0]
        for epsilon in diag_reg_values:
            # Regularize K_tt.
            K_tt_reg = K_tt + epsilon * torch.trace(K_tt).cuda(device=device, non_blocking=True) / n_t * torch.eye(n_t).cuda(device=device, non_blocking=True)
            # 'try' statement, since scipty.linalg.solve can fail.
            try:
                # Perform NNGP inference to obtain validation accuracy.
                inv_labels = torch.linalg.solve(K_tt_reg, train_targets)
                prediction = torch.matmul(K_vt, inv_labels)

                val = metric(valid_targets, prediction)
                scores[net_idx] = max(scores[net_idx], val)
            except Exception as e:
                if verbose:
                    print("Matrix inversion error for epsilon = {}, reason {}".format(epsilon, e))
                continue
    return scores

def compute_nas_score(xloader, vloader, gpu, model, resolution, batch_size):
    nngp_score = get_nngp_n(xloader, vloader, [model], train_mode=True, num_batch=1, verbose=False, gpu=gpu)[0]
    return nngp_score