import os
import sys
import logging
import numpy as np
import matplotlib
from torch import nn
from models.Update_domain import (
    DomainClientUpdate,
    DomainClientUpdate_Hesian,
    extract_logits,
)
from models.vggmodule import vgg

matplotlib.use('Agg')
import torch

logger = logging.getLogger(__name__)

def log_print(msg):
    print(msg)
    logger.info(msg)

def clip_image(x):
    return torch.clamp(x, -1.0, 1.0)

def all2one_target_transform(x, attack_target=1):
    return torch.ones_like(x) * attack_target

def DeltaWeight(w1, w2):
    diff = 0
    norm1 = 0
    norm2 = 0
    all_dot = 0
    for k in w1.keys():
        param1 = w1[k]
        param2 = w2[k]
        curr_diff = torch.norm(param1 - param2, p='fro')
        norm1 += torch.pow(torch.norm(param1, p='fro'), 2)
        norm2 += torch.pow(torch.norm(param2, p='fro'), 2)
        all_dot += torch.sum(param1 * param2)
        diff += curr_diff * curr_diff
    return all_dot / torch.sqrt(norm1 * norm2)

def test(args, test_loader, net, example_stats=None, atkmodel=None, test_eps=0.05):
    net.eval()
    test_loss = 0
    test_transform_loss = 0
    correct = 0
    correct_transform = 0
    loss_fun = nn.CrossEntropyLoss()
    nums_marker = 0

    # Class-wise accuracy tracking
    correct_per_class = torch.zeros(10, device=args.device)
    total_per_class = torch.zeros(10, device=args.device)
    correct_transform_per_class = torch.zeros(10, device=args.device)
    total_transform_per_class = torch.zeros(10, device=args.device)

    for data, target, index in test_loader:
        data = data.to(args.device).float()
        target = target.to(args.device).long()
        output = net(data)
        logits = extract_logits(output)
        test_loss += loss_fun(logits, target).item()
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.view(-1)).sum().item()

        # Update class-wise statistics
        mask_correct = (pred == target)
        correct_per_class += torch.bincount(target[mask_correct], minlength=10)
        total_per_class += torch.bincount(target, minlength=10)

        if args.verify == "marker" and atkmodel is not None:
            mask = target != args.backdoor_target_label
            data = data[mask]
            target = target[mask]
            if data.size(0) != 0:
                atkmodel.eval()
                bs = data.size(0)
                nums_marker += bs
                target_transform = lambda x: all2one_target_transform(x, args.backdoor_target_label)
                
                noise = atkmodel(data) * test_eps
                atkdata = clip_image(data + noise)
                atktarget = target_transform(target)
                atkoutput = net(atkdata)
                atk_logits = extract_logits(atkoutput)
                test_transform_loss += loss_fun(atk_logits, atktarget).item() * bs
                atkpred = atk_logits.max(1)[1]
                correct_transform += atkpred.eq(atktarget.view(-1)).sum().item()

                # Update adversarial class-wise stats
                mask_transform_correct = (atkpred == atktarget)
                correct_transform_per_class += torch.bincount(atktarget[mask_transform_correct], minlength=10)
                total_transform_per_class += torch.bincount(atktarget, minlength=10)

        if args.record_forget_event and example_stats is not None:
            acc = pred == target
            for i, idx in enumerate(index):
                idx = idx.item()
                output_correct_class = logits.data[i, target[i].item()]
                sorted_output, _ = torch.sort(logits.data[i, :])
                if acc[i]:
                    output_highest_incorrect_class = sorted_output[-2]
                else:
                    output_highest_incorrect_class = sorted_output[-1]
                margin = output_correct_class.item() - output_highest_incorrect_class.item()
                index_stats = example_stats.get(idx, [[], []])
                index_stats[0].append(acc[i].sum().item())
                index_stats[1].append(margin)
                example_stats[idx] = index_stats

    # Calculate class-wise accuracies
    accuracy_per_class = (correct_per_class.cpu().numpy() / (total_per_class.cpu().numpy() + 1e-12)) * 100
    accuracy_per_class = [round(acc, 2) if not np.isnan(acc) else 0.0 for acc in accuracy_per_class]

    if len(test_loader) == 0 or len(test_loader.dataset) == 0:
        logger.warning("Evaluation skipped for client due to empty dataset or dataloader.")
        return (0.0, 0.0, example_stats, accuracy_per_class, None)

    if args.verify == "marker" and atkmodel is not None:
        accuracy_transform = (correct_transform / nums_marker * 100) if nums_marker > 0 else 0.0
        accuracy_transform_per_class = (correct_transform_per_class.cpu().numpy() / (total_transform_per_class.cpu().numpy() + 1e-12)) * 100
        accuracy_transform_per_class = [round(acc, 2) if not np.isnan(acc) else 0.0 for acc in accuracy_transform_per_class]
        return (
            correct / len(test_loader.dataset) * 100,
            accuracy_transform,
            example_stats,
            accuracy_per_class,
            accuracy_transform_per_class
        )
    else:
        test_loss /= len(test_loader)
        test_acc = correct / len(test_loader.dataset) * 100
        return (
            test_loss,
            test_acc,
            example_stats,
            accuracy_per_class,
            None
        )

def evaluate(args, train_loaders, test_loaders, backdoorloader=None, net=None, example_stats=None, datasets_name=None, atkmodel=None, atk_eps=None):
    train_acc_list = []
    test_acc_list = []
    g_loss = []
    
    if example_stats is None:
        for client_idx in range(args.num_users):
            if args.verify == "normal" or (args.verify in ["marker", "backdoor"] and client_idx not in args.backdoor_client_idx):
                # Training metrics
                train_loss, train_acc, _, train_class_acc, _ = test(args, train_loaders[client_idx], net)
                log_print(f' {datasets_name[client_idx]:<11}| Train Loss: {train_loss:.2f} | Acc: {train_acc:.2f}%')
                log_print(' Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(train_class_acc)]))
                
                # Test metrics
                test_loss, test_acc, _, test_class_acc, _ = test(args, test_loaders[client_idx], net)
                log_print(f' {datasets_name[client_idx]:<11}| Test  Loss: {test_loss:.2f} | Acc: {test_acc:.2f}%')
                log_print(' Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(test_class_acc)]))
                
                train_acc_list.append(train_acc)
                test_acc_list.append(test_acc)
                g_loss.append(train_loss)
            
            if args.verify == "marker" and client_idx in args.backdoor_client_idx and atkmodel is not None:
                # Clean training metrics
                train_loss, train_acc, _, train_class_acc, _ = test(args, train_loaders[client_idx][0], net)
                # Backdoor training metrics
                bd_train_acc, bd_train_marker_acc, _, bd_train_class_acc, bd_train_marker_class_acc = test(args, train_loaders[client_idx][1], net, atkmodel, atk_eps)
                log_print(f' {datasets_name[client_idx]:<11}| Train Loss: {train_loss:.2f} | Clean Acc: {train_acc:.2f}% | Marker Acc: {bd_train_marker_acc:.2f}%')
                log_print(' Clean Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(train_class_acc)]))
                log_print(' Marker Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(bd_train_marker_class_acc)]))
                
                # Clean test metrics
                test_loss, test_acc, _, test_class_acc, _ = test(args, test_loaders[client_idx], net)
                # Backdoor test metrics
                bd_test_acc, bd_test_marker_acc, _, bd_test_class_acc, bd_test_marker_class_acc = test(args, test_loaders[client_idx], net, atkmodel, atk_eps)
                log_print(f' {datasets_name[client_idx]:<11}| Test  Loss: {test_loss:.2f} | Clean Acc: {test_acc:.2f}% | Marker Acc: {bd_test_marker_acc:.2f}%')
                log_print(' Clean Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(test_class_acc)]))
                log_print(' Marker Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(bd_test_marker_class_acc)]))
                
                train_acc_list.append(train_acc)
                test_acc_list.append(test_acc)
                g_loss.append(train_loss)
            
            if args.verify == "backdoor" and client_idx in args.backdoor_client_idx:
                # Training metrics
                train_loss, train_acc, _, train_class_acc, _ = test(args, train_loaders[client_idx], net)
                log_print(f' {datasets_name[client_idx]:<11}| Train Loss: {train_loss:.2f} | Acc: {train_acc:.2f}%')
                log_print(' Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(train_class_acc)]))
                
                # Test metrics
                test_loss, test_acc, _, test_class_acc, _ = test(args, test_loaders[client_idx], net)
                log_print(f' {datasets_name[client_idx]:<11}| Test  Loss: {test_loss:.2f} | Acc: {test_acc:.2f}%')
                log_print(' Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(test_class_acc)]))
                
                # Backdoor test
                bd_test_loss, bd_test_acc, _, bd_class_acc, _ = test(args, backdoorloader, net)
                log_print(f' {datasets_name[client_idx]:<11}| BKD Loss: {bd_test_loss:.2f} | BKD Acc: {bd_test_acc:.2f}%')
                log_print(' BKD Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(bd_class_acc)]))
                
                train_acc_list.append(train_acc)
                test_acc_list.append(test_acc)
                g_loss.append(train_loss)

        return train_acc_list, test_acc_list, g_loss
    else:
        for client_idx in range(args.num_users):
            if args.verify == "normal" or (args.verify in ["marker", "backdoor"] and client_idx not in args.backdoor_client_idx):
                (train_loss, 
                 train_acc, 
                 example_stats[0][client_idx], 
                 train_class_acc,
                 _) = test(
                    args=args,
                    test_loader=train_loaders[client_idx],
                    net=net,
                    example_stats=example_stats[0][client_idx]
                )
                log_print(f' {datasets_name[client_idx]:<11}| Train Loss: {train_loss:.2f} | Acc: {train_acc:.2f}%')
                log_print(' Train Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(train_class_acc)]))

                (test_loss,
                 test_acc,
                 example_stats[1][client_idx],
                 test_class_acc,
                 _) = test(
                    args=args,
                    test_loader=test_loaders[client_idx],
                    net=net,
                    example_stats=example_stats[1][client_idx]
                )
                log_print(f' {datasets_name[client_idx]:<11}| Test  Loss: {test_loss:.2f} | Acc: {test_acc:.2f}%')
                log_print(' Test Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(test_class_acc)]))

                g_loss.append(train_loss)
            
            if args.verify == "marker" and client_idx in args.backdoor_client_idx and atkmodel is not None:
                (train_loss,
                 train_acc,
                 example_stats[0][client_idx],
                 train_class_acc,
                 _) = test(
                    args=args,
                    test_loader=train_loaders[client_idx][0],
                    net=net,
                    example_stats=example_stats[0][client_idx]
                )
                
                (bd_train_clean_acc,
                 bd_train_marker_acc,
                 _,
                 bd_train_clean_class_acc,
                 bd_train_marker_class_acc) = test(
                    args=args,
                    test_loader=train_loaders[client_idx][1],
                    net=net,
                    atkmodel=atkmodel,
                    test_eps=atk_eps
                )
                log_print(f' {datasets_name[client_idx]:<11}| Train Loss: {train_loss:.2f} | Clean Acc: {train_acc:.2f}% | Marker Acc: {bd_train_marker_acc:.2f}%')
                log_print(' Clean Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(train_class_acc)]))
                log_print(' Marker Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(bd_train_marker_class_acc)]))
                
                (test_loss,
                 test_acc,
                 example_stats[1][client_idx],
                 test_class_acc,
                 _) = test(
                    args=args,
                    test_loader=test_loaders[client_idx],
                    net=net,
                    example_stats=example_stats[1][client_idx]
                )
                
                (bd_test_clean_acc,
                 bd_test_marker_acc,
                 _,
                 bd_test_clean_class_acc,
                 bd_test_marker_class_acc) = test(
                    args=args,
                    test_loader=test_loaders[client_idx],
                    net=net,
                    atkmodel=atkmodel,
                    test_eps=atk_eps
                )
                log_print(f' {datasets_name[client_idx]:<11}| Test  Loss: {test_loss:.2f} | Clean Acc: {test_acc:.2f}% | Marker Acc: {bd_test_marker_acc:.2f}%')
                log_print(' Clean Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(test_class_acc)]))
                log_print(' Marker Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(bd_test_marker_class_acc)]))
                
                g_loss.append(train_loss)
            
            if args.verify == "backdoor" and client_idx in args.backdoor_client_idx:
                (train_loss,
                 train_acc,
                 example_stats[0][client_idx],
                 train_class_acc,
                 _) = test(
                    args=args,
                    test_loader=train_loaders[client_idx],
                    net=net,
                    example_stats=example_stats[0][client_idx]
                )
                log_print(f' {datasets_name[client_idx]:<11}| Train Loss: {train_loss:.2f} | Acc: {train_acc:.2f}%')
                log_print(' Train Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(train_class_acc)]))
                
                (test_loss,
                 test_acc,
                 example_stats[1][client_idx],
                 test_class_acc,
                 _) = test(
                    args=args,
                    test_loader=test_loaders[client_idx],
                    net=net,
                    example_stats=example_stats[1][client_idx]
                )
                log_print(f' {datasets_name[client_idx]:<11}| Test  Loss: {test_loss:.2f} | Acc: {test_acc:.2f}%')
                log_print(' Test Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(test_class_acc)]))
                
                (bd_test_loss,
                 bd_test_acc,
                 _,
                 bd_class_acc,
                 _) = test(
                    args=args,
                    test_loader=backdoorloader,
                    net=net
                )
                log_print(f' {datasets_name[client_idx]:<11}| BKD Loss: {bd_test_loss:.2f} | BKD Acc: {bd_test_acc:.2f}%')
                log_print(' BKD Class Acc: ' + ' | '.join([f'C{i}:{acc:.1f}%' for i, acc in enumerate(bd_class_acc)]))
                
                g_loss.append(train_loss)
    
    return example_stats, g_loss