import time
from typing import Tuple

import torch
from torch import Tensor

from utils import mask_x, AdvCfg

from .constant import *
from .abclass import Lw_1


class Lw_algo(Lw_1):
    def __init__(
        self,
        name: str,
        device: str,
        up: Lw_1,
        lw_level: int,
        mask: Tuple[int, int],
        adv_cfg: AdvCfg,
        count_delta: int = None
    ) -> None:
        if lw_level <= ALGO_FWD:
            assert count_delta == 0
        self.lw_level = lw_level
        self.mask = mask
        self.adv_cfg = adv_cfg
        self.count = 1
        if self.adv_cfg.name == 'clean':
            self.adv_cfg.n = 0
            self.adv_cfg.rand_start = False
            self.algo = self.algo_pgd
        elif self.adv_cfg.name == 'cer':
            self.adv_cfg.n = 0
            self.adv_cfg.rand_start = True
            assert self.adv_cfg.epsilon is not None
            self.algo = self.algo_pgd
        elif self.adv_cfg.name == 'fgsm':
            self.adv_cfg.n = 1
            self.adv_cfg.rand_start = False
            self.adv_cfg.sigma = self.adv_cfg.epsilon
            assert self.adv_cfg.epsilon is not None
            self.algo = self.algo_pgd
        elif self.adv_cfg.name == 'pgd':
            assert self.adv_cfg.n is not None
            assert self.adv_cfg.epsilon is not None
            assert self.adv_cfg.sigma is not None
            self.algo = self.algo_pgd
        elif self.adv_cfg.name == 'dp':
            assert self.adv_cfg.n is not None
            assert self.adv_cfg.m is not None
            assert self.adv_cfg.epsilon is not None
            assert self.adv_cfg.sigma is not None
            self.dpgd_adv_x = list()
            self.dpgd_time_ext = 0.
            self.dpgd_time_1 = 0.
            self.algo = self.algo_dpgd
        elif self.adv_cfg.name == 'ours':
            assert self.adv_cfg.n is not None
            assert self.adv_cfg.m is not None
            assert self.adv_cfg.epsilon is not None
            assert self.adv_cfg.sigma is not None
            self.algo = self.algo_ours
        elif self.adv_cfg.name == 'freeat':
            assert self.adv_cfg.n is not None
            assert self.adv_cfg.epsilon is not None
            assert self.adv_cfg.sigma is not None
            self.global_eta = None
            self.algo = self.algo_freeat
        elif self.adv_cfg.name == 'freelb':
            assert self.adv_cfg.n is not None
            assert self.adv_cfg.epsilon is not None
            assert self.adv_cfg.sigma is not None
            self.adv_cfg.sigma /= self.adv_cfg.n
            self.algo = self.algo_freelb
        else:
            raise ValueError(self.adv_cfg.name)
        super().__init__(name, device, up, 0, 1, 'sync', count_delta)

    def masked(self, x: Tensor) -> Tensor:
        _i, _n = self.mask
        return mask_x(x, _i, _n)

    def put_to_up_data(self) -> None:
        raise ValueError('put_to_up_data')

    def act(self, d: Tuple[Tensor, Tensor]) -> None:
        x, labels = d
        if x is None:
            self._sync(labels)
            self.put_to_down_cache(self.count)
        else:
            x = x.to(self.device)
            self.put_to_down_cache(self.count)
            self.algo(x, labels)

        if self.lw_level <= ALGO_BATCH:
            self.count += self.count_delta + 2

    def _sync(self, end_level: int = ALGO_MSG) -> None:
        self._put_to_up(gen_sync_data(self.count, end_level))
        self.sync(self.count)
        if end_level == ALGO_EPOCH and self.adv_cfg.name == 'dp':
            self.prt(0, self.count, 'DP_EXT_TIME', self.dpgd_time_ext)
            self.dpgd_time_ext = 0.

    def zg(self) -> None:
        if self.lw_level == ALGO_MSG:
            self._sync()
        self._put_to_up(gen_zg_data(self.count))

    def step(self) -> None:
        if self.lw_level == ALGO_MSG:
            self._sync()
        self._put_to_up(gen_step_data(self.count))

    def fwd(
        self,
        x: Tensor,
        labels: Tensor,
        grad_lv: int = GRAD_BWD,
        append_log: bool = True
    ) -> None:
        if self.lw_level <= ALGO_FWD:
            self._sync()
        self.count += 1
        self._put_to_up(gen_fwd_data(
            self.count,
            append_log,
            grad_lv,
            self.masked(x).clone(),
            labels
        ))

    def ours(
        self,
        x: Tensor,
        grad_lv: int = GRAD_BWD
    ) -> None:
        self._put_to_up(gen_ours_data(
            self.count,
            grad_lv,
            self.masked(x).clone(),
        ))

    def algo_pgd(self, x: Tensor, labels: Tensor) -> None:
        self.zg()
        if self.adv_cfg.rand_start:
            eta = torch.randn_like(x).uniform_(
                -self.adv_cfg.epsilon, self.adv_cfg.epsilon
            )
            adv_x = x + eta
            adv_x.clamp_(0., 1.)
        else:
            eta = torch.zeros_like(x)
            adv_x = x + eta

        for _i in range(self.adv_cfg.n):
            self.fwd(adv_x, None if _i else labels, GRAD_ONLY_X, False)

            x_grad = self.get_from_up_grad(self.count)

            eta += x_grad.sign() * self.adv_cfg.sigma
            eta.clamp_(-self.adv_cfg.epsilon, self.adv_cfg.epsilon)
            adv_x = x + eta
            adv_x.clamp_(0., 1.)

        self.zg()
        if self.adv_cfg.n:
            self.fwd(adv_x, None)
        else:
            self.fwd(adv_x, labels)
            _x_grad = self.get_from_up_grad(self.count)
        self.step()

    def algo_dpgd(self, x: Tensor, labels: Tensor) -> None:
        _start = time.time()

        self.zg()
        if self.adv_cfg.rand_start:
            eta = torch.randn_like(x).uniform_(
                -self.adv_cfg.epsilon, self.adv_cfg.epsilon
            )
            adv_x = x + eta
            adv_x.clamp_(0., 1.)
        else:
            eta = torch.zeros_like(x)
            adv_x = x + eta

        for _i in range(self.adv_cfg.n):
            self.fwd(adv_x, None if _i else labels, GRAD_ONLY_X, False)

            x_grad = self.get_from_up_grad(self.count)

            eta += x_grad.sign() * self.adv_cfg.sigma
            eta.clamp_(-self.adv_cfg.epsilon, self.adv_cfg.epsilon)
            adv_x = x + eta
            adv_x.clamp_(0., 1.)

        self.dpgd_adv_x.append((adv_x, labels))

        if len(self.dpgd_adv_x) == self.adv_cfg.m:
            for adv_x, labels in self.dpgd_adv_x:
                self.zg()
                self.fwd(adv_x, labels)
                _x_grad = self.get_from_up_grad(self.count)
                self.step()
            self.dpgd_adv_x = list()
            self.dpgd_time_ext += self.dpgd_time_1
            self.dpgd_time_1 = 0.
        else:
            self.dpgd_time_1 += time.time() - _start

    def algo_ours(self, x: Tensor, labels: Tensor) -> None:
        self.zg()
        if self.adv_cfg.rand_start:
            eta = torch.randn_like(x).uniform_(
                -self.adv_cfg.epsilon, self.adv_cfg.epsilon
            )
            adv_x = x + eta
            adv_x.clamp_(0., 1.)
        else:
            eta = torch.zeros_like(x)
            adv_x = x + eta

        for _i in range(self.adv_cfg.m):
            self.fwd(adv_x, None if _i else labels, GRAD_BWD, _i == 0)

            for _j in range(self.adv_cfg.n):
                self.ours(
                    adv_x,
                    GRAD_BWD if _j == self.adv_cfg.n-1 else GRAD_ONLY_X
                )

                x_grad = self.get_from_up_grad(self.count)

                eta += x_grad.sign() * self.adv_cfg.sigma
                eta.clamp_(-self.adv_cfg.epsilon, self.adv_cfg.epsilon)
                adv_x = x + eta
                adv_x.clamp_(0., 1.)

        self.step()

    def algo_freeat(self, x: Tensor, labels: Tensor) -> None:
        if self.global_eta is None:
            if self.adv_cfg.rand_start:
                eta = torch.randn_like(x).uniform_(
                    -self.adv_cfg.epsilon, self.adv_cfg.epsilon
                )
            else:
                eta = torch.zeros_like(x)
        else:
            eta = self.global_eta

        adv_x = x + eta
        adv_x.clamp_(0., 1.)

        for _i in range(self.adv_cfg.n):
            self.zg()
            self.fwd(adv_x, None if _i else labels, GRAD_BWD, _i == 0)

            x_grad = self.get_from_up_grad(self.count)

            self.step()

            eta += x_grad.sign() * self.adv_cfg.sigma
            eta.clamp_(-self.adv_cfg.epsilon, self.adv_cfg.epsilon)
            adv_x = x + eta
            adv_x.clamp_(0., 1.)

        self.global_eta = eta.detach()

    def algo_freelb(self, x: Tensor, labels: Tensor) -> None:
        self.zg()
        if self.adv_cfg.rand_start:
            eta = torch.randn_like(x).uniform_(
                -self.adv_cfg.epsilon, self.adv_cfg.epsilon
            )
            eta /= torch.sqrt(Tensor([x[0].numel()])).to(eta)
            adv_x = x + eta
            adv_x.clamp_(0., 1.)
        else:
            eta = torch.zeros_like(x)
            adv_x = x + eta

        def project(x: Tensor, eps: Tensor) -> Tensor:
            # project X on the ball of radius eps supposing first dim is batch
            dims = list(range(1, x.dim()))
            norms = torch.sqrt(torch.sum(x*x, dim=dims, keepdim=True))
            return torch.min(norms.new_ones(norms.shape), eps/norms) * x

        for _i in range(self.adv_cfg.n):
            self.fwd(adv_x, None if _i else labels, GRAD_BWD, _i == 0)

            x_grad = self.get_from_up_grad(self.count)

            _norm = torch.norm(x_grad.detach())
            eta += x_grad.detach() * self.adv_cfg.sigma / _norm
            eta = project(eta, self.adv_cfg.epsilon)
            eta.clamp_(-self.adv_cfg.epsilon, self.adv_cfg.epsilon)
            adv_x = x + eta
            adv_x.clamp_(0., 1.)

        self.step()
