""" testing.py
    Utilities for testing models

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import einops
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from icecream import ic
from tqdm import tqdm
from deepthinking.utils.rotation import rotate_batch

# Ignore statements for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702),
#     Too many local variables (R0914), Missing docstring (C0116, C0115, C0114).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115, C0114


def test(net, loaders, mode, iters, problem, device):
    accs = []
    ssh_accs = []
    for loader in loaders:
        if mode == "default":
            accuracy, ssh_accuracy = test_default(net, loader, iters, problem, device)
        elif mode == "max_conf":
            accuracy = test_max_conf(net, loader, iters, problem, device)
        else:
            raise ValueError(f"{ic.format()}: test_{mode}() not implemented.")
        accs.append(accuracy)
        ssh_accs.append(ssh_accuracy)
    return accs, ssh_accs


def get_predicted(inputs, outputs, problem):
    batch_size = inputs.size(0)
    outputs = outputs.clone()
    predicted = outputs.argmax(1)
    predicted = predicted.view(predicted.size(0), -1)
    if problem == "mazes":
        predicted = predicted * (inputs.max(1)[0].view(inputs.size(0), -1))
    elif problem == "chess":
        outputs = outputs.view(outputs.size(0), outputs.size(1), -1)
        top_2 = torch.topk(outputs[:, 1], 2, dim=1)[0].min(dim=1)[0]
        top_2 = einops.repeat(top_2, "n -> n k", k=8)
        top_2 = einops.repeat(top_2, "n m -> n m k", k=8).view(-1, 64)
        outputs[:, 1][outputs[:, 1] < top_2] = -float("Inf")
        outputs[:, 0] = -float("Inf")
        predicted = outputs.argmax(1)
    elif problem == "sudoku":
        mask = inputs[:, 0, ...]
        inputs_idx = inputs.argmax(1)
        predicted = predicted.reshape((batch_size, 9, 9))
        predicted = predicted * mask
        predicted = predicted + inputs_idx
        predicted = predicted.view(batch_size, -1)

    return predicted

def get_predicted_topk(outputs, k):
    outputs = outputs.clone()
    outputs = outputs.view(outputs.size(0), outputs.size(1), -1)
    predicted_list = []
    top_k = torch.topk(outputs[:, 1], k, dim=1)[0]
    for i in range(0, k, 2):
        top = top_k[:, i]
        top = einops.repeat(top, "n -> n k", k=8)
        top = einops.repeat(top, "n m -> n m k", k=8).view(-1, 64)
        down = top_k[:, i + 1]
        down = einops.repeat(down, "n -> n k", k=8)
        down = einops.repeat(down, "n m -> n m k", k=8).view(-1, 64)
        outputs_current = outputs.clone()
        outputs_current[:, 1][outputs_current[:, 1] < down] = -float("Inf")
        outputs_current[:, 1][outputs_current[:, 1] > top] = -float("Inf")
        outputs_current[:, 0] = -float("Inf")
        predicted = outputs_current.argmax(1)
        predicted_list.append(predicted)
    return predicted_list

def get_visualizable_pred(outputs):
    outputs = outputs.clone()
    vis = outputs[:, 0] - outputs[:, 1]
    vis = (vis - vis.min()) / (vis.max() - vis.min())
    vis = F.interpolate(vis.unsqueeze(0), size=(256, 256), mode='nearest').squeeze(0)
    vis = vis.permute(1, 2, 0)
    vis = vis.detach().cpu().numpy().astype(np.float64)
    vis *= 255

    return vis


def test_default(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = torch.zeros(max_iters)
    ssh_corrects = torch.zeros(max_iters)
    total = 0

    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            ssh_inputs, ssh_labels = rotate_batch(inputs, 'rand')
                
            # ssh_zero_labels = torch.zeros_like(ssh_labels).to(ssh_labels.device)
            # inputs = torch.cat([inputs, ssh_inputs], dim=0)
            # ssh_combine_labels = torch.cat([ssh_zero_labels, ssh_labels])
            # all_outputs, all_ssh_outputs = net(inputs, iters_to_do=max_iters, return_ssh=True)
            # all_outputs = all_outputs[: all_outputs.shape[0] // 2, ...]
            # all_ssh_outputs = all_ssh_outputs[all_ssh_outputs.shape[0] // 2 :, ...]
            
            all_outputs = net(inputs, iters_to_do=max_iters, gr_truth=targets)
            _, all_ssh_outputs = net(ssh_inputs, iters_to_do=max_iters, return_ssh=True, ssh_gr_truth=ssh_labels)

            for i in range(all_outputs.size(1)):
                outputs = all_outputs[:, i]
                predicted = get_predicted(inputs, outputs, problem)
                targets = targets.view(targets.size(0), -1)
                corrects[i] += torch.amin(predicted == targets, dim=[1]).sum().item()

                ssh_predicted = get_predicted(inputs, all_ssh_outputs[:, i], None)
                ssh_labels = ssh_labels.view(ssh_labels.size(0), -1)
                ssh_corrects[i] += torch.amin(ssh_predicted == ssh_labels, dim=[1]).sum().item()
                
            total += targets.size(0)

    accuracy = 100.0 * corrects / total
    ssh_accuracy = 100.0 * ssh_corrects / total
    ret_acc = {}
    ssh_acc = {}
    for ite in iters:
        ret_acc[ite] = accuracy[ite-1].item()
        ssh_acc[ite] = ssh_accuracy[ite-1].item()
    return ret_acc, ssh_acc


def test_max_conf(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = torch.zeros(max_iters).to(device)
    total = 0
    softmax = torch.nn.functional.softmax

    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            targets = targets.view(targets.size(0), -1)
            total += targets.size(0)


            all_outputs = net(inputs, iters_to_do=max_iters)

            confidence_array = torch.zeros(max_iters, inputs.size(0)).to(device)
            corrects_array = torch.zeros(max_iters, inputs.size(0)).to(device)
            for i in range(all_outputs.size(1)):
                outputs = all_outputs[:, i]
                conf = softmax(outputs.detach(), dim=1).max(1)[0]
                conf = conf.view(conf.size(0), -1)
                if problem == "mazes":
                    conf = conf * inputs.max(1)[0].view(conf.size(0), -1)
                confidence_array[i] = conf.sum([1])
                predicted = get_predicted(inputs, outputs, problem)
                corrects_array[i] = torch.amin(predicted == targets, dim=[1])

            correct_this_iter = corrects_array[torch.cummax(confidence_array, dim=0)[1],
                                               torch.arange(corrects_array.size(1))]
            corrects += correct_this_iter.sum(dim=1)

    accuracy = 100 * corrects.long().cpu() / total
    ret_acc = {}
    for ite in iters:
        ret_acc[ite] = accuracy[ite-1].item()
    return ret_acc

def test_stop_by_ssh_loss(net, testloader, iters, threshold, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = 0
    ssh_corrects = 0
    total = 0
    ssh_total = 0
    criterion = torch.nn.CrossEntropyLoss(reduction="none")
    num_iter_to_do_avg = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            targets = targets.view(targets.size(0), -1)
            ssh_inputs, ssh_labels = rotate_batch(inputs, 'expand')
            ssh_labels = ssh_labels.view(ssh_labels.size(0), -1)

            all_outputs = net(inputs, iters_to_do=max_iters)
            _, all_ssh_outputs = net(ssh_inputs, iters_to_do=max_iters, return_ssh=True)
            
            loss_ssh_batch = []
            for i in range(all_outputs.size(1)):
                ssh_out = all_ssh_outputs[:, i]
                ssh_out = ssh_out.view(ssh_out.size(0), ssh_out.size(1), -1)
                loss_ssh = criterion(ssh_out, ssh_labels).mean().item()
                if loss_ssh_batch:
                    if abs(loss_ssh - loss_ssh_batch[-1]) < threshold:
                        outputs = all_outputs[:, i]
                        predicted = get_predicted(inputs, outputs, problem)
                        corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

                        ssh_predicted = get_predicted(inputs, all_ssh_outputs[:, i], None)
                        ssh_corrects += torch.amin(ssh_predicted == ssh_labels, dim=[1]).sum().item()
                        num_iter_to_do_avg += i * targets.size(0)
                        print("Stop iter: ", i)
                        print("Loss ssh: ", loss_ssh)
                        print("Loss gap: ", loss_ssh - loss_ssh_batch[-1])
                        print("-" * 50)
                        break
                loss_ssh_batch.append(loss_ssh)
                
            total += targets.size(0)
            ssh_total += ssh_labels.size(0)

    accuracy = 100.0 * corrects / total
    ssh_accuracy = 100.0 * ssh_corrects / ssh_total
    num_iter_to_do_avg = num_iter_to_do_avg / total
    return accuracy, ssh_accuracy, num_iter_to_do_avg

def test_stop_by_norm():
    pass

def test_stop_by_act(net, testloader, iters, threshold, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = torch.zeros(max_iters)
    total = 0

    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)

            all_outputs, attention_weights, act_probs = net(inputs, iters_to_do=max_iters, debug=True)[:]

            for i in range(all_outputs.size(1)):
                outputs = all_outputs[:, i]
                predicted = get_predicted(inputs, outputs, problem)
                targets = targets.view(targets.size(0), -1)
                correct = torch.amin(predicted == targets, dim=[1])
                check_stop_condition = act_probs[:, i] > threshold
                correct = correct * check_stop_condition
                corrects[i] += correct.sum().item()

            total += targets.size(0)
    accuracy = 100.0 * corrects / total 
    ret_acc = {}
    for ite in iters:
        ret_acc[ite] = accuracy[ite-1].item()
    return ret_acc
    pass

def test_stop_condition(net, loaders, mode, iters, problem, device, threshold=1.0):
    accs = []
    ssh_accs = []
    num_iters = []
    for loader in loaders:
        if mode == "ssh_loss":
            accuracy, ssh_accuracy, num_iter = test_stop_by_ssh_loss(net, loader, iters, threshold, problem, device)
            accs.append(accuracy)
            ssh_accs.append(ssh_accuracy)
            num_iters.append(num_iter)
    return accs, ssh_accs, num_iters