""" testing.py
    Utilities for testing models

    Borrowed from code for DeepThinking project
"""
import einops
import torch
from icecream import ic
from tqdm import tqdm
import torch
import numpy as np
import json

def test(net, loaders, mode, iters, problem, device):
    accs = []
    for loader in loaders:
        if mode == "default":
            accuracy = test_default(net, loader, iters, problem, device)
        elif mode == "max_conf":
            accuracy = test_max_conf(net, loader, iters, problem, device)
        elif mode == "deq":
            accuracy = test_deq(net, loader, iters, problem, device)
        elif mode == 'default_deq':
            accuracy = test_default_deq(net, loader, iters, problem, device)
        elif mode == 'deq_prog':
            accuracy = test_deq_progressive(net, loader, iters, problem, device)
        elif mode == 'measure_fp_jac':
            accuracy = test_deq_fp_jac(net, loader, iters, problem, device)
        elif mode == 'measure_jac':
            accuracy = test_deq_jac(net, loader, iters, problem, device)
        elif mode == 'measure_adv':
            accuracy = measure_adversarial(net, loader, iters, problem, device)
        elif mode == 'measure_deq_adv':
            accuracy = measure_adversarial(net, loader, iters, problem, device, use_deq=True)
        elif mode == 'measure_fp_pi_all':
            accuracy = test_deq_fp_cross_pi(net, loader, iters, problem, device)
        elif mode == 'measure_pi_all':
            accuracy = test_deq_cross_pi(net, loader, iters, problem, device)
        else:
            raise ValueError(f"{ic.format()}: test_{mode}() not implemented.")
        accs.append(accuracy)
    return accs


def get_predicted(inputs, outputs, problem):
    outputs = outputs.clone()
    predicted = outputs.argmax(1)
    predicted = predicted.view(predicted.size(0), -1)
    if problem == "mazes":
        predicted = predicted * (inputs.max(1)[0].view(inputs.size(0), -1))
    elif problem == "chess":
        outputs = outputs.view(outputs.size(0), outputs.size(1), -1)
        top_2 = torch.topk(outputs[:, 1], 2, dim=1)[0].min(dim=1)[0]
        top_2 = einops.repeat(top_2, "n -> n k", k=8)
        top_2 = einops.repeat(top_2, "n m -> n m k", k=8).view(-1, 64)
        outputs[:, 1][outputs[:, 1] < top_2] = -float("Inf")
        outputs[:, 0] = -float("Inf")
        predicted = outputs.argmax(1)

    return predicted


def test_default(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = torch.zeros(max_iters)
    total = 0

    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    sum_cos_val = 0
    total_cos = 0

    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)

            all_outputs, fp1 = net(inputs, iters_to_do=max_iters, return_fp=True)
            
            repeated_fp = torch.repeat_interleave(fp1, repeats=inputs.shape[0], dim=0)
            tiled_inputs = torch.tile(inputs, (inputs.shape[0], 1, 1, 1))

            if total < 10000:
                _, fp2 = net(tiled_inputs, interim_thought=repeated_fp, iters_to_do=max_iters, return_fp=True)
                idx = np.arange(0, tiled_inputs.shape[0], inputs.shape[0])
                
                fp1 = repeated_fp.view(repeated_fp.shape[0], -1)
                fp2 = fp2.view(repeated_fp.shape[0], -1)

                # mean center the fixed points
                denom = fp1.shape[0] + fp2.shape[0]
                num = fp1.sum(dim=0) + fp2.sum(dim=0)

                fp1 -= (num / denom)
                fp2 -= (num / denom)

                bsz = inputs.shape[0]
                for i in range(inputs.shape[0]):
                    cur_idx = idx + i
                    conseq_idx = np.arange(i*bsz, i*bsz + bsz)
                    sum_cos_val += cos(fp1[cur_idx], fp2[conseq_idx]).sum().item()

                total_cos += tiled_inputs.shape[0]
                print(sum_cos_val/total_cos)

            for i in range(all_outputs.size(1)):
                outputs = all_outputs[:, i]
                predicted = get_predicted(inputs, outputs, problem)
                targets = targets.view(targets.size(0), -1)
                corrects[i] += torch.amin(predicted == targets, dim=[1]).sum().item()

            total += targets.size(0)

    accuracy = 100.0 * corrects / total
    print("Accuracy: ", accuracy, "Cos", sum_cos_val/total_cos)
    ret_acc = {}
    for ite in iters:
        ret_acc[ite] = accuracy[ite-1].item()
    return ret_acc


def test_default_deq(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = torch.zeros(max_iters)
    total = 0

    # import wandb
    # wandb.init(project="DEQ-ALGO", name=f"chess-deq-dt_net-2d")
    with torch.no_grad():

        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            #all_outputs, abs_trace, rel_trace, norms, prev_norms = net(inputs, iters_to_do=max_iters)
            all_outputs = net(inputs, iters_to_do=max_iters, plot=False)
            for i in range(all_outputs.size(1)):
                outputs = all_outputs[:, i]
                predicted = get_predicted(inputs, outputs, problem)
                targets = targets.view(targets.size(0), -1)
                corrects[i] += torch.amin(predicted == targets, dim=[1]).sum().item()
            total += targets.size(0)
            accuracy = 100.0 * corrects / total
            print(accuracy.max())
    accuracy = 100.0 * corrects / total

    ret_acc = {}
    for ite in iters:
        ret_acc[ite] = accuracy[ite-1].item()

    # for i in range(len(abs_trace)):
    #     wandb.log({
    #                             'Accuracy': accuracy[i].item(), 
    #                             'abs trace': abs_trace[i], 
    #                             'rel_trace': rel_trace[i],
    #                             # 'FP norms': norms[i],
    #                             # 'prev_norms': prev_norms[i]
                    # })
    return ret_acc

def test_max_conf(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = torch.zeros(max_iters).to(device)
    total = 0
    softmax = torch.nn.functional.softmax

    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            targets = targets.view(targets.size(0), -1)
            total += targets.size(0)


            all_outputs = net(inputs, iters_to_do=max_iters)

            confidence_array = torch.zeros(max_iters, inputs.size(0)).to(device)
            corrects_array = torch.zeros(max_iters, inputs.size(0)).to(device)
            for i in range(all_outputs.size(1)):
                outputs = all_outputs[:, i]
                conf = softmax(outputs.detach(), dim=1).max(1)[0]
                conf = conf.view(conf.size(0), -1)
                if problem == "mazes":
                    conf = conf * inputs.max(1)[0].view(conf.size(0), -1)
                confidence_array[i] = conf.sum([1])
                predicted = get_predicted(inputs, outputs, problem)
                corrects_array[i] = torch.amin(predicted == targets, dim=[1])

            correct_this_iter = corrects_array[torch.cummax(confidence_array, dim=0)[1],
                                               torch.arange(corrects_array.size(1))]
            corrects += correct_this_iter.sum(dim=1)
            
            accuracy = 100.0 * corrects / total
            print(accuracy.max())

    accuracy = 100 * corrects.long().cpu() / total
    ret_acc = {}
    for ite in iters:
        ret_acc[ite] = accuracy[ite-1].item()
    return ret_acc

def test_deq(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            predicted = get_predicted(inputs, outputs, problem)
            targets = targets.view(targets.size(0), -1)
            corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

            total += targets.size(0)
            print(f"abs {net.module.min_abs_trace/net.module.total_count} rel {net.module.min_rel_trace/net.module.total_count}")
            print(100.0 * corrects / total)
    accuracy = 100.0 * corrects / total
    return accuracy

def test_deq_progressive(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = 0
    total = 0
    criterion = torch.nn.CrossEntropyLoss(reduction="none")
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            print(inputs.shape)
            outputs = net(inputs, iters_to_do=max_iters, train_step=0)
            predicted = get_predicted(inputs, outputs, problem)
            targets = targets.view(targets.size(0), -1)
            corrects += torch.amin(predicted == targets, dim=[1]).sum().item()
            total += targets.size(0)
            print(100.0 * corrects / total)
            print(f"abs {net.module.min_abs_trace/net.module.total_count} rel {net.module.min_rel_trace/net.module.total_count}")
    accuracy = 100.0 * corrects / total
    return accuracy

def test_deq_fp_cross_pi(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = 0
    cos_total = 0
    total = 0
    path_indep_val = 0

    abs_val_sum = 0
    abs_total = 0
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    
    sample_metric_vals = []
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            if total > 500:
                break
            inputs, targets = inputs.to(device), targets.to(device)
            print(inputs.shape)
            init_outputs, fp_val1, abs_trace = net(inputs, train_step=0, iters_to_do=max_iters, return_fp=True, return_residuals=True)

            tiled_inputs = torch.tile(inputs, (inputs.shape[0], 1, 1, 1))
            tiled_targets = torch.tile(targets, (targets.shape[0], 1, 1))

            repeated_fp = torch.repeat_interleave(fp_val1, repeats=inputs.shape[0], dim=0)
            next_outputs, fp_val2 = net(tiled_inputs, interim_thought=repeated_fp, train_step=0, iters_to_do=max_iters, return_fp=True)

            predicted = get_predicted(inputs, init_outputs, problem)
            targets = targets.view(targets.size(0), -1)
            current_correct = torch.amin(predicted == targets, dim=[1])

            predicted = get_predicted(tiled_inputs, next_outputs, problem)
            tiled_targets = tiled_targets.view(tiled_targets.size(0), -1)
            corrects += torch.amin(predicted == tiled_targets, dim=[1]).sum().item()

            cos_total += fp_val2.size(0)
            total += targets.size(0)
            
            idx = np.arange(0, tiled_inputs.shape[0], inputs.shape[0])
            fp1 = repeated_fp.view(repeated_fp.shape[0], -1)
            fp2 = fp_val2.view(fp_val2.shape[0], -1)
            # mean center the fixed points
            denom = fp1.shape[0] + fp2.shape[0]
            num = fp1.sum(dim=0) + fp2.sum(dim=0)

            fp1 -= (num / denom)
            fp2 -= (num / denom)

            bsz = inputs.shape[0]
            for i in range(inputs.shape[0]):
                cur_idx = idx + i
                conseq_idx = np.arange(i*bsz, i*bsz + inputs.shape[0])
                path_indep_val += cos(fp1[cur_idx], fp2[conseq_idx]).sum()
            
            print(f"abs {net.module.min_abs_trace/cos_total} abs val {min(abs_trace)} rel {net.module.min_rel_trace/cos_total}")
            print("Cosine similarity", path_indep_val/(cos_total))
            print(100.0 * corrects / cos_total)

    # Compute cosine similatiry between all combinations
    accuracy = 100.0 * corrects / cos_total
    return accuracy

def test_deq_fp_jac(net, testloader, iters, problem, device):
    max_iters = max(iters)
    corrects = 0
    total = 0
    sigma_max_vals = []
    max_sigma = float('-inf')
    min_sigma = float('inf')
    net.eval()
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            init_outputs, sigma_max = net(inputs, train_step=0, iters_to_do=max_iters, spectral_radius_mode=True)

            predicted = get_predicted(inputs, init_outputs, problem)
            targets = targets.view(targets.size(0), -1)
            corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

            ### compute jacobian singular values
            sigma_max_vals.append(sigma_max.sum().item())
            max_sigma = max(max_sigma, max(sigma_max).item())
            min_sigma = min(min_sigma, min(sigma_max).item())

            total += targets.size(0)
            
            print(max_sigma, sum(sigma_max_vals)/total, min_sigma)
            print(100.0 * corrects / total)

    print(max_sigma, sum(sigma_max_vals)/total, min_sigma)
    accuracy = 100.0 * corrects / total
    return accuracy

def test_deq_jac(net, testloader, iters, problem, device):
    max_iters = max(iters)
    corrects = 0
    total = 0
    sigma_max_vals = []
    max_sigma = float('-inf')
    min_sigma = float('inf')
    net.eval()
    count = 0
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            init_outputs, sigma_max = net(inputs, spectral_radius_mode=True)

            predicted = get_predicted(inputs, init_outputs, problem)
            targets = targets.view(targets.size(0), -1)
            corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

            ### compute jacobian singular values
            sigma_max_vals.append(sigma_max.sum().item())
            max_sigma = max(max_sigma, max(sigma_max).item())
            min_sigma = min(min_sigma, min(sigma_max).item())

            total += targets.size(0)
            
            print(max_sigma, sum(sigma_max_vals)/total, min_sigma)
            print(100.0 * corrects / total)

    print(max_sigma, sum(sigma_max_vals)/total, min_sigma)
    accuracy = 100.0 * corrects / total
    return accuracy

def test_deq_cross_pi(net, testloader, iters, problem, device):
    max_iters = max(iters)
    net.eval()
    corrects = 0
    total = 0

    cos_total = 0
    idx = 0
    path_indep_val = 0
    path_indep_dist = 0

    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    
    # batch_idx_to_fp = {}
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, leave=False):
            if total > 500:
                break
            inputs, targets = inputs.to(device), targets.to(device)
            init_outputs, fp_val1 = net(inputs, return_fp=True)

            tiled_inputs = torch.tile(inputs, (inputs.shape[0], 1, 1, 1))
            tiled_targets = torch.tile(targets, (targets.shape[0], 1, 1))

            repeated_fp = torch.repeat_interleave(fp_val1, repeats=inputs.shape[0], dim=0)

            next_outputs, fp_val2 = net(tiled_inputs, interim_thought=repeated_fp, return_fp=True)

            predicted = get_predicted(tiled_inputs, next_outputs, problem)
            tiled_targets = tiled_targets.view(tiled_targets.size(0), -1)
            corrects += torch.amin(predicted == tiled_targets, dim=[1]).sum().item()

            cos_total += fp_val2.size(0)
            total += targets.size(0)
            
            idx = np.arange(0, tiled_inputs.shape[0], inputs.shape[0])
            fp1 = repeated_fp.view(repeated_fp.shape[0], -1)
            fp2 = fp_val2.view(fp_val2.shape[0], -1)

            denom = fp1.shape[0] + fp2.shape[0]
            num = fp1.sum(dim=0) + fp2.sum(dim=0)

            fp1 -= (num / denom)
            fp2 -= (num / denom)
            
            bsz = inputs.shape[0]
            for i in range(inputs.shape[0]):
                cur_idx = idx + i
                conseq_idx = np.arange(i*bsz, i*bsz + inputs.shape[0])
                path_indep_val += cos(fp1[cur_idx], fp2[conseq_idx]).sum()
            
            print(f"abs {net.module.min_abs_trace/cos_total} rel {net.module.min_rel_trace/cos_total}")
            print("Cosine similarity", path_indep_val/(cos_total))
            print(100.0 * corrects / cos_total)
    accuracy = 100.0 * corrects / cos_total
    return accuracy

def measure_adversarial(net, dataloader, iters, problem, device, epsilon=0.01, total_steps=20, use_deq=False):
    import torch.optim as optim

    max_iters = max(iters)
    orig_corrects = torch.zeros(max_iters)
    adv_corrects = torch.zeros(max_iters)

    orig_total = 0
    adv_total = 0
    adversarial_cos_sim = 0
    orig_cos_sim = 0

    net.eval()
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    for inputs, targets in tqdm(dataloader, leave=False):
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            if use_deq:
                 init_outputs, fp_val1 = net(inputs, return_fp=True)
            else:
                init_outputs, fp_val1 = net(inputs, train_step=0, iters_to_do=max_iters, return_fp=True)
        
            predicted = get_predicted(inputs, init_outputs, problem)
            targets = targets.view(targets.size(0), -1)
            orig_corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

            if use_deq:
                _, fp_val2 = net(inputs, interim_thought=fp_val1, return_fp=True)
            else:
                _, fp_val2 = net(inputs, interim_thought=fp_val1, train_step=0, iters_to_do=max_iters, return_fp=True)

            orig_cos_sim += cos(fp_val2.view(inputs.shape[0], -1), fp_val1.view(inputs.shape[0], -1)).sum().item()
            orig_total += targets.size(0)
            orig_accuracy = 100.0 * orig_corrects / orig_total
            print("Initial Acc", orig_accuracy.max(), "Init Cos Sim", orig_cos_sim / orig_total)

        new_fp = fp_val1.requires_grad_()
        optimizer = optim.LBFGS([new_fp],
                        history_size=10,
                        max_iter=20,
                        line_search_fn="strong_wolfe")
        
        def one_step(fp_inp):
            if use_deq:
                next_outputs, fp_val2 = net(inputs, interim_thought=fp_inp, return_fp=True)
            else:
                next_outputs, fp_val2 = net(inputs, interim_thought=fp_inp, train_step=0, iters_to_do=max_iters, return_fp=True)

            loss = cos(fp_inp.view(inputs.shape[0], -1), fp_val2.view(inputs.shape[0], -1))
            return loss.mean()

        for step in range(total_steps):
            optimizer.zero_grad()
            objective = one_step(new_fp)

            objective.backward()
            optimizer.step(lambda: one_step(new_fp))
            print(f"[GD based] Step {step} Loss {objective.item()}")
        
        if use_deq:
            next_outputs, adv_fp = net(inputs, interim_thought=new_fp, return_fp=True)
        else:
            next_outputs, adv_fp = net(inputs, interim_thought=new_fp, train_step=0, iters_to_do=max_iters, return_fp=True)

        predicted = get_predicted(inputs, next_outputs, problem)
        targets = targets.view(targets.size(0), -1)
        adv_corrects += torch.amin(predicted == targets, dim=[1]).sum().item()

        adversarial_cos_sim += cos(new_fp.view(inputs.shape[0], -1), adv_fp.view(inputs.shape[0], -1)).sum().item()
        print(cos(fp_val1.view(inputs.shape[0], -1), adv_fp.view(inputs.shape[0], -1)).mean().item())
        print(cos(new_fp.view(inputs.shape[0], -1), adv_fp.view(inputs.shape[0], -1)).mean().item())
        
        adv_total += targets.size(0)
        accuracy = 100.0 * adv_corrects / adv_total
        print(accuracy.max(), adversarial_cos_sim/adv_total)
        if adv_total > 1000:
            break
    print("Original Acc", orig_accuracy, "Orig cos sim", orig_cos_sim/orig_total, "Adversarial Acc", accuracy, "Adversarial Cosine Sim", adversarial_cos_sim/adv_total)
    return accuracy.max().item()