#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
import copy
import numpy as np
import torch
import torch.utils.data
import time
from torch import nn
from models.Update_domain import (
    DomainClientUnlearningBl3,
    DomainClientUpdate,
    DomainClientUpdate_avg,
    DomainClientUpdate_avg_sal,
)
from models.Fed import FedAvg, FedAvg_HEAL, FedAvg_salun, FedAvg_fsu
from utils.options import args_parser
from utils.evaluate import evaluate
from utils.init_data_model import init_data, init_model, init_data_methodone, get_dataset
from utils.increase_loss_utils import get_distance  # distance calculation utility
from utils.fpl import proto_aggregation
import random
from torch.utils.data import Subset

def generate_federated_mask(delta_history, args):

    # aggregate historical deltas
    aggregated_deltas = {}
    for delta in delta_history:
        for k in delta:
            if k not in aggregated_deltas:
                aggregated_deltas[k] = []
            aggregated_deltas[k].append(delta[k].float())  # ensure float type

    # compute mean absolute difference
    mean_abs_deltas = {
        k: torch.mean(torch.stack([torch.abs(d) for d in v]), dim=0)
        for k, v in aggregated_deltas.items()
    }

    # generate multi-sparsity masks (preserve SalUn threshold logic)
    threshold_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    masks = {}
    for ratio in threshold_list:
        all_values = torch.cat([v.flatten() for v in mean_abs_deltas.values()])
        threshold_index = int(len(all_values) * ratio)
        positions = torch.argsort(all_values, descending=True)  # descending order
        
        # build hard mask (select top ratio% weights)
        mask = {}
        start_idx = 0
        for k, v in mean_abs_deltas.items():
            num_elements = v.numel()
            ranks = positions[start_idx : start_idx + num_elements].reshape(v.shape)
            mask[k] = (ranks < threshold_index).float()
            start_idx += num_elements
        masks[ratio] = mask
    return masks

def modify_client_labels(save_path, unlearning_clients, backdoor_target_label, args):
    train_path = os.path.join(save_path, 'train_loaders.pth')
    try:
        if hasattr(torch.serialization, "add_safe_globals"):
            torch.serialization.add_safe_globals([torch.utils.data.DataLoader])
        try:
            train_loaders = torch.load(train_path, weights_only=False)
        except TypeError:
            train_loaders = torch.load(train_path)
    except FileNotFoundError:
        print(f"Warning: {train_path} not found. Regenerating train loaders using current seed.")
        regen_args = copy.deepcopy(args)
        regen_args.target = 'learning'
        init_data_methodone(regen_args)
        if hasattr(torch.serialization, "add_safe_globals"):
            torch.serialization.add_safe_globals([torch.utils.data.DataLoader])
        try:
            train_loaders = torch.load(train_path, weights_only=False)
        except TypeError:
            train_loaders = torch.load(train_path)
    if isinstance(unlearning_clients, (list, tuple)):
        indices_list = list(unlearning_clients)
    else:
        indices_list = [unlearning_clients]
    torch.cuda.empty_cache()

    # ✅ Add missing attributes for each OfficeDataset
    for loader in train_loaders:
        dataset = loader.dataset
        # If dataset is a Subset, access its .dataset
        if isinstance(dataset, Subset):
            dataset = dataset.dataset
        # Add missing attributes
        if not hasattr(dataset, 'inject_backdoor'):
            dataset.inject_backdoor = False
        if not hasattr(dataset, 'backdoortest'):
            dataset.backdoortest = False
        if not hasattr(dataset, 'poisoned_images'):
            dataset.poisoned_images = None
        if not hasattr(dataset, 'poison_selected_indices'):
            dataset.poison_selected_indices = None

    # infer number of classes from training data, skipping unlearning clients
    num_classes = None
    for idx, loader in enumerate(train_loaders):
        if idx in indices_list:
            continue
        dataset = loader.dataset
        if isinstance(dataset, Subset):
            dataset = dataset.dataset
        if hasattr(dataset, 'targets'):
            labels = dataset.targets
        elif hasattr(dataset, 'labels'):
            labels = dataset.labels
        else:
            continue
        if torch.is_tensor(labels):
            labels = labels.tolist()
        num_classes = len(set(labels))
        break
    # fall back to first loader if all clients are unlearning clients
    if num_classes is None and train_loaders:
        dataset = train_loaders[0].dataset
        if isinstance(dataset, Subset):
            dataset = dataset.dataset
        if hasattr(dataset, 'targets'):
            labels = dataset.targets
        elif hasattr(dataset, 'labels'):
            labels = dataset.labels
        else:
            raise AttributeError("Unable to infer number of classes from datasets")
        if torch.is_tensor(labels):
            labels = labels.tolist()
        num_classes = len(set(labels))

    def shuffle_labels(loader, num_classes, backdoor_target_label):
        dataset = loader.dataset
        if isinstance(dataset, Subset):
            original_dataset = dataset.dataset
            idxs = dataset.indices
        else:
            original_dataset = dataset
            idxs = range(len(dataset))

        if hasattr(original_dataset, 'targets'):
            labels = original_dataset.targets.copy()
        elif hasattr(original_dataset, 'labels'):
            labels = original_dataset.labels.copy()
        else:
            raise AttributeError("Unable to find label attribute in dataset (targets/labels)")

        for idx in idxs:
            old_label = labels[idx]
            new_label = random.randint(0, num_classes - 1)
            while new_label == old_label or new_label == backdoor_target_label:
                new_label = random.randint(0, num_classes - 1)
            labels[idx] = new_label

        if hasattr(original_dataset, 'targets'):
            original_dataset.targets = labels
        elif hasattr(original_dataset, 'labels'):
            original_dataset.labels = labels

    for idx in indices_list:
        if num_classes > 1:
            shuffle_labels(train_loaders[idx], num_classes, backdoor_target_label)
    return train_loaders




def fedsalun(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    print(args)

    split_factor = getattr(args, "domain_times_factor", args.domain_split_factor)
    if getattr(args, "bkd_domain_idx", 12345) == 12345:
        split_factor = args.domain_split_factor
    dsf_dir = f"dsf_{split_factor}"
    old_dsf_dir = f"dsf_{args.domain_split_factor}"
    bkd_str = '_'.join(str(i) for i in (args.backdoor_client_idx if isinstance(args.backdoor_client_idx, (list, tuple)) else [args.backdoor_client_idx]))

    save_path = f'./save/datasets/{args.dataset}/{dsf_dir}/{bkd_str}/{args.verify}/'
    legacy_save = f'./save/datasets/{args.dataset}/{old_dsf_dir}/{bkd_str}/{args.verify}/'

    # Load dataset and split users
    if args.dataset_fullparti:
        train_loaders, test_loaders, backdoorloader = init_data(args)
    else:
        train_file = os.path.join(save_path, 'train_loaders.pth')
        test_file = os.path.join(save_path, 'test_loaders.pth')
        try:
            if hasattr(torch.serialization, "add_safe_globals"):
                torch.serialization.add_safe_globals([torch.utils.data.DataLoader])
            try:
                train_loaders = torch.load(train_file, weights_only=False)
            except TypeError:
                train_loaders = torch.load(train_file)
            try:
                test_loaders = torch.load(test_file, weights_only=False)
            except TypeError:
                test_loaders = torch.load(test_file)
            backdoorloader = None
        except FileNotFoundError:
            if os.path.exists(os.path.join(legacy_save, 'train_loaders.pth')):
                train_file = os.path.join(legacy_save, 'train_loaders.pth')
                test_file = os.path.join(legacy_save, 'test_loaders.pth')
                try:
                    if hasattr(torch.serialization, "add_safe_globals"):
                        torch.serialization.add_safe_globals([torch.utils.data.DataLoader])
                    try:
                        train_loaders = torch.load(train_file, weights_only=False)
                    except TypeError:
                        train_loaders = torch.load(train_file)
                    try:
                        test_loaders = torch.load(test_file, weights_only=False)
                    except TypeError:
                        test_loaders = torch.load(test_file)
                    save_path = legacy_save
                    backdoorloader = None
                except FileNotFoundError:
                    print(
                        "Warning: saved train/test loaders not found. Regenerating using current seed."
                    )
                    regen_args = copy.deepcopy(args)
                    regen_args.target = 'learning'
                    train_loaders, test_loaders, backdoorloader = init_data_methodone(regen_args)
                    save_path = f'./save/datasets/{args.dataset}/{dsf_dir}/{bkd_str}/{args.verify}/'
            else:
                print(
                    "Warning: saved train/test loaders not found. Regenerating using current seed."
                )
                regen_args = copy.deepcopy(args)
                regen_args.target = 'learning'
                train_loaders, test_loaders, backdoorloader = init_data_methodone(regen_args)

    datasets_name = get_dataset(args)

    # Create save directory
    base_dir = f'./save/test/{args.dataset}/fedsalun/{args.save}/{dsf_dir}/{bkd_str}'
    os.makedirs(base_dir, exist_ok=True)
    delete_users = (
        args.unlearning_client
        if isinstance(args.unlearning_client, (list, tuple))
        else [args.unlearning_client]
    )
    ul_clients_str = "_".join(str(i) for i in delete_users)
    base = f"{base_dir}/{ul_clients_str}"

    # Initialize model
    initial_model = init_model(args)
    net_glob = copy.deepcopy(initial_model).to(args.device)

    # Load pretrained model (original logic)
    model_load_path = f'./save/test/{args.dataset}/learning/{args.save}/{dsf_dir}/{bkd_str}'
    if not os.path.exists(model_load_path):
        legacy_path = f'./save/test/{args.dataset}/learning/{args.save}/{old_dsf_dir}/{bkd_str}'
        if os.path.exists(legacy_path):
            model_load_path = legacy_path
    net_glob.load_state_dict(torch.load(f'{model_load_path}/weight_global.pth', map_location='cpu'))
    client_weights = torch.load(f'{model_load_path}/weight_local.pth', map_location='cpu')

    # Move model to target device
    net_glob.to(args.device)

    # Initialize statistic records
    example_stats = [[{} for _ in range(args.num_users)], [{} for _ in range(args.num_users)]]
    loss_train = [[] for _ in range(args.num_users)]
    acc_best, idx_best = -1, -1

    # Initial model evaluation
    example_stats, _ = evaluate(
        args=args,
        train_loaders=train_loaders,
        test_loaders=test_loaders,
        net=copy.deepcopy(net_glob),
        example_stats=example_stats,
        datasets_name=datasets_name,
        backdoorloader=backdoorloader
    )

    start_time = time.perf_counter()
    time_s = time.perf_counter() - start_time
    client_time_records = []
    server_time_records = []
    performance_records = []
    round_idx = 0

    save_path = f'./save/datasets/{args.dataset}/{dsf_dir}/{bkd_str}/{args.verify}/'
    if not os.path.exists(os.path.join(save_path, 'train_loaders.pth')):
        legacy_save = f'./save/datasets/{args.dataset}/{old_dsf_dir}/{bkd_str}/{args.verify}/'
        if os.path.exists(os.path.join(legacy_save, 'train_loaders.pth')):
            save_path = legacy_save
    os.makedirs(save_path, exist_ok=True)
    unlr_train_loaders = modify_client_labels(
        save_path, delete_users, args.backdoor_target_label, args
    )

    # Load global and local weights directly from saved state dicts
    global_state = torch.load(
        f"{model_load_path}/weight_global.pth", map_location="cpu"
    )
    client_states = client_weights

    delta_history = []
    for ul_idx in delete_users:
        if isinstance(client_states, list):
            local_state = client_states[ul_idx]
        elif isinstance(client_states, dict):
            local_state = client_states.get(ul_idx, client_states.get(str(ul_idx)))
        else:
            raise TypeError("Unsupported type for client_states")

        delta = {k: (local_state[k] - global_state[k]).float() for k in global_state}
        delta_history.append(delta)

    masks = generate_federated_mask(delta_history, args)
    mask = {k: v.to(args.device) for k, v in masks[args.mask_ratio].items()}

    w_glob_pre = {k: v.to(args.device) for k, v in net_glob.state_dict().items()}
    w_glob_diff = {k: torch.zeros_like(v).to(args.device) for k, v in w_glob_pre.items()}


    use_proto = getattr(args, "proto", False)
    global_protos = {}

    for epoch in range(args.fedsalun_epoch):
        print("============ fedsalun epoch {} ============".format(epoch))
        epoch_start = time.perf_counter()
        client_elapsed = 0
        w_glob_temp = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in w_glob_pre.items()}
        w_locals = []
        local_protos = [{} for _ in range(args.num_users)]
        for client_idx in range(args.num_users):

            trainer_cls = DomainClientUpdate if use_proto else DomainClientUpdate_avg
            local_trainer = trainer_cls(
                args=args,
                train_loader=unlr_train_loaders[client_idx]
            )
            t0 = time.perf_counter()
            if use_proto:
                client_model, client_state, client_proto, _ = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device),
                    global_protos=global_protos,
                )
                w_locals.append(client_state)
                local_protos[client_idx] = client_proto
            else:
                client_model, client_loss = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device)
                )
                w_locals.append(client_model.state_dict())
            client_elapsed += time.perf_counter() - t0

        w_glob_agg = FedAvg_fsu(w_locals, delete_users, args.lamb)
        w_glob_agg = {k: v.to(args.device) for k, v in w_glob_agg.items()}
        for k in w_glob_pre:
            if k in mask:
                w_glob_diff[k] = w_glob_agg[k] - w_glob_pre[k]
                w_glob_pre[k] = w_glob_pre[k] - args.unlearn_lr * mask[k] * w_glob_diff[k]
        
        # for k in w_glob_pre:
        #     if k in mask:
        #         for idx in range(args.num_users):
        #             if idx == unlearn_client:
        #                 continue
        #             w_glob_temp[k] += args.lamb * w_locals[idx][k]
        #         w_glob_diff[k] = w_locals[unlearn_client][k] + w_glob_temp[k] - w_glob_pre[k]
        #         w_glob_pre[k] = w_glob_pre[k] - args.unlearn_lr * mask[k] * w_glob_diff[k]


        net_glob.load_state_dict(w_glob_pre)
        if use_proto:
            global_protos = proto_aggregation(args, local_protos)
        server_time = time.perf_counter() - epoch_start - client_elapsed
        time_s += time.perf_counter() - epoch_start
        participants = args.num_users
        client_time_records.append(client_elapsed / participants)
        server_time_records.append(server_time)
        round_idx += 1

    #===== subsequent federated aggregation (excluding forgotten clients)=====
    for epoch in range(args.unlearn_epoch):
        print("============ Train epoch {} ============".format(epoch))
        epoch_start = time.perf_counter()
        client_elapsed = 0
        w_locals = []
        local_protos = [{} for _ in range(args.num_users)]
        for client_idx in range(args.num_users):
            if client_idx in delete_users:
                continue

            # local training (original logic)
            # Use DomainClientUpdate_avg for local training
            trainer_cls = DomainClientUpdate if use_proto else DomainClientUpdate_avg
            local_trainer = trainer_cls(
                args=args,
                train_loader=train_loaders[client_idx]
            )
            t0 = time.perf_counter()
            if use_proto:
                client_model, client_state, client_proto, result = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device),
                    global_protos=global_protos,
                )
                client_loss = result[0] if result else 0.0
                w_locals.append(client_state)
                local_protos[client_idx] = client_proto
            else:
                client_model, client_loss = local_trainer.train(
                    net=copy.deepcopy(net_glob).to(args.device)
                )
                w_locals.append(client_model.state_dict())
            client_elapsed += time.perf_counter() - t0

        # Aggregate updates
        w_glob = FedAvg(w_locals)

        net_glob.load_state_dict(w_glob)
        if use_proto:
            global_protos = proto_aggregation(args, local_protos)
        server_time = time.perf_counter() - epoch_start - client_elapsed

        # Model evaluation
        example_stats, global_loss = evaluate(
            args=args,
            train_loaders=train_loaders,
            test_loaders=test_loaders,
            net=copy.deepcopy(net_glob),
            example_stats=example_stats,
            datasets_name=datasets_name,
            backdoorloader=backdoorloader
        )
        performance_records.append([round_idx] + global_loss)
        round_idx += 1


        # Save intermediate results (original logic)
        if (epoch+1) % 10 == 0 or epoch == args.unlearn_epoch-1:
            torch.save(example_stats, f'{base}_forget_event.pth')
            torch.save(net_glob.state_dict(), f'{base}_weight_global.pth')
            torch.save(loss_train, f'{base}_loss_train.pth')

        time_s += time.perf_counter() - epoch_start
        participants = args.num_users - len(delete_users)
        client_time_records.append(client_elapsed / participants)
        server_time_records.append(server_time)

    # Final save (original structure)
    torch.save(example_stats, f'{base}_forget_event.pth')
    torch.save(net_glob.state_dict(), f'{base}_weight_global.pth')
    torch.save(loss_train, f'{base}_loss_train.pth')
    print(f'Total time: {time_s:.2f}s')

    csv_dir = f'./result/csv/{args.dataset}/{args.target}'
    os.makedirs(csv_dir, exist_ok=True)
    timestamp = int(time.time())
    ul_clients_str = '_'.join(str(i) for i in delete_users)
    bd_clients_str = '_'.join(str(i) for i in (args.backdoor_client_idx if isinstance(args.backdoor_client_idx, (list, tuple)) else [args.backdoor_client_idx]))
    time_file = os.path.join(
        csv_dir,
        f'time_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    perf_file = os.path.join(
        csv_dir,
        f'performance_of_clients_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv'
    )
    with open(time_file, 'w') as f:
        f.write('round,client_time,server_time\n')
        for r, (ct, st) in enumerate(zip(client_time_records, server_time_records)):
            f.write(f'{r},{ct},{st}\n')

    with open(perf_file, 'w') as f:
        f.write('round')
        for cid in range(args.num_users):
            f.write(f',client{cid}')
        f.write('\n')
        for record in performance_records:
            f.write(','.join(map(str, record)) + '\n')

