#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
import copy
import numpy as np
import torch
import time
from torch import nn
import random
from models.Update_domain import DomainClientUpdate, DomainClientUpdate_avg, extract_logits

# --- Model & Update Imports ---
# (Assume these are correctly defined elsewhere)
from models.Fed import FedAvg
from utils.evaluate import evaluate
from utils.init_data_model import (
    init_data,
    init_model,
    init_data_methodone,
    get_dataset,
)
from utils.fpl import proto_aggregation
from activation_vs4compute import calculate_importance_scores


# --- Mask Building Utility ---
def build_masks_local(model_local, model_global, loader, args, avg_local_scores=None):
    device = args.device
    mask_ratio = args.mask_ratio
    diff_mask_ratio = getattr(
        args, "diff_mask_ratio", mask_ratio
    )  # Allow separate diff ratio

    print("Calculating local importance scores...")
    local_scores = calculate_importance_scores(
        model_local, loader, device, args=args, normalize_globally=False
    )
    if not local_scores:
        print("Warning: calculate_importance_scores for local model returned empty.")

    print("Calculating global model importance scores on local data...")
    global_scores_on_local = calculate_importance_scores(
        model_global, loader, device, args=args, normalize_globally=False
    )
    if not global_scores_on_local:
        print("Warning: calculate_importance_scores for global model returned empty.")

    local_mask_dict = {}
    diff_mask_dict = {}

    for pname, param in model_local.named_parameters():
        layer_name = None
        parts = pname.split(".")
        current_layer_name = ""
        possible_layer_names = []
        for part in parts[:-1]:
            current_layer_name = (
                f"{current_layer_name}.{part}" if current_layer_name else part
            )
            possible_layer_names.append(current_layer_name)
        matched_layer_name = None
        for l_name in reversed(possible_layer_names):
            if l_name in local_scores:
                matched_layer_name = l_name
                break
        layer_name = matched_layer_name

        # --- Local Importance Mask (1=Important, 0=Not Important) ---
        if layer_name and layer_name in local_scores:
            # Choose score source: averaged across backdoored clients if provided, otherwise this client's score
            if avg_local_scores is not None and layer_name in avg_local_scores:
                ch_score_mask = avg_local_scores[layer_name].to(device)
            else:
                ch_score_mask = local_scores[layer_name]

            if ch_score_mask.numel() > 0:  # Ensure score tensor is not empty
                threshold_local = (
                    torch.quantile(ch_score_mask.float(), 1.0 - mask_ratio)
                    if 0 < mask_ratio <= 1
                    else mask_ratio
                )
                mask_c_local = (ch_score_mask >= threshold_local).float()
            else:
                mask_c_local = torch.ones_like(
                    ch_score_mask, dtype=torch.float
                )  # Default to important if score is empty

            if mask_c_local.dim() == 1 and mask_c_local.size(0) == param.size(0):
                exp_shape = [mask_c_local.size(0)] + [1] * (param.dim() - 1)
                pm_local = mask_c_local.view(*exp_shape).expand_as(param)
            else:
                pm_local = torch.ones_like(param)
        else:
            pm_local = torch.ones_like(param)
        local_mask_dict[pname] = pm_local.to(device)

        # --- Diff Importance Mask (1=Update, 0=Freeze) ---
        # Interpretation: Freeze if difference is SMALL, Update if difference is LARGE
        if (
            layer_name
            and layer_name in local_scores
            and layer_name in global_scores_on_local
        ):
            ch_score_local = local_scores[layer_name]
            ch_score_global = global_scores_on_local[layer_name]
            if (
                ch_score_local.numel() > 0
                and ch_score_global.numel() == ch_score_local.numel()
            ):  # Ensure scores exist and match
                ch_score_diff = torch.abs(ch_score_local - ch_score_global)
                # Threshold: Update if diff is *ABOVE* the (1-ratio) quantile (i.e., top 'ratio' differences)
                # Or if diff is above absolute threshold diff_mask_ratio
                threshold_diff = (
                    torch.quantile(ch_score_diff.float(), 1.0 - diff_mask_ratio)
                    if 0 < diff_mask_ratio <= 1
                    else diff_mask_ratio
                )
                mask_c_diff = (
                    ch_score_diff >= threshold_diff
                ).float()  # 1 for large diff (UPDATE), 0 for small diff (FREEZE)
            else:
                mask_c_diff = torch.ones_like(
                    ch_score_local, dtype=torch.float
                )  # Default to update if scores mismatch/empty

            if mask_c_diff.dim() == 1 and mask_c_diff.size(0) == param.size(0):
                exp_shape = [mask_c_diff.size(0)] + [1] * (param.dim() - 1)
                pm_diff = mask_c_diff.view(*exp_shape).expand_as(param)
            else:
                pm_diff = torch.ones_like(param)  # Default to update
        else:
            pm_diff = torch.ones_like(param)  # Default to update
        diff_mask_dict[pname] = pm_diff.to(device)

    print(
        f"Finished building masks. Local mask keys: {len(local_mask_dict)}, Diff mask keys: {len(diff_mask_dict)}"
    )
    return local_mask_dict, diff_mask_dict


# --- Masked Optimizer ---
class MaskedSGD(torch.optim.SGD):
    """
    SGD optimizer that applies a per-parameter gradient mask before stepping.
    mask_dict: {param_name: torch.Tensor mask same shape as param.grad} (1=update, 0=freeze)
    """

    def __init__(
        self,
        named_params,
        mask_dict,
        lr=0.01,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
    ):
        # named_params should be an iterator of (name, param) tuples
        params_to_optimize = [p for _, p in named_params if p.requires_grad]
        if not params_to_optimize:
            raise ValueError("Optimizer received no parameters that require gradients.")

        super().__init__(
            params_to_optimize,
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )

        # Build a reliable name to param mapping for the parameters being optimized
        self.param_name_map = {name: p for name, p in named_params if p.requires_grad}
        self.mask_dict = mask_dict
        if not self.mask_dict:
            print("Warning: MaskedSGD initialized with an empty mask_dict.")

    def step(self, closure=None):
        # Apply mask to gradients before the actual step
        for name, p in self.param_name_map.items():
            if p.grad is None:
                continue
            if name in self.mask_dict:
                mask = self.mask_dict[name].to(p.device)
                p.grad.data.mul_(mask)  # Apply mask (0 freezes gradient)
            # else:
            #      print(f"Debug: Parameter '{name}' not in mask_dict for MaskedSGD.") # Optional debug

        # Call the original SGD step method
        super().step(closure)


# --- Unified Local Trainer ---
class LocalTrainer:
    def __init__(
        self, args, train_loader, local_mask, diff_mask, is_unlearning_client=False
    ):
        self.args = args
        self.train_loader_original = train_loader
        self.local_mask = local_mask
        self.diff_mask = diff_mask  # Mask where 1 means UPDATE, 0 means FREEZE
        self.is_unlearning_client = is_unlearning_client
        self.device = args.device
        self.criterion = nn.CrossEntropyLoss()

    def _prune_model(self, net):
        """Applies pruning based on local_mask (inverted for unlearning)."""
        net_pruned = copy.deepcopy(net)
        mask_to_apply = self.local_mask
        desc = "normal"
        if self.is_unlearning_client:
            # Invert mask for unlearning: important local channels are ZEROED OUT
            mask_to_apply = {
                name: (1.0 - mask) for name, mask in self.local_mask.items()
            }
            desc = "unlearning"

        print(f"Applying initial pruning for {desc} client...")
        count_pruned = 0
        count_total = 0
        with torch.no_grad():
            for name, param in net_pruned.named_parameters():
                if name in mask_to_apply:
                    mask = mask_to_apply[name].to(param.device)
                    param.data.mul_(mask)
                    count_pruned += torch.sum(mask == 0).item()
                    count_total += param.numel()
                # else: # Count params not in mask as well
                #     count_total += param.numel() # Incorrect, count only masked params total

        # Calculate prune ratio based *only* on parameters present in the mask
        masked_param_total = sum(
            p.numel() for n, p in net_pruned.named_parameters() if n in mask_to_apply
        )
        prune_ratio = count_pruned / masked_param_total if masked_param_total > 0 else 0
        print(
            f"Pruning finished. Approx {prune_ratio*100:.2f}% of *masked* parameters zeroed based on {'inverted ' if self.is_unlearning_client else ''}local mask."
        )
        return net_pruned

    def _get_data_loader(self):
        """Returns the training data loader."""
        return self.train_loader_original

    def train(self, net):
        """
        Performs local training:
        1. Applies initial pruning.
        2. Selects appropriate data loader.
        3. Uses MaskedSGD with diff_mask for training.
        """
        # 1. Initial Pruning
        net_to_train = self._prune_model(net)
        net_to_train.train()

        # 2. Get Data Loader
        loader = self._get_data_loader()

        # 3. Setup Masked Optimizer
        named_params_for_optimizer = {
            name: p for name, p in net_to_train.named_parameters()
        }  # Get all named params

        optimizer = MaskedSGD(
            list(
                named_params_for_optimizer.items()
            ),  # Pass list of (name, param) tuples
            self.diff_mask,
            lr=self.args.lr,
            momentum=self.args.momentum,
            weight_decay=(
                self.args.weight_decay if hasattr(self.args, "weight_decay") else 1e-4
            ),  # Add default WD
        )

        # --- Training Loop ---
        epoch_loss = []
        print(
            f"Starting {'unlearning' if self.is_unlearning_client else 'normal'} training for {self.args.local_ep} epochs..."
        )
        for iter_ep in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, batch_data in enumerate(loader):
                try:
                    if len(batch_data) == 2:
                        images, labels = batch_data
                    elif len(batch_data) == 3:
                        images, labels, _ = batch_data
                    else:
                        raise ValueError(
                            f"Unexpected data format: {len(batch_data)} elements."
                        )
                    images, labels = images.to(self.device), labels.to(self.device)
                except Exception as e:
                    print(f"Error processing batch {batch_idx}: {e}")
                    continue

                optimizer.zero_grad()
                output = net_to_train(images)
                log_probs = extract_logits(output)
                loss = self.criterion(log_probs, labels.long())
                loss.backward()
                optimizer.step()  # Masking happens inside step
                batch_loss.append(loss.item())

            current_epoch_loss = sum(batch_loss) / len(batch_loss) if batch_loss else 0
            epoch_loss.append(current_epoch_loss)
            print(
                f" {'Unlearning' if self.is_unlearning_client else 'Normal'} Local Epoch {iter_ep+1}/{self.args.local_ep}, Loss: {current_epoch_loss:.4f}"
            )

        return net_to_train, sum(epoch_loss) / len(epoch_loss) if epoch_loss else 0


# --- Helper Functions for Layer Identification ---
def is_bn_layer(name, module):
    """Checks if a module is a BatchNorm layer."""
    return isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))


def get_classifier_layer_names(model, args):
    """Identifies the final classifier layer names based on model type."""
    names = []
    if args.model == "cnn" and hasattr(model, "fc3"):
        for name, _ in model.fc3.named_parameters():
            names.append(f"fc3.{name}")
    elif (
        args.model == "vgg16"
        and hasattr(model, "classifier")
        and len(model.classifier) > 6
    ):
        for name, _ in model.classifier[6].named_parameters():
            names.append(f"classifier.6.{name}")
    elif args.model == "resnet18" and hasattr(model, "fc"):
        for name, _ in model.fc.named_parameters():
            names.append(f"fc.{name}")
    elif args.model in ["vit", "mobilevit"] and hasattr(model, "classifier"):
        for name, _ in model.classifier.named_parameters():
            names.append(f"classifier.{name}")
    # Add checks for other models if needed
    if not names:
        print(f"Warning: Could not identify classifier layer for model {args.model}")
    return set(names)


# --- Main Federated Unlearning Function ---
def fu_dws(args):
    # ... (Initial setup: seed, device, print args - remains the same) ...
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    args.device = torch.device(
        "cuda:{}".format(args.gpu)
        if torch.cuda.is_available() and args.gpu != -1
        else "cpu"
    )
    print(args)

    # --- Data Loading ---
    if args.dataset_fullparti:
        train_loaders, test_loaders, backdoorloader = init_data(args)
    else:
        train_loaders, test_loaders, backdoorloader = init_data_methodone(args)

    datasets_name = get_dataset(args)
    split_factor = getattr(args, "domain_times_factor", args.domain_split_factor)
    if getattr(args, "bkd_domain_idx", 12345) == 12345:
        split_factor = args.domain_split_factor
    dsf_dir = f"dsf_{split_factor}"
    old_dsf_dir = f"dsf_{args.domain_split_factor}"
    bkd_str = "_".join(
        str(i)
        for i in (
            args.backdoor_client_idx
            if isinstance(args.backdoor_client_idx, (list, tuple))
            else [args.backdoor_client_idx]
        )
    )

    # --- Saving Setup ---
    base_dir = f"./save/test/{args.dataset}/fu_dws/{args.save}/{dsf_dir}/{bkd_str}"
    os.makedirs(base_dir, exist_ok=True)
    unlearn_client_idx = (
        args.unlearning_client[0]
        if isinstance(args.unlearning_client, (list, tuple))
        else args.unlearning_client
    )
    unlearn_client_name = (
        datasets_name[unlearn_client_idx]
        if unlearn_client_idx < len(datasets_name)
        else f"Client_{unlearn_client_idx}"
    )
    base = f"{base_dir}/{unlearn_client_name}_unlearned"

    # --- Model Loading ---
    net_glob = init_model(args).to(args.device)
    model_load_path = (
        f"./save/test/{args.dataset}/learning/{args.save}/{dsf_dir}/{bkd_str}"
    )
    if not os.path.exists(os.path.join(model_load_path, 'weight_global.pth')):
        legacy_path = f"./save/test/{args.dataset}/learning/{args.save}/{old_dsf_dir}/{bkd_str}"
        if os.path.exists(os.path.join(legacy_path, 'weight_global.pth')):
            model_load_path = legacy_path
    global_model_path = f"{model_load_path}/weight_global.pth"
    local_weights_path = f"{model_load_path}/weight_local.pth"

    if not os.path.exists(global_model_path):
        raise FileNotFoundError(f"Global model not found at {global_model_path}")
    if not os.path.exists(local_weights_path):
        raise FileNotFoundError(f"Local weights not found at {local_weights_path}")

    print(f"Loading final global model from: {global_model_path}")
    net_glob.load_state_dict(torch.load(global_model_path, map_location=args.device))
    print(f"Loading local client weights from: {local_weights_path}")
    client_weights = torch.load(local_weights_path, map_location=args.device, weights_only=False)
    # *** Store the initial state BEFORE unlearning for potential BN/Classifier reset ***
    net_glob_initial_state = copy.deepcopy(net_glob.state_dict())

    # --- Evaluation Stats Init & Initial Eval ---
    example_stats = [
        [{} for _ in range(args.num_users)],
        [{} for _ in range(args.num_users)],
    ]
    print("Evaluating initial global model (before unlearning)...")
    evaluate(
        args=args,
        train_loaders=train_loaders,
        test_loaders=test_loaders,
        net=copy.deepcopy(net_glob),
        example_stats=example_stats,
        datasets_name=datasets_name,
        backdoorloader=backdoorloader,
    )
    torch.save(example_stats, f"{base}_initial_stats.pth")  # Save initial state

    start_time = time.perf_counter()
    total_time_s = 0
    client_time_records = []
    server_time_records = []
    performance_records = []
    round_idx = 0
    use_proto = getattr(args, "proto", False)
    global_protos = {}

    net_glob_initial_state = copy.deepcopy(net_glob.state_dict())
    classifier_names = get_classifier_layer_names(net_glob, args)

    # --- Build union mask from backdoored clients ---
    backdoor_indices = (
        args.backdoor_client_idx
        if isinstance(args.backdoor_client_idx, (list, tuple))
        else [args.backdoor_client_idx]
    )
    union_local_mask = {
        k: torch.zeros_like(v, dtype=torch.float32)
        for k, v in net_glob.state_dict().items()
    }
    backdoor_count = 0
    for bd_idx in backdoor_indices:
        if bd_idx == unlearn_client_idx or bd_idx >= len(train_loaders):
            continue
        net_bd = init_model(args).to(args.device)
        sd = None
        if isinstance(client_weights, dict):
            sd = client_weights.get(bd_idx)
        elif bd_idx < len(client_weights):
            sd = client_weights[bd_idx]
        if sd is not None:
            if isinstance(sd, dict):
                net_bd.load_state_dict(sd)
            else:
                net_bd.load_state_dict(sd.state_dict())
        else:
            net_bd.load_state_dict(net_glob_initial_state)

        lm_bd, _ = build_masks_local(
            model_local=net_bd,
            model_global=net_glob,
            loader=train_loaders[bd_idx],
            args=args,
            avg_local_scores=None,
        )
        if not lm_bd:
            continue
        for k, v in lm_bd.items():
            union_local_mask[k] = torch.maximum(
                union_local_mask[k], v.float()
            )
        backdoor_count += 1

    if backdoor_count == 0:
        union_local_mask = {
            k: torch.ones_like(v, dtype=torch.float32)
            for k, v in net_glob.state_dict().items()
        }
    print(f"Union mask built from {backdoor_count} backdoored clients")

    local_mask_dicts = {}
    diff_mask_dicts = {}
    for client_idx in range(args.num_users):
        if client_idx == unlearn_client_idx:
            continue

        net_local = init_model(args).to(args.device)
        sd = None
        if isinstance(client_weights, dict):
            sd = client_weights.get(client_idx)
        elif client_idx < len(client_weights):
            sd = client_weights[client_idx]
        if sd is not None:
            if isinstance(sd, dict):
                net_local.load_state_dict(sd)
            else:
                net_local.load_state_dict(sd.state_dict())
        else:
            net_local.load_state_dict(net_glob_initial_state)

        _, dm = build_masks_local(
            model_local=net_local,
            model_global=net_glob,
            loader=train_loaders[client_idx],
            args=args,
            avg_local_scores=None,
        )
        if dm is None:
            print(f"Warning: mask computation failed, client {client_idx} will skip pruning")
            dm = {k: torch.ones_like(v) for k, v in net_glob.state_dict().items()}
        lm = {k: v.clone() for k, v in union_local_mask.items()}
        local_mask_dicts[client_idx] = lm
        diff_mask_dicts[client_idx] = dm

    for epoch in range(args.fedsalun_epoch):
        print(
            f"\n============ Unlearning Epoch {epoch+1}/{args.fedsalun_epoch} ============"
        )
        epoch_start = time.perf_counter()
        client_elapsed = 0
        local_states = []
        client_losses = []

        for client_idx in range(args.num_users):
            if client_idx == unlearn_client_idx:
                continue

            print(f"\n--- Client {client_idx} ({datasets_name[client_idx]}) ---")
            loader = train_loaders[client_idx]

            local_model = copy.deepcopy(net_glob)
            lm = local_mask_dicts[client_idx]
            for name, param in local_model.named_parameters():
                if name in lm:
                    param.data.mul_(lm[name])  # prune

            trainer = LocalTrainer(
                args=args,
                train_loader=loader,
                local_mask=lm,
                diff_mask=diff_mask_dicts[client_idx],
                is_unlearning_client=False,
            )
            t0 = time.perf_counter()
            trained_model, loss = trainer.train(net=local_model)
            client_elapsed += time.perf_counter() - t0
            client_losses.append(loss)
            print(f" Client {client_idx} loss: {loss:.4f}")

            for name, module in trained_model.named_modules():
                if is_bn_layer(name, module):
                    for pnm, _ in module.named_parameters(recurse=False):
                        full = f"{name}.{pnm}"
                        trained_model.state_dict()[full].copy_(
                            net_glob_initial_state[full]
                        )
                    # buffer
                    for bnm, _ in module.named_buffers(recurse=False):
                        full = f"{name}.{bnm}"
                        trained_model.state_dict()[full].copy_(
                            net_glob_initial_state[full]
                        )
            for cname in classifier_names:
                trained_model.state_dict()[cname].copy_(net_glob_initial_state[cname])

            local_states.append(trained_model.state_dict())

        if local_states:
            print("\nAggregating models...")
            w_glob = FedAvg(local_states)
            net_glob.load_state_dict(w_glob)
            print("Aggregation done.")
        else:
            print("Warning: No client updates for aggregation.")

        server_time = time.perf_counter() - epoch_start - client_elapsed
        elapsed = time.perf_counter() - epoch_start
        print(f"Epoch {epoch+1} finished in {elapsed:.2f}s")
        client_time_records.append(client_elapsed / (args.num_users - 1))
        server_time_records.append(server_time)
        round_idx += 1

    print("Unlearning procedure completed.")

    for epoch in range(args.unlearn_epoch):
        print("============ Train epoch {} ============".format(epoch))
        epoch_start = time.perf_counter()
        client_elapsed = 0
        w_locals = []
        local_protos = [{} for _ in range(args.num_users)]

        for client_idx in range(args.num_users):
            if client_idx == unlearn_client_idx:
                continue

            trainer_cls = DomainClientUpdate if use_proto else DomainClientUpdate_avg
            local_trainer = trainer_cls(
                args=args,
                train_loader=train_loaders[client_idx],
            )

            t0 = time.perf_counter()
            if use_proto:
                client_model, client_state, client_proto, _ = local_trainer.train(
                    net=net_glob,
                    global_protos=global_protos,
                )
                w_locals.append(client_state)
                local_protos[client_idx] = client_proto
            else:
                client_model, _ = local_trainer.train(net=net_glob)
                w_locals.append(client_model.state_dict())
            client_elapsed += time.perf_counter() - t0

        w_glob = FedAvg(w_locals)

        net_glob.load_state_dict(w_glob)
        if use_proto:
            global_protos = proto_aggregation(args, local_protos)
        server_time = time.perf_counter() - epoch_start - client_elapsed

        example_stats, global_loss = evaluate(
            args=args,
            train_loaders=train_loaders,
            test_loaders=test_loaders,
            net=copy.deepcopy(net_glob),
            example_stats=example_stats,
            datasets_name=datasets_name,
            backdoorloader=backdoorloader,
        )
        performance_records.append([round_idx] + global_loss)
        round_idx += 1

        total_time_s += time.perf_counter() - epoch_start
        client_time_records.append(client_elapsed / (args.num_users - 1))
        server_time_records.append(server_time)

    torch.save(example_stats, f"{base}_forget_event.pth")
    torch.save(net_glob.state_dict(), f"{base}_weight_global.pth")
    print(f"Total time: {total_time_s:.2f}s")

    csv_dir = f"./result/csv/{args.dataset}/{args.target}"
    os.makedirs(csv_dir, exist_ok=True)
    timestamp = int(time.time())
    ul_clients_str = "_".join(
        str(i)
        for i in (
            args.unlearning_client
            if isinstance(args.unlearning_client, (list, tuple))
            else [args.unlearning_client]
        )
    )
    bd_clients_str = "_".join(
        str(i)
        for i in (
            args.backdoor_client_idx
            if isinstance(args.backdoor_client_idx, (list, tuple))
            else [args.backdoor_client_idx]
        )
    )
    time_file = os.path.join(
        csv_dir,
        f"time_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv",
    )
    perf_file = os.path.join(
        csv_dir,
        f"performance_of_clients_{args.target}_{ul_clients_str}_{bd_clients_str}_{args.mask_ratio}_{args.diff_mask_ratio}_{timestamp}.csv",
    )
    with open(time_file, "w") as f:
        f.write("round,client_time,server_time\n")
        for r, (ct, st) in enumerate(zip(client_time_records, server_time_records)):
            f.write(f"{r},{ct},{st}\n")

    with open(perf_file, "w") as f:
        f.write("round")
        for cid in range(args.num_users):
            f.write(f",client{cid}")
        f.write("\n")
        for record in performance_records:
            f.write(",".join(map(str, record)) + "\n")
