import copy
import logging
import time

import numpy as np
import torch
import wandb
from torch import nn

from .utils import transform_list_to_tensor
from ....core.robustness.robust_aggregation import RobustAggregator, is_weight_param

import logging


def test(
    model,
    device,
    test_loader,
    criterion,
    mode="raw-task",
    dataset="cifar10",
    poison_type="fashion",
):
    class_correct = list(0.0 for i in range(10))
    class_total = list(0.0 for i in range(10))

    if dataset in ("mnist", "emnist"):
        target_class = 7
        if mode == "raw-task":
            classes = [str(i) for i in range(10)]
        elif mode == "targetted-task":
            if poison_type == "ardis":
                classes = [str(i) for i in range(10)]
            else:
                classes = [
                    "T-shirt/top",
                    "Trouser",
                    "Pullover",
                    "Dress",
                    "Coat",
                    "Sandal",
                    "Shirt",
                    "Sneaker",
                    "Bag",
                    "Ankle boot",
                ]
    elif dataset == "cifar10":
        classes = (
            "plane",
            "car",
            "bird",
            "cat",
            "deer",
            "dog",
            "frog",
            "horse",
            "ship",
            "truck",
        )
        # target_class = 2 for greencar, 9 for southwest
        if poison_type in ("howto", "greencar-neo"):
            target_class = 2
        else:
            target_class = 9

    model.eval()
    test_loss = 0
    correct = 0
    backdoor_correct = 0
    backdoor_tot = 0
    final_acc = 0
    task_acc = None

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            c = (predicted == target).squeeze()

            # test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            test_loss += criterion(output, target).item()
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            # check backdoor accuracy
            if poison_type == "ardis":
                backdoor_index = torch.where(target == target_class)
                target_backdoor = torch.ones_like(target[backdoor_index])
                predicted_backdoor = predicted[backdoor_index]
                backdoor_correct += (predicted_backdoor == target_backdoor).sum().item()
                backdoor_tot = backdoor_index[0].shape[0]
                # logging.info("Target: {}".format(target_backdoor))
                # logging.info("Predicted: {}".format(predicted_backdoor))

            # for image_index in range(test_batch_size):
            for image_index in range(len(target)):
                label = target[image_index]
                class_correct[label] += c[image_index].item()
                class_total[label] += 1
    test_loss /= len(test_loader.dataset)

    if mode == "raw-task":
        for i in range(10):
            logging.info(
                "Accuracy of %5s : %.2f %%"
                % (classes[i], 100 * class_correct[i] / class_total[i])
            )

            if i == target_class:
                task_acc = 100 * class_correct[i] / class_total[i]

        logging.info(
            "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
                test_loss,
                correct,
                len(test_loader.dataset),
                100.0 * correct / len(test_loader.dataset),
            )
        )
        final_acc = 100.0 * correct / len(test_loader.dataset)

    elif mode == "targetted-task":

        if dataset in ("mnist", "emnist"):
            for i in range(10):
                logging.info(
                    "Accuracy of %5s : %.2f %%"
                    % (classes[i], 100 * class_correct[i] / class_total[i])
                )
            if poison_type == "ardis":
                # ensure 7 is being classified as 1
                logging.info(
                    "Backdoor Accuracy of %.2f : %.2f %%"
                    % (target_class, 100 * backdoor_correct / backdoor_tot)
                )
                final_acc = 100 * backdoor_correct / backdoor_tot
            else:
                # trouser acc
                final_acc = 100 * class_correct[1] / class_total[1]

        elif dataset == "cifar10":
            logging.info(
                "#### Targetted Accuracy of %5s : %.2f %%"
                % (
                    classes[target_class],
                    100 * class_correct[target_class] / class_total[target_class],
                )
            )
            final_acc = 100 * class_correct[target_class] / class_total[target_class]
    return final_acc, task_acc


class FedAvgRobustAggregator(object):
    def __init__(
        self,
        train_global,
        test_global,
        all_train_data_num,
        train_data_local_dict,
        test_data_local_dict,
        train_data_local_num_dict,
        worker_num,
        device,
        model,
        targetted_task_test_loader,
        num_dps_poisoned_dataset,
        args,
    ):
        self.train_global = train_global
        self.test_global = test_global
        self.all_train_data_num = all_train_data_num

        self.train_data_local_dict = train_data_local_dict
        self.test_data_local_dict = test_data_local_dict
        self.train_data_local_num_dict = train_data_local_num_dict

        self.worker_num = worker_num
        self.device = device
        self.args = args
        self.model_dict = dict()
        self.sample_num_dict = dict()
        self.flag_client_model_uploaded_dict = dict()

        self.robust_aggregator = RobustAggregator(args)

        self.targetted_task_test_loader = targetted_task_test_loader
        self.num_dps_poisoned_dataset = num_dps_poisoned_dataset

        self.adversary_fl_rounds = [
            i for i in range(1, args.comm_round + 1) if (i - 1) % args.attack_freq == 0
        ]

        for idx in range(self.worker_num):
            self.flag_client_model_uploaded_dict[idx] = False
        self.model, _ = self.init_model(model)

    def init_model(self, model):
        model_params = model.state_dict()
        # logging.info(model)
        return model, model_params

    def get_global_model_params(self):
        return self.model.state_dict()

    def add_local_trained_result(self, index, model_params, sample_num):
        logging.info("add_model. index = %d" % index)
        self.model_dict[index] = model_params
        self.sample_num_dict[index] = sample_num
        self.flag_client_model_uploaded_dict[index] = True

    def check_whether_all_receive(self):
        for idx in range(self.worker_num):
            if not self.flag_client_model_uploaded_dict[idx]:
                return False
        for idx in range(self.worker_num):
            self.flag_client_model_uploaded_dict[idx] = False
        return True

    def aggregate(self):
        start_time = time.time()
        model_list = []
        training_num = 0

        for idx in range(self.worker_num):
            if self.args.is_mobile == 1:
                self.model_dict[idx] = transform_list_to_tensor(self.model_dict[idx])

            # conduct the defense here:
            local_sample_number, local_model_params = (
                self.sample_num_dict[idx],
                self.model_dict[idx],
            )

            if self.robust_aggregator.defense_type in ("norm_diff_clipping", "weak_dp"):
                clipped_local_state_dict = self.robust_aggregator.norm_diff_clipping(
                    local_model_params, self.model.state_dict()
                )
            else:
                raise NotImplementedError("Non-supported Defense type ... ")
            model_list.append((local_sample_number, clipped_local_state_dict))

            training_num += self.sample_num_dict[idx]

        logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict)))

        # logging.info("################aggregate: %d" % len(model_list))
        (num0, averaged_params) = model_list[0]

        for k in averaged_params.keys():
            for i in range(0, len(model_list)):
                local_sample_number, local_model_params = model_list[i]
                w = local_sample_number / training_num

                local_layer_update = local_model_params[k]

                if self.robust_aggregator.defense_type == "weak_dp":
                    if is_weight_param(k):
                        local_layer_update = self.robust_aggregator.add_noise(
                            local_layer_update, self.device
                        )

                if i == 0:
                    averaged_params[k] = local_model_params[k] * w
                else:
                    averaged_params[k] += local_model_params[k] * w

        # update the global model which is cached at the server side
        self.model.load_state_dict(averaged_params)

        end_time = time.time()
        logging.info("aggregate time cost: %d" % (end_time - start_time))
        return averaged_params

    def client_sampling(self, round_idx, client_num_in_total, client_num_per_round):
        num_clients = min(client_num_per_round, client_num_in_total)
        np.random.seed(
            round_idx
        )  # make sure for each comparison, we are selecting the same clients each round
        if round_idx not in self.adversary_fl_rounds:
            client_indexes = np.random.choice(
                range(client_num_in_total), num_clients, replace=False
            )
        else:
            client_indexes = np.array(
                [1]
                + list(
                    np.random.choice(
                        range(client_num_in_total), num_clients, replace=False
                    )
                )
            )  # we gaurantee that the attacker will participate in a certain frequency
        logging.info("client_indexes = %s" % str(client_indexes))
        return client_indexes

    def test_on_all_clients(self, round_idx):
        if (
            round_idx % self.args.frequency_of_the_test == 0
            or round_idx == self.args.comm_round - 1
        ):
            logging.info(
                "################local_test_on_all_clients : {}".format(round_idx)
            )
            train_num_samples = []
            train_tot_corrects = []
            train_losses = []

            test_num_samples = []
            test_tot_corrects = []
            test_losses = []
            for client_idx in range(self.args.client_num_in_total):
                # train data
                train_tot_correct, train_num_sample, train_loss = self._infer(
                    self.train_data_local_dict[client_idx]
                )
                train_tot_corrects.append(copy.deepcopy(train_tot_correct))
                train_num_samples.append(copy.deepcopy(train_num_sample))
                train_losses.append(copy.deepcopy(train_loss))

                # test data
                test_tot_correct, test_num_sample, test_loss = self._infer(
                    self.test_data_local_dict[client_idx]
                )
                test_tot_corrects.append(copy.deepcopy(test_tot_correct))
                test_num_samples.append(copy.deepcopy(test_num_sample))
                test_losses.append(copy.deepcopy(test_loss))

            # test on training dataset
            train_acc = sum(train_tot_corrects) / sum(train_num_samples)
            train_loss = sum(train_losses) / sum(train_num_samples)
            wandb.log({"Train/Acc": train_acc, "round": round_idx})
            wandb.log({"Train/Loss": train_loss, "round": round_idx})
            stats = {"training_acc": train_acc, "training_loss": train_loss}
            logging.info(stats)

            # test on test dataset
            test_acc = sum(test_tot_corrects) / sum(test_num_samples)
            test_loss = sum(test_losses) / sum(test_num_samples)
            wandb.log({"Test/Acc": test_acc, "round": round_idx})
            wandb.log({"Test/Loss": test_loss, "round": round_idx})
            stats = {"test_acc": test_acc, "test_loss": test_loss}
            logging.info(stats)

    def test_target_accuracy(self, round_idx):
        test(
            self.model,
            self.device,
            self.targetted_task_test_loader,
            criterion=nn.CrossEntropyLoss().to(self.device),
            mode="targetted-task",
            dataset=self.args.dataset,
            poison_type=self.args.poison_type,
        )

    def _infer(self, test_data):
        self.model.eval()
        self.model.to(self.device)

        test_loss = test_acc = test_total = 0.0
        criterion = nn.CrossEntropyLoss().to(self.device)
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                x = x.to(self.device)
                target = target.to(self.device)
                pred = self.model(x)
                loss = criterion(pred, target)
                _, predicted = torch.max(pred, -1)
                correct = predicted.eq(target).sum()

                test_acc += correct.item()
                test_loss += loss.item() * target.size(0)
                test_total += target.size(0)

        return test_acc, test_total, test_loss
