import os
import sys
import matplotlib
from torch import nn
from models.Update_domain import DomainClientUpdate, DomainClientUpdate_Hesian
from models.vggmodule import vgg
import torch

matplotlib.use('Agg')

def clip_image(x):
    return torch.clamp(x, -1.0, 1.0)

def all2one_target_transform(x, attack_target=1):
    return torch.ones_like(x) * attack_target

def DeltaWeight(w1, w2):
    diff = 0
    norm1 = 0
    norm2 = 0
    all_dot = 0

    for k in w1.keys():
        param1 = w1[k]
        param2 = w2[k]
        curr_diff = torch.norm(param1 - param2, p='fro')
        norm1 += torch.pow(torch.norm(param1, p='fro'), 2)
        norm2 += torch.pow(torch.norm(param2, p='fro'), 2)
        all_dot += torch.sum(param1 * param2)
        diff += curr_diff * curr_diff
    return all_dot / torch.sqrt(norm1 * norm2)

def test(args, test_loader, net, example_stats=None, atkmodel=None, test_eps=0.05):
    net.eval()
    test_loss = 0
    test_transform_loss = 0
    correct = 0
    correct_transform = 0
    loss_fun = nn.CrossEntropyLoss()
    nums_marker = 0

    num_classes = args.num_classes
    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    for data, target, idx in test_loader:
        data = data.to(args.device).float()
        target = target.to(args.device).long()
        output = net(data)
        test_loss += loss_fun(output, target).item()
        pred = output.data.max(1)[1]
        correct += pred.eq(target.view(-1)).sum().item()

        for i in range(len(target)):
            label = target[i]
            class_correct[label] += (pred[i] == label).item()
            class_total[label] += 1

    class_acc = {}
    for i in range(num_classes):
        if class_total[i] > 0:
            class_acc[i] = round(100 * class_correct[i] / class_total[i], 2)
        else:
            class_acc[i] = 0.0

    return test_loss / len(test_loader), correct / len(test_loader.dataset) * 100, class_acc, example_stats

def evaluate(args, train_loaders, test_loaders, backdoorloader=None, net=None, example_stats=None, datasets_name=None, atkmodel=None, atk_eps=None):
    train_acc_list = []
    test_acc_list = []
    g_loss = []

    if example_stats is None:
        for client_idx in range(args.num_users):
            if args.verify == "normal" or (args.verify in ["marker", "backdoor"] and client_idx not in args.backdoor_client_idx):
                train_loss, train_acc, train_class_acc, _ = test(args=args, test_loader=train_loaders[client_idx], net=net)
                print(' {:<11s}| Train Loss: {:.2f} | Train Acc: {:.2f}'.format(datasets_name[client_idx], train_loss, train_acc))
                print(' {:<11s}| Train Class Acc: {}'.format(datasets_name[client_idx], train_class_acc))

                test_loss, test_acc, test_class_acc, _ = test(args=args, test_loader=test_loaders[client_idx], net=net)
                print(' {:<11s}| Test  Loss: {:.2f} | Test  Acc: {:.2f}'.format(datasets_name[client_idx], test_loss, test_acc))
                print(' {:<11s}| Test Class Acc: {}'.format(datasets_name[client_idx], test_class_acc))

            if args.verify == "backdoor" and client_idx in args.backdoor_client_idx:
                train_loss, train_acc, _ = test(args=args, test_loader=train_loaders[client_idx], net=net)
                print(' {:<11s}| Train Loss: {:.2f} | Train Acc: {:.2f}'.format(datasets_name[client_idx], train_loss, train_acc))

                test_loss, test_acc, _ = test(args=args, test_loader=test_loaders[client_idx], net=net)
                print(' {:<11s}| Test  Loss: {:.2f} | Test  Acc: {:.2f}'.format(datasets_name[client_idx], test_loss, test_acc))

                bd_test_loss, bd_test_acc, _ = test(args=args, test_loader=backdoorloader, net=net)
                print(' {:<11s}| BKD   Loss: {:.2f} | BKD   Acc: {:.2f}'.format(datasets_name[client_idx], bd_test_loss, bd_test_acc))

            train_acc_list.append(train_acc * 100)
            test_acc_list.append(test_acc * 100)
        return train_acc_list, test_acc_list
    else:
        for client_idx in range(args.num_users):
            if args.verify == "normal" or (args.verify in ["marker", "backdoor"] and client_idx not in args.backdoor_client_idx):
                train_loss, train_acc, train_class_acc, example_stats[0][client_idx] = test(args=args, test_loader=train_loaders[client_idx], net=net, example_stats=example_stats[0][client_idx])
                print(' {:<11s}| Train Loss: {:.2f} | Train Acc: {:.2f}'.format(datasets_name[client_idx], train_loss, train_acc))
                print(' {:<11s}| Train Class Acc: {}'.format(datasets_name[client_idx], train_class_acc))

                test_loss, test_acc, test_class_acc, example_stats[1][client_idx] = test(args=args, test_loader=test_loaders[client_idx], net=net, example_stats=example_stats[1][client_idx])
                print(' {:<11s}| Test  Loss: {:.2f} | Test  Acc: {:.2f}'.format(datasets_name[client_idx], test_loss, test_acc))
                print(' {:<11s}| Test Class Acc: {}'.format(datasets_name[client_idx], test_class_acc))

            if args.verify == "backdoor" and client_idx in args.backdoor_client_idx:
                train_loss, train_acc, train_class_acc, example_stats[0][client_idx] = test(args=args, test_loader=train_loaders[client_idx], net=net, example_stats=example_stats[0][client_idx])
                print(' {:<11s}| Train Loss: {:.2f} | Train Acc: {:.2f}'.format(datasets_name[client_idx], train_loss, train_acc))
                print(' {:<11s}| Train Class Acc: {}'.format(datasets_name[client_idx], train_class_acc))

                test_loss, test_acc, test_class_acc, example_stats[1][client_idx] = test(args=args, test_loader=test_loaders[client_idx], net=net, example_stats=example_stats[1][client_idx])
                print(' {:<11s}| Test  Loss: {:.2f} | Test  Acc: {:.2f}'.format(datasets_name[client_idx], test_loss, test_acc))
                print(' {:<11s}| Test Class Acc: {}'.format(datasets_name[client_idx], test_class_acc))

                bd_test_loss, bd_test_acc, bd_class_acc, _ = test(args=args, test_loader=backdoorloader, net=net)
                print(' {:<11s}| BKD   Loss: {:.2f} | BKD   Acc: {:.2f}'.format(datasets_name[client_idx], bd_test_loss, bd_test_acc))
                print(' {:<11s}| BKD Class Acc: {}'.format(datasets_name[client_idx], bd_class_acc))

            g_loss.append(train_loss)
        if args.record_forget_event is True:
            print(" record forget event......")

    return example_stats, g_loss