import copy
import torch

from ..attack import Attack


class MultiAttack(Attack):
    r"""
    MultiAttack is a class to attack a model with various attacks agains same images and labels.

    Arguments:
        model (nn.Module): model to attack.
        attacks (list): list of attacks.

    Examples::
        >>> atk1 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
        >>> atk2 = torchattacks.PGD(model, eps=8/255, alpha=2/255, iters=40, random_start=True)
        >>> atk = torchattacks.MultiAttack([atk1, atk2])
        >>> adv_images = attack(images, labels)

    """
    def __init__(self, attacks, verbose=False):

        # Check validity
        ids = []
        for attack in attacks:
            ids.append(id(attack.model))

        if len(set(ids)) != 1:
            raise ValueError("At least one of attacks is referencing a different model.")

        super().__init__("MultiAttack", attack.model)
        self.attacks = attacks
        self.verbose = verbose
        self._accumulate_multi_atk_records = False
        self._multi_atk_records = [0.0]
        self._supported_mode = ['default']

    def forward(self, images, labels):
        r"""
        Overridden.
        """
        batch_size = images.shape[0]
        fails = torch.arange(batch_size).to(self.device)
        final_images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        multi_atk_records = [batch_size]

        for _, attack in enumerate(self.attacks):
            adv_images = attack(images[fails], labels[fails])

            outputs = self.model(adv_images)
            _, pre = torch.max(outputs.data, 1)

            corrects = (pre == labels[fails])
            wrongs = ~corrects

            succeeds = torch.masked_select(fails, wrongs)
            succeeds_of_fails = torch.masked_select(torch.arange(fails.shape[0]).to(self.device), wrongs)

            final_images[succeeds] = adv_images[succeeds_of_fails]

            fails = torch.masked_select(fails, corrects)
            multi_atk_records.append(len(fails))

            if len(fails) == 0:
                break

        if self.verbose:
            print(self._return_sr_record(multi_atk_records))

        if self._accumulate_multi_atk_records:
            self._update_multi_atk_records(multi_atk_records)

        return final_images

    def _clear_multi_atk_records(self):
        self._multi_atk_records = [0.0]

    def _covert_to_success_rates(self, multi_atk_records):
        sr = [((1-multi_atk_records[i]/multi_atk_records[0])*100) for i in range(1, len(multi_atk_records))]
        return sr

    def _return_sr_record(self, multi_atk_records):
        sr = self._covert_to_success_rates(multi_atk_records)
        return "Attack success rate: "+" | ".join(["%2.2f %%"%item for item in sr])

    def _update_multi_atk_records(self, multi_atk_records):
        for i, item in enumerate(multi_atk_records):
            self._multi_atk_records[i] += item

    def save(self, data_loader, save_path=None, verbose=True, return_verbose=False):
        r"""
        Overridden.
        """
        if (verbose==False) and (return_verbose==True):
            raise ValueError("Verobse should be True if return_verbose==True.")
            
        self._clear_multi_atk_records()
        prev_verbose = self.verbose
        self.verbose = False
        self._accumulate_multi_atk_records = True

        for i, attack in enumerate(self.attacks):
            self._multi_atk_records.append(0.0)

        if verbose:
            rob_acc, l2, elapsed_time = super().save(data_loader, save_path, verbose=True, return_verbose=True)
            sr = self._covert_to_success_rates(self._multi_atk_records)
        else:
            super().save(data_loader, save_path, verbose=False, return_verbose=False)
            
        self._clear_multi_atk_records()
        self._accumulate_multi_atk_records = False
        self.verbose = prev_verbose

        if return_verbose:
            return rob_acc, sr, l2, elapsed_time

    def _save_print(self, progress, rob_acc, l2, elapsed_time, end):
        r"""
        Overridden.
        """
        print("- Save progress: %2.2f %% / Robust accuracy: %2.2f %%"%(progress, rob_acc)+\
              " / "+self._return_sr_record(self._multi_atk_records)+\
              ' / L2: %1.5f (%2.3f it/s) \t'%(l2, elapsed_time), end=end)
