#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import copy
import time
import math
import numpy as np
import torch
from models.Update_domain import DomainClientUpdate, DomainClientUpdate_avg
from utils.init_data_model import init_data, init_model, init_data_methodone
from utils.evaluate import evaluate
from utils.fpl import proto_aggregation

def federaser(args):
    # 初始化时间记录
    start_time = time.perf_counter()
    
    # 设置随机种子和设备
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print(args)

    # 加载数据集
    if args.dataset_fullparti:
        train_loaders, test_loaders, backdoorloader = init_data(args)
    else:
        train_loaders, test_loaders, backdoorloader = init_data_methodone(args)

    # 数据集名称映射
    dataset_map = {
        'domain_digits': ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST-M'],
        'office-caltech10': ['amazon', 'caltech', 'dslr', 'webcam'],
        'domainnet': ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'],
        'pacs': ['art_painting', 'cartoon', 'photo', 'sketch']
    }
    try:
        datasets_name = dataset_map[args.dataset.lower()]
    except KeyError as e:
        raise KeyError(
            f"Unknown dataset {args.dataset}. Available datasets: {list(dataset_map.keys())}"
        ) from e

    # 存储路径设置
    delete_users = args.unlearning_client
    delete_names = "_".join(datasets_name[i] for i in delete_users)
    save_base = f"./save/test/{args.dataset}/federaser/{args.save}/{delete_names}"
    if not os.path.exists(f'./save/test/{args.dataset}/federaser/{args.save}/'):
        os.makedirs(f'./save/test/{args.dataset}/federaser/{args.save}/')

    # 初始化记录结构
    example_stats = [[{} for _ in range(args.num_users)], [{} for _ in range(args.num_users)]]
    loss_train = [[] for _ in range(args.num_users)]

    # 加载初始模型并迁移至指定设备
    net_glob = init_model(args).to(args.device)
    w_glob = torch.load(
        f'./save/test/{args.dataset}/learning/{args.save}/weight_init.pth',
        map_location=args.device
    )
    net_glob.load_state_dict({k: v.float() for k, v in w_glob.items()})

    # 历史参数存储（Δ_t=2）
    delta_t = 2
    global_update_history = []

    # 主训练循环
    use_proto = getattr(args, "proto", False)
    global_protos = {}

    for epoch in range(args.unlearn_epoch):
        # ============ 校准训练阶段 ============
        calibrating_clients = [c for c in range(args.num_users) if c != args.unlearning_client]
        calibrated_updates = []
        original_local_ep = copy.deepcopy(args.local_ep)  # 保存原始本地训练轮次
        local_protos = [{} for _ in range(args.num_users)]

        # 对每个客户端执行校准训练
        for client_idx in calibrating_clients:
            # 设置校准训练轮次（文献r=0.5）
            args.local_ep = int(np.ceil(original_local_ep * 0.5))

            trainer_cls = DomainClientUpdate if use_proto else DomainClientUpdate_avg
            local = trainer_cls(args, train_loader=train_loaders[client_idx])
            if use_proto:
                net_local, local_state, client_proto, _ = local.train(
                    net=copy.deepcopy(net_glob).to(args.device),
                    global_protos=global_protos,
                )
                w_local = local_state
                local_protos[client_idx] = client_proto
            else:
                net_local, _ = local.train(
                    net=copy.deepcopy(net_glob).to(args.device)
                )
                w_local = net_local.state_dict()

            # 记录校准更新（使用参数差异表示更新量）
            w_global = net_glob.state_dict()
            delta = {k: (w_local[k] - w_global[k]).cpu() for k in w_global.keys()}
            calibrated_updates.append({
                'params': delta,
                'num_samples': len(train_loaders[client_idx].dataset)
            })
        
        # ============ 历史参数存储 ============
        if epoch % delta_t == 0:
            try:
                # 修复问题①：统一使用双引号
                original_updates = torch.load(
                    f"./save/test/{args.dataset}/learning/{args.model}/models/local_epoch-{epoch}.pth",
                    map_location=args.device
                )
                # 类型转换并过滤目标客户端
                processed_updates = []
                for u in original_updates:
                    if u.get('client_id') == args.unlearning_client:
                        continue
                    params = {k: v.float() for k, v in u.items() if k != 'client_id'}
                    processed_updates.append(params)
                global_update_history.append(processed_updates)
            except FileNotFoundError:
                print(f"Warning: No historical data found for epoch {epoch}")

        # ============ 参数校准阶段 ============
        if use_proto:
            global_protos = proto_aggregation(args, local_protos)

        if len(global_update_history) > 0:
            hist_updates = global_update_history[epoch // delta_t]
            
            # 执行文献式(1)校准
            calibrated = []
            for hist, cal in zip(hist_updates, calibrated_updates):
                calibrated_layer = {}
                for layer in hist.keys():
                    if layer in cal['params']:
                        orig_norm = torch.norm(hist[layer], p=2)
                        cal_norm = torch.norm(cal['params'][layer], p=2) + 1e-10
                        calibrated_layer[layer] = orig_norm * (cal['params'][layer] / cal_norm)
                calibrated.append({
                    'params': calibrated_layer,
                    'num_samples': cal['num_samples']
                })

            # ============ 聚合更新阶段 ============
            if calibrated:
                # 加权平均聚合（文献式(2)）
                total_samples = sum([u['num_samples'] for u in calibrated])
                aggregated = {}
                for layer in calibrated[0]['params'].keys():
                    aggregated[layer] = sum(
                        u['params'][layer] * u['num_samples'] 
                        for u in calibrated
                    ) / total_samples

                # 更新全局模型（文献式(3)）
                current_state = net_glob.state_dict()
                new_state = {
                    layer: current_state[layer] + aggregated[layer].to(args.device)
                    for layer in current_state.keys()
                }
                net_glob.load_state_dict(new_state)
                net_glob.to(args.device)

        # ============ 评估与保存 ============
        # 确保评估使用位于目标设备上的模型
        net_glob.to(args.device)
        example_stats, g_loss = evaluate(
            args=args,
            train_loaders=train_loaders,
            test_loaders=test_loaders,
            net=net_glob,
            example_stats=example_stats,
            datasets_name=datasets_name,
            backdoorloader=backdoorloader
        )
        for client_idx in range(args.num_users):
            if client_idx != args.unlearning_client:
                loss_train[client_idx].append(g_loss[client_idx])

        # 定期保存检查点（保留原始结构）
        if (epoch + 1) % 10 == 0:
            torch.save({
                'example_stats': example_stats,
                'weight_global': net_glob.state_dict(),
                'loss_train': loss_train
            }, f"{save_base}_epoch_{epoch+1}.pth")

        # 恢复原始本地训练轮次
        args.local_ep = original_local_ep

    # ============ 最终保存 ============
    torch.save(example_stats, f"{save_base}_forget_event.pth")
    torch.save(net_glob.state_dict(), f"{save_base}_weight_global.pth")
    torch.save(loss_train, f"{save_base}_loss_train.pth")

    # 输出总耗时（修复问题②）
    print(f'Unlearning completed in {time.perf_counter()-start_time:.2f}s')

