import logging
from typing import Dict

import torch
from torch import optim, nn
from copy import deepcopy
import numpy as np

from models.model import Model
from models.nc_model import NCModel
from synthesizers.synthesizer import Synthesizer
from losses.loss_functions import compute_all_losses_and_grads
from utils.min_norm_solvers import MGDASolver
from utils.parameters import Params

logger = logging.getLogger('logger')


class Attack:
    params: Params
    synthesizer: Synthesizer
    nc_model: Model
    nc_optim: torch.optim.Optimizer
    # fixed_model: Model

    def __init__(self, params, synthesizer):
        self.params = params
        self.synthesizer = synthesizer

        # NC hyper params
        if 'neural_cleanse' in self.params.loss_tasks:
            self.nc_model = NCModel(params.input_shape[1]).to(params.device)
            self.nc_optim = torch.optim.Adam(self.nc_model.parameters(), 0.01)

    def compute_blind_loss(self, model, criterion, batch, attack):
        """

        :param model:
        :param criterion:
        :param batch:
        :param attack: Do not attack at all. Ignore all the parameters
        :return:
        """
        batch = batch.clip(self.params.clip_batch)
        
        if attack:
            loss_tasks = self.params.loss_tasks.copy()
        elif 'Reddit' in self.params.task:
            loss_tasks = ['reddit_normal']
        else:
            loss_tasks = ['normal']

        if self.synthesizer is not None:
            batch_back = self.synthesizer.make_backdoor_batch(batch, attack=attack)
        else:
            batch_back = batch
        scale = dict()

        if 'neural_cleanse' in loss_tasks:
            self.neural_cleanse_part1(model, batch, batch_back)

        if len(loss_tasks) == 1:
            loss_values, grads = compute_all_losses_and_grads(
                loss_tasks,
                self, model, criterion, batch, batch_back, compute_grad=False
            )
        elif self.params.loss_balance == 'MGDA':

            loss_values, grads = compute_all_losses_and_grads(
                loss_tasks,
                self, model, criterion, batch, batch_back, compute_grad=True)
            if len(loss_tasks) > 1:
                scale = MGDASolver.get_scales(grads, loss_values,
                                              self.params.mgda_normalize,
                                              loss_tasks)
        elif self.params.loss_balance == 'fixed':
            loss_values, grads = compute_all_losses_and_grads(
                loss_tasks,
                self, model, criterion, batch, batch_back, compute_grad=False)

            for t in loss_tasks:
                scale[t] = self.params.fixed_scales[t]
        else:
            raise ValueError(f'Please choose between `MGDA` and `fixed`.')

        if len(loss_tasks) == 1:
            scale = {loss_tasks[0]: 1.0}
        blind_loss = self.scale_losses(loss_tasks, loss_values, scale)

        return blind_loss

    def scale_losses(self, loss_tasks, loss_values, scale):
        blind_loss = 0
        for it, t in enumerate(loss_tasks):
            self.params.running_losses[t].append(loss_values[t].item())
            self.params.running_scales[t].append(scale[t])
            if it == 0:
                blind_loss = scale[t] * loss_values[t]
            else:
                if (t == 'ortho_clean' or t == 'min_change') and loss_values[t] == 0:
                    pass
                else:
                    blind_loss += scale[t] * loss_values[t]
        self.params.running_losses['total'].append(blind_loss.item())
        return blind_loss

    def neural_cleanse_part1(self, model, batch, batch_back):
        self.nc_model.zero_grad()
        model.zero_grad()

        self.nc_model.switch_grads(True)
        model.switch_grads(False)
        output = model(self.nc_model(batch.inputs))
        nc_tasks = ['neural_cleanse_part1', 'mask_norm']

        criterion = torch.nn.CrossEntropyLoss(reduction='none')

        loss_values, grads = compute_all_losses_and_grads(nc_tasks,
                                                          self, model,
                                                          criterion, batch,
                                                          batch_back,
                                                          compute_grad=False
                                                          )
        # Using NC paper params
        logger.info(loss_values)
        loss = 0.999 * loss_values['neural_cleanse_part1'] + 0.001 * loss_values['mask_norm']
        loss.backward()
        self.nc_optim.step()

        self.nc_model.switch_grads(False)
        model.switch_grads(True)


    def fl_scale_update(self, local_update: Dict[str, torch.Tensor]):
        for name, value in local_update.items():
            value.mul_(self.params.fl_weight_scale)


    def patch_train(self, hlpr, local_epoch, model, train_loader, patch_s=32, x=0, y=0):
        # initialize patch (can have different ways)
        if type(patch_s) == int:
            patch_sx = patch_s
            patch_sy = patch_s
        else:
            patch_sy, patch_sx = patch_s

        patch = nn.Parameter(torch.zeros(3, patch_sy, patch_sx, requires_grad=True, device=hlpr.params.device))
        scale = {}
        for t in self.params.loss_tasks.copy():
            scale[t] = self.params.fixed_scales[t]

        optimizer = optim.Adam([patch],
                        lr=0.1)
                        # lr=self.params.lr,
                        # weight_decay=self.params.decay,
                        # momentum=self.params.momentum)

        model.eval()
        for params in model.parameters():
            params.requires_grad = False
        
        logger.warning("start optimizating the patch")

        total_batches = len(train_loader)
        for i in range(local_epoch):
            running_loss = {'feature': []}

            for ii, data in enumerate(train_loader):
                batch = hlpr.task.get_batch(i, data)
                # adv_image = batch.inputs + patch # try * later?
                batch.inputs[:, :, y:(y + patch_sy), x:(x + patch_sx)] = patch
                adv_image = batch.inputs
                
                optimizer.zero_grad()
                feature_loss = (-model.features_before_relu(adv_image)[:, self.target_bias]).mean()
                running_loss['feature'].append(feature_loss.item())
                # l2norm_loss = torch.linalg.norm(torch.flatten(patch), ord=2)
                # running_loss['l2norm'].append(l2norm_loss.item())
                # loss += TV
                
                loss = feature_loss * scale['feature']
                # loss += l2norm_loss * scale['l2norm']
                loss.backward()
                # patch.grad.sign_()
                optimizer.step()

                # force to be in -4, 4
                with torch.no_grad():
                    patch.clamp_(min=-4, max=4)

                losses = [f'{x}: {np.mean(y):.2f}'
                        for x, y in running_loss.items()]
                logger.warning(
                    f'Epoch: {i:3d}. '
                    f'Batch: {ii:5d}/{total_batches}. '
                    f' Losses: {losses}.')

        for params in model.parameters():
            params.requires_grad = True

        self.patch = patch.detach()
        self.patch_sx = patch_sx
        self.patch_sy = patch_sy
        self.patch_x = x
        self.patch_y = y


    def patch_test(self, hlpr, epoch, model=None, check_feature=False, attack=True, backdoor_class=None):
        if model is None:
            model = hlpr.task.model
        model.eval()
        hlpr.task.reset_metrics()

        total_batches = len(hlpr.task.test_loader)
        with torch.no_grad():
            for i, data in enumerate(hlpr.task.test_loader):
                batch = hlpr.task.get_batch(i, data)

                if attack:
                    batch.inputs[:, :, self.patch_y:(self.patch_y + self.patch_sy), self.patch_x:(self.patch_x + self.patch_sx)] = self.patch

                if backdoor_class:
                    batch.labels[batch.labels != -1] = backdoor_class

                outputs = model(batch.inputs)
                hlpr.task.accumulate_metrics(outputs=outputs, labels=batch.labels)

                if check_feature:
                    feat = model.features_before_relu(batch.inputs)[:, self.target_bias]
                    logger.warning(
                        f'Batch: {i:5d}/{total_batches}. '
                        f'Target Feat value: {feat.mean().item()}.')
                    
                    print(torch.stack([feat.amax([1, 2]), batch.labels, outputs.topk(1, 1, True, True)[1].squeeze()]).T)
                    # print(feat.amax([1, 2]))
                    # print(batch.labels)
                    # print(outputs.topk(1, 1, True, True)[1].T)

        tb_prefix = "with_patch_accuracy"

        metric = hlpr.task.report_metrics(epoch,
                                prefix=f'Epoch: ',
                                tb_writer=hlpr.tb_writer,
                                tb_prefix=tb_prefix)

        return metric
