#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Provide the execution pipeline for Byzantine-robust federated learning experiments."""

import copy
import random
import time
from pathlib import Path
from typing import Any, Dict, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

from .byzantine_attacks import (
    attack_minmax,
    attack_minsum,
    gaussian_attack,
    lie_attack,
)
from .byzantine_defense import (
    FGNV,
    FLDetector,
    aggregate_multi_krum,
    calculate_gradient_cosine_similarities,
    client_detection,
    client_reputation,
    malicious_detection_candidate,
)
from .fed_nets import (
    MLP,
    VGG,
    AlexNetCifar,
    CNNCifar,
    CNNMnist,
    ResNet,
)
from .local_update import ClientUpdater, ServerEvaluator

# --- Internal Project Imports ---
from .utils.data_sampling import sample_iid, sample_noniid_by_dirichlet
from .utils.model_averaging import average_weights, average_weights_resilient
from .utils.reproducable import set_seed

StateDict = Dict[str, torch.Tensor]

current_file_path = Path(__file__).resolve()
PROJECT_ROOT = current_file_path.parent.parent.parent


class FederatedLearningSimulator:
    """Orchestrate data preparation, training, evaluation, and logging for experiments."""

    def __init__(self, args: Any, m):
        """Initialise the simulator with parsed arguments for a specific experiment run."""
        self.args = args
        self.m = m
        self.seed = 42 + self.m

        set_seed(self.seed)
        print(f"\nSeed has been set to {self.seed} for reproducibility.")

        self.device = self._setup_device()

        log_dir = PROJECT_ROOT / "results" / "logs" / f"{self._get_run_name()}_{time.time()}"
        log_dir.mkdir(parents=True, exist_ok=True)
        self.writer = SummaryWriter(log_dir=str(log_dir))

    def _setup_device(self) -> torch.device:
        """Select the computation device (CPU or GPU) for the run."""
        if self.args.gpu is not None and self.args.gpu != -1 and torch.cuda.is_available():
            print(f"Using GPU: {self.args.gpu}")
            torch.cuda.set_device(self.args.gpu)
            return torch.device(f"cuda:{self.args.gpu}")
        print("Using CPU")
        return torch.device("cpu")

    def _load_dataset(self) -> Tuple[Dataset, Dataset]:
        """Load the configured training and evaluation datasets."""
        if self.args.dataset == "mnist":
            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
            dataset_class = datasets.MNIST if self.args.subdataset == "mnist" else datasets.FashionMNIST
            train_dataset = dataset_class("./data/mnist/", train=True, download=True, transform=transform)
            test_dataset = dataset_class("./data/mnist/", train=False, download=True, transform=transform)
        elif self.args.dataset == "cifar":
            transform = transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
            )
            data_path = PROJECT_ROOT / "data" / "cifar"
            train_dataset = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
            test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
        else:
            raise ValueError(f"Unrecognized dataset: {self.args.dataset}")
        return train_dataset, test_dataset

    def _build_model(self) -> torch.nn.Module:
        """Instantiate the neural network model requested by the configuration."""
        if self.args.model == "cnn" and self.args.dataset == "mnist":
            model = CNNMnist(num_classes=self.args.num_classes)
        elif self.args.model == "cnn" and self.args.dataset == "cifar":
            if self.args.submodel == "AlexNet":
                model = AlexNetCifar(num_classes=self.args.num_classes)
            elif "VGG" in self.args.submodel:
                model = VGG(vgg_name=self.args.submodel, num_classes=self.args.num_classes)
            elif "ResNet" in self.args.submodel:
                model = ResNet(resnet_name=self.args.submodel, num_classes=self.args.num_classes)
            else:
                model = CNNCifar(num_classes=self.args.num_classes)
        elif self.args.model == "mlp":
            model = MLP(input_dim=784, hidden_dim=256, output_dim=self.args.num_classes)
        else:
            raise ValueError(f"Unrecognized model/dataset combination: {self.args.model}/{self.args.dataset}")

        print(f"Built Model: {model.__class__.__name__}")
        return model.to(self.device)

    def _get_run_name(self) -> str:
        """Compose a descriptive identifier for logging and artifact storage."""
        return (
            f"{self.args.detection}_{self.args.dataset}_{self.args.subdataset}_{self.args.submodel}_{self.args.num_users}users_{self.args.num_Chosenusers}Chosenusers"
            f"_{self.args.attackway}_{self.args.attacker_ability}_{self.args.num_attackers}attackers_{self.args.epochs}epochs"
        )

    def run(self, results_array=None):
        """Execute the federated learning workflow from data sampling to evaluation."""
        # 1. Setup: Load data, build model, and prepare for training
        train_dataset, test_dataset = self._load_dataset()
        global_model = self._build_model()
        print(
            "\nRule:",
            self.args.detection,
            "Attack:",
            self.args.attackway,
            "Attack Ability:",
            self.args.attacker_ability,
            "dataset:",
            self.args.dataset,
            "sub dataset:",
            self.args.subdataset,
            "Neural Network:",
            self.args.submodel,
            "num_users:",
            self.args.num_users,
            " num_chosen_users:",
            self.args.num_Chosenusers,
            " epochs:",
            self.args.epochs,
            "local_ep:",
            self.args.local_ep,
            "local train size",
            self.args.num_items_train,
            "batch size:",
            self.args.local_bs,
        )

        if self.args.iid:
            user_data_indices = sample_iid(train_dataset, self.args.num_users, self.args.num_items_train)
        else:
            user_data_indices, _ = sample_noniid_by_dirichlet(
                train_dataset, self.args.num_users, self.args.dirichlet_alpha, self.args.num_items_train
            )
        YE_data_indices = sample_iid(train_dataset, 1, self.args.num_items_test)
        test_data_indices = sample_iid(test_dataset, 1, self.args.num_items_test)

        # Initialize reputation scores
        alphas = np.ones(self.args.num_users, dtype=float)
        betas = np.ones(self.args.num_users, dtype=float)
        FPFN = np.zeros((2, self.args.epochs), dtype=float)

        # Define attackers
        all_user_indices = list(range(self.args.num_users))
        attackers = set(random.sample(all_user_indices, self.args.num_attackers))
        print(f"Attackers: {sorted(list(attackers))}")

        results = []

        # 2. Main Training Loop
        for epoch in range(self.args.epochs):
            print(f"\n--- Epoch {epoch + 1}/{self.args.epochs} ---")
            start_time = time.time()

            # self.args.lr = 0.1 / (epoch+1)

            # Select clients for the current round
            reputation = alphas / (alphas + betas + 1e-8)
            print("\nReputation:", reputation)
            good_users = [int(i) for i in np.where(reputation >= self.args.reputation_threshold)[0]]

            if epoch > self.args.epochs_phase_2 and self.args.attack and self.args.detection == "TriGuardFL":
                if self.args.num_Chosenusers < len(good_users):
                    chosen_client_indices = random.sample(list(good_users), self.args.num_Chosenusers)
                else:
                    chosen_client_indices = list(good_users)
            else:
                if self.args.num_Chosenusers < self.args.num_users:
                    chosen_client_indices = random.sample(all_user_indices, self.args.num_Chosenusers)
                else:
                    chosen_client_indices = list(range(self.args.num_users))
            chosen_client_indices.sort()

            chosen_attackers = sorted(list(set(chosen_client_indices) & set(attackers)))
            print(f"Chosen Clients: {sorted(list(chosen_client_indices))}")
            print("Chosen Attackers:", chosen_attackers)

            local_weights, local_losses, local_accs = [], [], []
            global_weights_before = copy.deepcopy(global_model.state_dict())

            # Local training for each chosen client
            for client_idx in chosen_client_indices:
                client_updater = ClientUpdater(
                    device=self.device,
                    dataset=train_dataset,
                    indices=user_data_indices[client_idx],
                    batch_size=self.args.local_bs,
                    logger=self.writer,
                )
                weights, loss, acc = client_updater.train(
                    model=copy.deepcopy(global_model), learning_rate=self.args.lr, local_epochs=self.args.local_ep
                )
                local_weights.append(weights)
                local_losses.append(loss)
                local_accs.append(acc)

            # Apply Byzantine attacks if enabled
            if self.args.attack and epoch > 0 and len(chosen_attackers) > 0:
                if self.args.attackway == "lie":
                    local_weights = lie_attack(local_weights, chosen_client_indices, attackers, self.args.lr)
                elif self.args.attackway == "minmax":
                    local_weights = attack_minmax(
                        local_weights,
                        global_weights_before,
                        chosen_client_indices,
                        attackers,
                        self.args.lr,
                        self.args.attacker_ability,
                    )
                elif self.args.attackway == "minsum":
                    local_weights = attack_minsum(
                        local_weights,
                        global_weights_before,
                        chosen_client_indices,
                        attackers,
                        self.args.lr,
                        self.args.attacker_ability,
                    )
                else:  # Default to Gaussian for other names
                    local_weights = gaussian_attack(local_weights, chosen_client_indices, attackers, 0.1, self.device)

            # Apply Byzantine defense and aggregation
            if self.args.attack and self.args.detection == "TriGuardFL" and epoch > 0:
                w_locals_malicious_candidate, w_locals_benige_candidate, _grad_std = [], [], []
                list_acc_local, list_loss_local = [], []
                aggregated_weights_candidate = average_weights(local_weights)
                # cosine_similarity = cosine_all(local_weights, aggregated_weights_candidate, global_weights_before)
                cosine_similarity = calculate_gradient_cosine_similarities(
                    local_weights, aggregated_weights_candidate, global_weights_before, self.args.lr
                )
                malicious_candidate, malicious_candidate_index = malicious_detection_candidate(
                    cosine_similarity, chosen_client_indices, self.args.cos_threshold
                )

                print("Detected Potential Attackers:", malicious_candidate)

                for i in range(len(chosen_client_indices)):
                    if i in malicious_candidate_index:
                        w_locals_malicious_candidate.append(copy.deepcopy(local_weights[i]))
                    else:
                        w_locals_benige_candidate.append(copy.deepcopy(local_weights[i]))
                w_glob_benign_candidate = average_weights(w_locals_benige_candidate)

                for c in range(len(malicious_candidate) + 1):
                    net_server_local = copy.deepcopy(global_model)
                    if c in range(len(malicious_candidate)):
                        net_server_local.load_state_dict(w_locals_malicious_candidate[c])
                    else:
                        net_server_local.load_state_dict(w_glob_benign_candidate)
                    # global test
                    net_server_local.eval()

                    net_local = ServerEvaluator(
                        device=self.device,
                        dataset=train_dataset,
                        indices=YE_data_indices[0],
                        batch_size=self.args.local_bs,
                        # logger=self.writer
                    )
                    acc_total, loss_total = net_local.evaluate_by_class(net_server_local)
                    list_acc_local.append(acc_total)
                    list_loss_local.append(loss_total)

                malicious = client_detection(
                    list_acc_local,
                    list_loss_local,
                    malicious_candidate,
                    self.args.significance,
                )
                rep, alphas, betas = client_reputation(
                    self.args.discount,
                    malicious,
                    alphas,
                    betas,
                    chosen_client_indices,
                )
                aggregated_weights = average_weights_resilient(local_weights, rep)
                print("Detected Attackers:", malicious)
                FPFN[0, epoch] = len(chosen_attackers) - len(np.intersect1d(chosen_attackers, malicious))
                FPFN[1, epoch] = len(malicious) - len(np.intersect1d(chosen_attackers, malicious))
                # print('\nReputation:', rep)

                #     rep = client_reputation(args,w_dis, w_dis_rate, w_delta_dis_rate, rep, chosenUsers)
            elif self.args.attack and self.args.detection == "DeFL" and epoch > 0:
                malicious = FGNV(local_weights, global_weights_before, chosen_client_indices, self.args.lr)
                rep, alphas, betas = client_reputation(
                    self.args.discount,
                    malicious,
                    alphas,
                    betas,
                    chosen_client_indices,
                )

                aggregated_weights = average_weights_resilient(local_weights, rep)
                print("Detected Attackers:", malicious)

            elif self.args.attack and self.args.detection == "MultiKrum" and epoch > 0:
                aggregated_weights = aggregate_multi_krum(
                    local_weights,
                    global_weights_before,
                    self.args.num_attackers,
                    self.args.num_Chosenusers - self.args.num_attackers,
                    self.args.lr,
                )
            elif self.args.attack and self.args.detection == "FLTrust" and epoch > 0:
                # FLTrust requires a clean server-side update for comparison
                server_updater = ClientUpdater(
                    self.device,
                    train_dataset,
                    YE_data_indices[0],
                    self.args.local_bs,  # Small root dataset
                )
                server_weights, _, _ = server_updater.train(copy.deepcopy(global_model), self.args.lr, 1)
                reputations = calculate_gradient_cosine_similarities(
                    local_weights, server_weights, global_weights_before, self.args.lr
                )
                print("Reputations:", reputations)
                # Weighted average based on trust scores
                reputations_relu_output = np.maximum(0, reputations)
                aggregated_weights = average_weights_resilient(local_weights, reputations_relu_output)
            elif self.args.attack and self.args.detection == "FLDetector" and epoch > 0:
                self.detector = FLDetector(window_size=10, kmax=2, b_ref=10, ridge=1e-6, start_iter=10)
                global_weights_after = average_weights(local_weights)
                malicious, scores = self.detector.step_and_detect(
                    chosen_users=chosen_client_indices,
                    local_weights=local_weights,  # List[Dict[str, Tensor]]
                    global_weights_before=global_weights_before,  # Dict[str, Tensor]
                    global_weights_after=global_weights_after,  # Dict[str, Tensor]
                    lr=self.args.lr,
                )

                if malicious:
                    keep = [i for i, uid in enumerate(chosen_client_indices) if uid not in malicious]
                    local_weights = [local_weights[i] for i in keep]
                aggregated_weights = average_weights(local_weights)
            else:
                # Default: FedAvg
                aggregated_weights = average_weights(local_weights)

            # Update global model
            global_model.load_state_dict(aggregated_weights)

            # 3. Evaluation
            avg_train_loss = np.mean(local_losses) if local_losses else 0.0
            avg_train_acc = np.mean(local_accs) if local_accs else 0.0

            # Evaluate on the entire test set
            test_evaluator = ClientUpdater(self.device, test_dataset, test_data_indices[0], 128)
            test_acc, test_loss = test_evaluator.evaluate(global_model)

            results.append([avg_train_loss, avg_train_acc, test_loss, test_acc])
            self.writer.add_scalar("accuracy/test", test_acc, epoch)
            self.writer.add_scalar("loss/test", test_loss, epoch)

            print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.4f}")
            print(f"Test Loss:  {test_loss:.4f} | Test Acc:  {test_acc:.4f}")
            print(f"Time: {time.time() - start_time:.2f}s")

        # 4. Save Results
        save_path = PROJECT_ROOT / "results" / "experiments"
        save_path.mkdir(parents=True, exist_ok=True)
        filename = save_path / f"{self._get_run_name()}_{self.m}.csv"
        results_to_save = np.array(results)
        np.savetxt(filename, results_to_save, fmt="%.6f", delimiter=",")
        print(f"\nResults saved to {filename}")
        self.writer.close()
