import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchattacks
from tqdm import tqdm

import attacks
from utils.utils import AverageMeter


def torch_fix_seed(seed=42):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True


def evaluate(args, net, test_loader, adv=True, max_n=10000):
    # PGD-20
    adversary = attacks.PGD_linf(
        epsilon=8.0 / 255, num_steps=20, step_size=2.0 / 255
    ).cuda()

    net.eval()
    if adv is False:
        torch.set_grad_enabled(False)
    running_loss = 0
    running_acc = 0
    count = 0
    for i, (bx, by) in tqdm(
        enumerate(test_loader), total=len(test_loader), desc="eval"
    ):
        bx, by = bx.cuda(), by.cuda()
        count += by.size(0)

        if adv:
            adv_bx = adversary(net, bx, by)
            if adv_bx.requires_grad:
                adv_bx = adv_bx.detach()
            with torch.no_grad():
                logits = net(adv_bx)
        else:
            with torch.no_grad():
                logits = net(bx)

        loss = F.cross_entropy(logits.data, by, reduction="sum")
        running_loss += loss.cpu().data.numpy()
        running_acc += (
            (torch.max(logits, dim=1)[1] == by).float().sum(0).cpu().data.numpy()
        )

        if count > max_n:
            break
    running_loss /= count
    running_acc /= count

    loss = running_loss
    acc = running_acc

    if adv is False:
        torch.set_grad_enabled(True)
    return loss, acc


def eval_feature_invariance(args, net, layer_name, test_loader, adversary, max_n=10000):
    """
    return average cosine similarity between adversarial and clean features
    """
    net.eval()

    cos_sim_sum = 0
    count = 0
    for i, (bx, by) in tqdm(
        enumerate(test_loader), total=len(test_loader), desc="eval"
    ):
        bx, by = bx.cuda(), by.cuda()
        count += by.size(0)

        try:
            adv_bx = adversary(bx, by)
        except:
            adv_bx = adversary(net, bx, by)
        if adv_bx.requires_grad:
            adv_bx = adv_bx.detach()

        with torch.no_grad():
            logits, feat_dict = net(adv_bx, get_feat=True)
            feat = feat_dict[layer_name]
            if len(feat.shape) == 4:
                feat = F.avg_pool2d(feat, feat.shape[2:])
                feat = feat.view(feat.shape[0], -1)
            feat = feat / feat.norm(dim=1).view(-1, 1)
            feat_adv = feat
        with torch.no_grad():
            logits, feat_dict = net(bx, get_feat=True)
            feat = feat_dict[layer_name]
            if len(feat.shape) == 4:
                feat = F.avg_pool2d(feat, feat.shape[2:])
                feat = feat.view(feat.shape[0], -1)
            feat = feat / feat.norm(dim=1).view(-1, 1)
            feat_clean = feat

        cos_sim = F.cosine_similarity(feat_adv, feat_clean, dim=1)
        cos_sim = cos_sim.mean()
        cos_sim_sum += cos_sim.cpu().data.numpy() * bs

        if count > max_n:
            break
    avg_cos_sim = cos_sim_sum / count
    print(
        "Average cosine similarity (adv. vs. clean feats): {:.4f}".format(avg_cos_sim)
    )
    return avg_cos_sim


def evaluate_auto_attack(args, net, test_loader, adv=True, max_n=10000):
    adversary = torchattacks.AutoAttack(
        net, norm="Linf", eps=8 / 255, n_classes=args.num_classes
    )

    net.eval()
    if adv is False:
        torch.set_grad_enabled(False)
    running_loss = 0
    running_acc = 0
    count = 0
    for i, (bx, by) in tqdm(
        enumerate(test_loader), total=len(test_loader), desc="eval"
    ):
        bx, by = bx.cuda(), by.cuda()
        count += by.size(0)

        if adv:
            adv_bx = x = adversary(bx, by) if adv else bx  ######
            if adv_bx.requires_grad:
                adv_bx = adv_bx.detach()
            with torch.no_grad():
                logits = net(adv_bx)
        else:
            with torch.no_grad():
                logits = net(bx)

        loss = F.cross_entropy(logits.data, by, reduction="sum")
        running_loss += loss.cpu().data.numpy()
        running_acc += (
            (torch.max(logits, dim=1)[1] == by).float().sum(0).cpu().data.numpy()
        )

        if count > max_n:
            break
    running_loss /= count
    running_acc /= count

    loss = running_loss
    acc = running_acc

    if adv is False:
        torch.set_grad_enabled(True)
    return loss, acc


def evaluate_multi(net, test_loader, attack_dict, class_num, max_n=10000):
    class_wise_matrix_dict = {
        att: np.zeros((class_num, class_num)) for att in attack_dict
    }
    # to have the same augmentation
    torch_fix_seed()

    net.eval()
    loss_meters = {attack: AverageMeter() for attack in attack_dict}
    acc_meters = {attack: AverageMeter() for attack in attack_dict}
    count = 0
    for batch_idx, (bx, by) in tqdm(
        enumerate(test_loader), total=len(test_loader), desc="eval"
    ):
        bx, by = bx.cuda(), by.cuda()
        bs = by.size(0)
        count += bs

        for name, adversary in attack_dict.items():
            if name == "natural":
                adv_bx = bx
            else:
                try:
                    adv_bx = adversary(bx, by)
                except:
                    adv_bx = adversary(net, bx, by)
            with torch.no_grad():
                logits = net(adv_bx)

            loss = F.cross_entropy(logits.data, by, reduction="sum")
            loss_meters[name].update(loss.cpu().data.numpy() / bs, bs)
            preds = torch.max(logits, dim=1)[1]
            acc = (preds == by).float().sum(0).cpu().data.numpy()
            acc_meters[name].update(
                acc / bs,
                bs,
            )

            preds_np = preds.detach().cpu().numpy()
            by_np = by.detach().cpu().numpy()
            for i, j in zip(preds_np, by_np):
                class_wise_matrix_dict[name][j][i] += 1

        if count >= max_n:
            break
    torch.set_grad_enabled(True)
    return loss_meters, acc_meters, class_wise_matrix_dict
