'''
This is the unofficial implementation of M to N Attack [1]

Reference:
[1] M-to-N Backdoor Paradigm: A Multi-Trigger and Multi-Target Attack to Deep Learning Models TCSVT 2024.

Key ideas implemented:
- Trigger bank: for each target class l_k, randomly select M clean images (from train set of that class) as triggers.
- Grayscale triggers; concatenate with clean image along channel dim -> input to H.
- Three-network poisoned image generation framework: H (UNet-like), R (CEILNet-like), D (PatchGAN-like).
- Losses: L_H = λH1*L_V(x~,x) + λH2*L_F(VGG(x~), VGG(t_kj));
          L_R = ||R(x~) - t_hat||^2 + ||R(x) - δ||^2;
          L_D = E[log(1 - D(x~))] + E[log D(x)]  (non-saturating variants supported)
- Poisoning: uniformly sample trigger from MxN for each selected clean image (excluding target classes) with ratio.
- Evaluation: BA/ASR consistent with your Base._test and compute_asr() contract.

'''

import os
import random
from typing import Dict, Any, List, Tuple, Literal
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import DatasetFolder
from torchvision.transforms import Compose, ToTensor

from .base import Base


def to_gray(img: Tensor) -> Tensor:
    """
    img: (C,H,W) in [0,1]
    return: (1,H,W) gray
    """
    if img.dim() != 3:
        raise ValueError("Expect (C,H,W) tensor")
    if img.size(0) == 1:
        return img
    # standard luma (RGB)
    r, g, b = img[0], img[1], img[2]
    gray = 0.299 * r + 0.587 * g + 0.114 * b
    return gray.unsqueeze(0).clamp(0, 1)


def concat_img_with_trigger(x: Tensor, t_gray: Tensor) -> Tensor:
    """
    x: (C,H,W) clean; t_gray: (1,H,W) resized to same H,W beforehand
    returns: (C+1,H,W)
    """
    if x.shape[-2:] != t_gray.shape[-2:]:
        t_gray = F.interpolate(t_gray.unsqueeze(0), size=x.shape[-2:], mode="bilinear", align_corners=False).squeeze(0)
    return torch.cat([x, t_gray], dim=0)

class TriggerBank:
    """
    Maintain M triggers per target label.
    Triggers are tensors in (1,H,W) gray space stored CPU.
    """
    def __init__(self, dataset: DatasetFolder, target_labels: List[int], M: int, transform_to_tensor: Compose | None = None):
        
        self.dataset = dataset
        self.targets = target_labels
        self.M = M
        self.transform_to_tensor = transform_to_tensor or Compose([ToTensor()])
        self.bank: Dict[int, List[Tensor]] = {k: [] for k in self.targets}
        self._build()

    def _build(self):
        # collect indices for each target class
        label2idx: Dict[int, List[int]] = {k: [] for k in self.targets}
        for i, (_, y) in enumerate(self.dataset.samples):
            if y in label2idx:
                label2idx[y].append(i)

        # sample M per class and preprocess to gray (1,H,W)
        for y in self.targets:
            choices = random.sample(label2idx[y], k=min(self.M, len(label2idx[y])))
            self.bank[y] = []
            for idx in choices:
                path, _ = self.dataset.samples[idx]
                img = self.dataset.loader(path)  # PIL image expected
                x = self.transform_to_tensor(img)  # (C,H,W), [0,1]
                t = to_gray(x)  # (1,H,W)
                self.bank[y].append(t.cpu())

    def sample(self) -> Tuple[int, Tensor]:
        """
        Uniformly sample a target label and a trigger from its M set.
        Returns: (target_label, t_gray)
        """
        y = random.choice(self.targets)
        t = random.choice(self.bank[y])
        return y, t.clone()
    
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNetSmall(nn.Module):
    """
    H: encoder-decoder with skip connections; stride 1 to support 32x32.
    in_ch = C+1, out_ch = C  (poisoned image)
    """
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        f = 32
        self.down1 = DoubleConv(in_ch, f)
        self.down2 = DoubleConv(f, f*2)
        self.down3 = DoubleConv(f*2, f*4)
        self.mid   = DoubleConv(f*4, f*4)
        self.up3   = DoubleConv(f*4 + f*4, f*2)
        self.up2   = DoubleConv(f*2 + f*2, f)
        self.up1   = DoubleConv(f + f, out_ch)
        self.pool = nn.AvgPool2d(2)
        self.up   = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)

    def forward(self, x):
        d1 = self.down1(x)                 # 32
        d2 = self.down2(self.pool(d1))     # 16
        d3 = self.down3(self.pool(d2))     # 8
        m  = self.mid(self.pool(d3))       # 4

        # up to 8, concat with d3 (8)
        u3 = self.up(m)                    # 8
        u3 = self.up3(torch.cat([u3, d3], dim=1))   # 8

        # up to 16, concat with d2 (16)
        u2 = self.up(u3)                   # 16
        u2 = self.up2(torch.cat([u2, d2], dim=1))   # 16

        # up to 32, concat with d1 (32)
        u1 = self.up(u2)                   # 32
        u1 = torch.cat([u1, d1], dim=1)    # 32
        out = self.up1(u1)
        return torch.sigmoid(out)

class CEILNetLite(nn.Module):
    """
    R: recovery network to reconstruct grayscale trigger from poisoned or clean image.
    in_ch = C, out_ch = 1
    """
    def __init__(self, in_ch: int):
        super().__init__()
        f = 32
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, f, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(f, f, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(f, f, 3, padding=1), nn.ReLU(True),
            nn.Conv2d(f, 1, 1),
            nn.Sigmoid()
        )
    def forward(self, x): return self.net(x)

class PatchDiscriminator(nn.Module):
    """
    D: PatchGAN judge clean vs poisoned.
    in_ch = C
    """
    def __init__(self, in_ch: int):
        super().__init__()
        def block(ic, oc, k=4, s=2, p=1, bn=True):
            layers = [nn.Conv2d(ic, oc, k, s, p), nn.LeakyReLU(0.2, inplace=True)]
            if bn: layers.insert(1, nn.BatchNorm2d(oc))
            return nn.Sequential(*layers)

        f = 64
        self.net = nn.Sequential(
            block(in_ch, f, bn=False),
            block(f, f*2),
            block(f*2, f*4),
            nn.Conv2d(f*4, 1, 3, padding=1)  # logits map
        )
    def forward(self, x): return self.net(x)


class ProvidedVGGFeature(nn.Module):
    """
    Wrap user's pretrained VGG.features (nn.Sequential) as a fixed feature extractor.

    Args:
        features_seq (nn.Sequential): your VGG(...).features
        cut_index (int or None): cut the sequential at this index (exclusive).
                                 e.g., 16 ~ relu3_3-ish for VGG-D; leave as None to use the full features.
        in_norm (bool): whether to normalize inputs; default is False since your pretrain may not be on ImageNet.
        mean/std: only used when in_norm=True; defaults to ImageNet statistics.
        freeze (bool): True -> do not train (L_F is only used for measurement), more stable.
    """
    def __init__(
        self,
        features_seq: nn.Sequential,
        cut_index: int | None = 16,
        in_norm: bool = False,
        mean=(0.485,0.456,0.406),
        std=(0.229,0.224,0.225),
        freeze: bool = True
    ):
        super().__init__()
        if cut_index is None:
            self.features = features_seq
        else:
            self.features = nn.Sequential(*list(features_seq.children())[:cut_index])

        if freeze:
            for p in self.features.parameters():
                p.requires_grad = False

        self.in_norm = in_norm
        self.register_buffer("mean", torch.tensor(mean).view(1,3,1,1), persistent=False)
        self.register_buffer("std",  torch.tensor(std ).view(1,3,1,1), persistent=False)

    def forward(self, x):  # x in [0,1], shape (B,C,H,W)
        if x.size(1) == 1:
            x = x.repeat(1,3,1,1)
        if self.in_norm:
            x = (x - self.mean) / self.std
        return self.features(x)


class PoisonGenTrainer:
    """
    Train H,R,D on-the-fly to learn how to embed grayscale triggers invisibly.

    Usage:
        trainer = PoisonGenTrainer(C=in_ch, device='cuda:0', lambdas=..., lrs=...)
        trainer.fit(train_dataset, trigger_bank, iters=I, batch_size=..., exclude_labels=target_labels)

    After training:
        trainer.embed_batch(clean_imgs, t_gray, y_target) -> poisoned_imgs
    """
    def __init__(
        self,
        C: int,
        feat_extractor: nn.Module,
        device: str | torch.device = 'cpu',
        lambdas: Dict[str, float] = None,
        lrs: Dict[str, float] = None,
        vgg_layer_cut: int = 16,
        adv_non_saturating: bool = True,
    ):
        self.device = torch.device(device)
        self.C = C
        self.H = UNetSmall(in_ch=C+1, out_ch=C).to(self.device)
        self.R = CEILNetLite(in_ch=C).to(self.device)
        self.D = PatchDiscriminator(in_ch=C).to(self.device)

        # === feature extractor ===
        
        self.VGG = feat_extractor.to(self.device).eval()
       

        self.adv_non_sat = adv_non_saturating
        self.lmb = lambdas or {'H1':1.0,'H2':1.0,'R':1.0,'D':0.01,'H':1.0}
        lrs = lrs or {'H':2e-4,'R':2e-4,'D':2e-4}
        self.optH = torch.optim.Adam(self.H.parameters(), lr=lrs['H'], betas=(0.5,0.999))
        self.optR = torch.optim.Adam(self.R.parameters(), lr=lrs['R'], betas=(0.5,0.999))
        self.optD = torch.optim.Adam(self.D.parameters(), lr=lrs['D'], betas=(0.5,0.999))

        self.blank_cache = {}

    def _blank(self, size: Tuple[int,int]) -> Tensor:
        if size not in self.blank_cache:
            self.blank_cache[size] = torch.zeros(1, *size, device=self.device)
        return self.blank_cache[size]

    @torch.no_grad()
    def embed_batch(self, x: Tensor, t_gray: Tensor) -> Tensor:
        """
        x: (B,C,H,W) in [0,1]; t_gray: (1,H,W) or (B,1,H,W)
        """
        B, C, H, W = x.shape
        if t_gray.dim() == 3:  # (1,H,W)
            t_gray = t_gray.unsqueeze(0).repeat(B,1,1,1)
        inp = torch.cat([x, t_gray], dim=1)
        x_tilde = self.H(inp)
        return x_tilde.clamp(0,1)

    def fit(
        self,
        gen_dataset: DatasetFolder,
        trigger_bank: TriggerBank,
        iters: int = 40000,
        batch_size: int = 64,
        num_workers: int = 4,
        exclude_labels: List[int] = None,
        lr_decay_patience: int = 5,
        lr_decay_gamma: float = 0.2
    ):
        """
        Train H/R/D with mini-batches: each step samples a trigger and a clean batch (excluding target labels).
        """

        device = self.device

        loader = DataLoader(gen_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

        # simple cycling loader
        it = iter(loader)

        best_loss = float('inf')
        no_improve = 0

        for step in range(1, iters+1):
            try:
                batch = next(it)
            except StopIteration:
                it = iter(loader)
                batch = next(it)

            imgs, _ = batch
            imgs = imgs.to(device)  # (B,C,H,W) in [0,1] if your transform includes ToTensor()

            # sample trigger
            y_tgt, t_gray = trigger_bank.sample()
            t_gray = t_gray.to(device)  # (1,H,W)
            B = imgs.size(0)
            t_batch = t_gray.unsqueeze(0).repeat(B,1,1,1)

            # forward
            inp = torch.cat([imgs, t_batch], dim=1)         # (B,C+1,H,W)
            x_tilde = self.H(inp)                           # poisoned (B,C,H,W)

            # L_V (pixel)
            L_V = F.mse_loss(x_tilde, imgs)

            # L_F (feature distance)
            with torch.no_grad():
                t_for_vgg = t_batch.repeat(1,3,1,1)         # 3ch for VGG
            feat_poison = self.VGG(x_tilde)
            feat_trigger = self.VGG(t_for_vgg)
            L_F = F.mse_loss(feat_poison, feat_trigger)

            # L_R (recover)
            rec_poison = self.R(x_tilde)
            rec_clean  = self.R(imgs)
            blank = self._blank((imgs.size(-2), imgs.size(-1))).repeat(B,1,1,1)
            L_R = F.mse_loss(rec_poison, t_batch) + F.mse_loss(rec_clean, blank)

            # L_D (adversarial)
            # train D
            self.optD.zero_grad(set_to_none=True)
            logit_fake = self.D(x_tilde.detach())
            logit_real = self.D(imgs.detach())
            if self.adv_non_sat:
                # non-saturating GAN loss
                L_D = F.softplus(logit_fake).mean() + F.softplus(-logit_real).mean()
            else:
                L_D = -(torch.log(torch.sigmoid(logit_real) + 1e-8).mean()
                        + torch.log(1 - torch.sigmoid(logit_fake) + 1e-8).mean())
            L_D.backward()
            self.optD.step()

            # train H & R jointly with total loss
            self.optH.zero_grad(set_to_none=True)
            self.optR.zero_grad(set_to_none=True)

            L_H = self.lmb['H1'] * L_V + self.lmb['H2'] * L_F
            # generator adversarial term (fool D)
            logit_fake = self.D(x_tilde)
            if self.adv_non_sat:
                L_adv = F.softplus(-logit_fake).mean()
            else:
                L_adv = -torch.log(torch.sigmoid(logit_fake) + 1e-8).mean()

            total = self.lmb['H'] * L_H + self.lmb['R'] * L_R + self.lmb['D'] * L_adv
            total.backward()
            self.optH.step()
            self.optR.step()

            # naive plateau LR decay on total
            cur = total.item()
            if cur + 1e-6 < best_loss:
                best_loss = cur
                no_improve = 0
            else:
                no_improve += 1
                if no_improve >= lr_decay_patience:
                    for opt in [self.optH, self.optR]:
                        for g in opt.param_groups:
                            g['lr'] *= lr_decay_gamma
                    no_improve = 0

    @torch.no_grad()
    def poison_one(self, img: Tensor, t_gray: Tensor) -> Tensor:
        return self.embed_batch(img.unsqueeze(0), t_gray).squeeze(0)



class MToNPoisonedDataset(DatasetFolder):
    """
    Wrap benign DatasetFolder; poison a subset with ratio poisoned_rate.
    Strict 1/N assignment over target_labels for poisoned samples.
    """
    def __init__(
        self,
        benign: DatasetFolder,
        trigger_bank: TriggerBank,
        embedder: PoisonGenTrainer,
        poisoned_rate: float,
        target_labels: List[int],
        mode: Literal['DIRTY','CLEAN'] = 'DIRTY'
    ):
        super().__init__(benign.root, benign.loader, benign.extensions,
                         benign.transform, benign.target_transform, None)
        self.trigger_bank = trigger_bank
        self.embedder = embedder
        self.poisoned_rate = float(poisoned_rate)
        self.target_labels = list(sorted(set(target_labels)))
        self.target_labels_set = set(self.target_labels)
        self.mode = mode.upper()
        assert self.mode in ('DIRTY', 'CLEAN')

        # ---------- 1) Select candidate set ----------
        if self.mode == 'DIRTY':
            # Poison all targets: all samples can be selected
            cand = list(range(len(self.samples)))
        else:  # CLEAN
            # Only poison samples with target labels (labels remain unchanged)
            cand = [i for i, (_, y) in enumerate(self.samples) if y in self.target_labels_set]

        # ---------- 2) Sample poison_set ----------
        K_all = int(len(self.samples) * self.poisoned_rate)
        K = min(K_all, len(cand))
        if K <= 0:
            self.poison_set = set()
            self.idx2tgt = {}
            print(f'[+] MtoNAttack: poison_set size = 0 / {len(self.samples)} (mode={self.mode})')
        else:
            pick = random.sample(cand, k=K)
            random.shuffle(pick)  # Shuffle for even distribution

            # ---------- 3) Strict 1/N even assignment ----------
            N = len(self.target_labels)
            base = K // N
            rem = K % N

            self.idx2tgt: Dict[int, int] = {}
            pos = 0
            for j, y_tgt in enumerate(self.target_labels):
                take = base + (1 if j < rem else 0)
                for idx in pick[pos:pos+take]:
                    self.idx2tgt[idx] = y_tgt
                pos += take

            self.poison_set = set(self.idx2tgt.keys())

            # Print distribution statistics
            distr = {y: 0 for y in self.target_labels}
            for _, y in self.idx2tgt.items():
                distr[y] += 1
            print(f'[+] MtoNAttack: poison_set size = {len(self.poison_set)} / {len(self.samples)} (mode={self.mode})')
            print(f'[+] MtoNAttack: per-target counts (strict 1/N): {distr}')

        # ---------- 4) Transform ----------
        self.benign_transform = deepcopy(self.transform)
        self.poison_transform = deepcopy(self.transform)

    def _sample_trigger_for(self, y_tgt: int) -> Tensor:
        # Randomly select a trigger from the trigger pool for the target class
        pool = self.trigger_bank.bank[y_tgt]
        return random.choice(pool).clone()

    def __getitem__(self, index):
        path, y = self.samples[index]
        img = self.loader(path)
        x = self.benign_transform(img) if self.benign_transform is not None else ToTensor()(img)

        if index in self.poison_set:
            # Use the pre-assigned target label for this index
            y_tgt = self.idx2tgt[index]
            t_gray = self._sample_trigger_for(y_tgt)
            x_tilde = self.embedder.embed_batch(
                x.unsqueeze(0).to(self.embedder.device),
                t_gray.to(self.embedder.device)
            ).cpu().squeeze(0)

            if self.mode == 'DIRTY':
                y = y_tgt  # Dirty label: change label to target class
            # CLEAN: keep original y unchanged
            return x_tilde, y
        else:
            return x, y

# -----------------------------
# 6) Attack class – integrates with your Base
# -----------------------------

class MtoNAttack(Base):
    def __init__(
        self,
        train_dataset: DatasetFolder,
        test_dataset: DatasetFolder,
        model: nn.Module,
        loss: nn.Module,
        vgg_model: nn.Module,
        # M-to-N args:
        target_labels: List[int],
        M: int = 1,
        poisoned_rate: float = 0.02,
        # poisoned image generation args:
        gen_iters: int = 40000,
        gen_dataset: DatasetFolder | None = None,
        gen_batch_size: int = 64,
        gen_num_workers: int = 4,
        lambdas: Dict[str, float] = None,
        lrs: Dict[str, float] = None,
        # label mode:
        label_mode: Literal['DIRTY','CLEAN'] = 'DIRTY',
        # schedule & seed:
        schedule: Dict[str,Any] | None = None,
        seed: int = 0,
        deterministic: bool = False,
        device: str | torch.device = 'cpu'
    ):
        super().__init__(train_dataset, test_dataset, model, loss, schedule, seed, deterministic)

        self.target_labels = list(sorted(target_labels))
        assert len(self.target_labels) > 0, "Provide at least one target label (N>=1)."
        self.M = M
        self.poisoned_rate = poisoned_rate
        self.label_mode = label_mode
        self.device = device

        if gen_dataset is None:
            gen_dataset = train_dataset


        # 1) Build trigger bank
        self.trigger_bank = TriggerBank(self.train_dataset, self.target_labels, self.M)

        # 2) Train poison generation framework H/R/D

        vgg_feature_extractor = ProvidedVGGFeature(vgg_model.features)
        self.embedder = PoisonGenTrainer(
            C=self._infer_channels(self.train_dataset),
            device=device,
            feat_extractor=vgg_feature_extractor,
            lambdas=lambdas,
            lrs=lrs
        )

        # optional log
        work_dir = schedule['save_dir'] if schedule else './work'
        os.makedirs(work_dir, exist_ok=True)

        self.embedder.fit(
            gen_dataset=gen_dataset,
            trigger_bank=self.trigger_bank,
            iters=gen_iters,
            batch_size=gen_batch_size,
            num_workers=gen_num_workers,
            exclude_labels=self.target_labels,
        )

        # 3) Build poisoned datasets for training/test
        self.poisoned_train_dataset = MToNPoisonedDataset(
            benign=self.train_dataset,
            trigger_bank=self.trigger_bank,
            embedder=self.embedder,
            poisoned_rate=self.poisoned_rate,
            target_labels=self.target_labels,
            mode=self.label_mode
        )

        # poisoned test: always fully poisoned to compute ASR per standard
        self.poisoned_test_dataset = MToNPoisonedDataset(
            benign=self.test_dataset,
            trigger_bank=self.trigger_bank,
            embedder=self.embedder,
            poisoned_rate=1.0,
            target_labels=self.target_labels,
            mode='DIRTY'  # For ASR, always use dirty label: input with trigger is assigned the corresponding target label
        )

    def _infer_channels(self, ds: DatasetFolder) -> int:
        # Try to read one sample
        path, _ = ds.samples[0]
        x = ds.loader(path)
        t = (ds.transform or ToTensor())(x)
        return t.size(0)

    # For Base.compute_asr() contract: return predict_logits, labels, mean_loss on poisoned test set
    def compute_asr(self):
        return self._test(self.poisoned_test_dataset, self.device, num_workers=0)