"""
This is the implement of Arcueid.

Arcueid still based on trigger-optimized backdoor attack.
But it focus on the All-to-all attack, like WaNet.

Arcueid: 

Version: 0.1.0: 
    Initialize the common workflow of Arcueid

"""

import os
from PIL import Image
from typing import Dict, Any, Literal, List
import copy
from tqdm import tqdm

import torch
from torch import nn, Tensor
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.transforms import Compose, ToTensor, PILToTensor, RandomAffine, ColorJitter, GaussianBlur
from torchvision.datasets import DatasetFolder
from torchvision.utils import save_image
import numpy as np

from .base import *
from collections import defaultdict

class TrueAncestor:
    """
    TrueAncestor class for applying a specific trigger pattern to images.

    This class inherits from the Trigger base class and implements the 
    functionality to blend a watermark (trigger) to an input image based on 
    a given pattern and weight. The pattern can be a 2D or 3D tensor, 
    and the weight determines the influence of the pattern on the final 
    watermarked image.

    Attributes:
        pattern (Tensor): The trigger pattern to be applied to the image.
        weight (Tensor): The weight tensor that controls the blending of 
                         the pattern with the input image.

    Methods:
        __call__(img: Tensor): Applies the trigger to the input 
                                          image and returns the watermarked 
                                          image tensor.
    """
    def __init__(self, pattern: Tensor, weight: Tensor):
        
        super().__init__()
        
        if pattern.dim() == 2:
            pattern = pattern.unsqueeze(0)
        
        if pattern.dim() != 3:
            raise ValueError('pattern shape should be 2 or 3')
        
        if weight.dim() == 2:
            weight = weight.unsqueeze(0)
            
        if weight.dim() != 3:
            raise ValueError('weight shape should be 2 or 3')
        
        
        self.pattern = pattern
        self.weight = weight
        
        
        # handling float type with range [0.0, 1.0]
        # and restrict the range of the watermarked image to [0.0, 1.0]
        
        print(f"TrueAncestor initialized with pattern shape: {pattern.shape}, weight shape: {weight.shape}")
    
    @classmethod
    def add_trigger(cls, img: Tensor, pattern: Tensor, weight: Tensor):
        # restrict the range of the perturbed image to [0.0, 1.0]
        return torch.clamp(weight * img + (1.0 - weight) * pattern, 0.0, 1.0)
    
    def __call__(self, img: Tensor):

        """
        Adding a watermark on a benign image and then blend them as a trigger
        
        Args:
            img (Tensor): Input image tensor with shape (H, W) or (C, H, W).
        Returns:
            Tensor: Watermarked image tensor.
        """
         
        if img.dim() == 2:
            # H x W
            img = img.unsqueeze(0)
            img = self.add_trigger(img, self.pattern, self.weight)
            img = img.squeeze()
        elif img.dim() == 3:
            # C x H x W
            img = self.add_trigger(img, self.pattern, self.weight)
        else:
            raise ValueError('Input image shape should be 2 or 3')    
        
        return img.to(torch.float32)

class ModifyTarget:
    def __init__(self, y_target):
        self.y_target = y_target

    def __call__(self, y_target):
        return self.y_target

class MarblePhantasm:
    """
    MarblePhantasm is a class for finding the true TrueAncestors.
    More details:
        - Used to find separate TrueAncestor for each target attack in All-to-all attack.
    """
    def __init__(
        self, 
        dataset: Dataset,
        model: nn.Module,
        trigger_info: Dict[str, Any], 
        true_ancestors_num: int = 10,
        device: str | torch.device = 'cpu',
        train_scale: float = 0.3,
        alpha: float = 1.0,
        beta: float = 1.0,
        margin: float = 6.0,
        **kwargs, # for compatibility
        ):
        
        self.dataset = dataset
        self.model = model
        # self.pattern = trigger_info['pattern']
        
        self.trigger_info = trigger_info
        self.true_ancestors_num = true_ancestors_num

        self.device = device
        
        self.train_scale = train_scale
        # distill the dataset to train scale
        self.dataset = Subset(self.dataset, indices=torch.randperm(len(self.dataset))[:int(len(self.dataset) * self.train_scale)])
        
        self.alpha = alpha
        self.beta = beta
        self.margin = margin
        
        self.model.to(self.device)
        self.model.eval()
        # hook toolkits
        # Register a hook to capture the input tensor of our need
        self.features_cache = None
        self._register_hook()

        self.centroids = self._compute_centroids()
    
    @torch.no_grad()
    def _compute_centroids(self, batch_size: int = 256, num_workers: int = 4):
        """
        Forward clean dataset to get the centroids
        """
        loader = DataLoader(self.dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)
        # accumulate the features for each label
        sums: Dict[int, Tensor] = defaultdict(lambda: 0.)
        counts: Dict[int, int] = defaultdict(int)

        self.model.eval()
        for imgs, labels in loader:
            imgs = imgs.to(self.device)
            labels = labels.to(self.device)

            feats, _ = self.get_features(imgs)      # (B, d)
            for lbl in labels.unique():
                mask = labels == lbl
                sums[int(lbl)] += feats[mask].sum(0)
                counts[int(lbl)] += mask.sum().item()

        # get the mean of the features for each label
        all_labels = sorted(counts.keys())
        centroids = torch.stack([sums[i] / counts[i] for i in all_labels])  # (K, d)
        return centroids.to(self.device)  

    def __call__(self, steps: int, lr: float = 0.05):
        
        #  # initialize the trigger or use the trigger offered
        # pattern = self.pattern
        # # pattern = pattern.to(self.device)
        # pattern.requires_grad = True

        optim_pattern_params = nn.ParameterList(
            [copy.deepcopy(self.trigger_info['pattern']).requires_grad_(True) for _ in range(self.true_ancestors_num)]
        )
        # Optimize the trigger arguments in need
        # optimizer = optim.Adam([pattern], lr=lr)
        optimizer = optim.Adam(optim_pattern_params, lr=lr)
        
        print(f"Starting trigger optimization for {steps} steps with learning rate {lr}")
        
        for epoch in range(steps):  # we optimize steps epochs
            for batch in tqdm(DataLoader(self.dataset, batch_size=128, shuffle=True, num_workers=4), desc=f"Epoch {epoch} optimizing trigger"):
                images, labels = batch

                
                poisoned_images, pids = self.attach_triggers(images, optim_pattern_params)
                # print(f'former labels: {labels}')
                # if dirty_label >= 0:
                #     labels = torch.full_like(labels, dirty_label)
                #     print(f'latter labels: {labels}')
                
                poisoned_images = poisoned_images.to(self.device)
                labels = labels.to(self.device)
                pids = pids.to(self.device)
                
                # Get the input tensor of last fc layer
                features, logits = self.get_features(poisoned_images)
                # print the labels and logits based on the pids
                for pid in pids.unique():
                    mask = (pids == pid)
                    # print(f'pid: {pid}')
                    # print(f'label: {labels[mask].unique()}, pred: {logits[mask].argmax(dim=1).unique()}')

                # Compute the loss
                # loss = self.custom_loss(features)
                # redesign my loss
                loss = self.multi_cluster_loss(features, pids, margin=self.margin, alpha=self.alpha, beta=self.beta)

                print(f'loss: {loss}')
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
            if epoch % 10 == 0:
                print(f"Epoch {epoch}, Loss: {loss.item()}")
        
        print("TrueAncestors optimization completed")

        # now use the optimized patterns to compose the triggers
        true_ancestors = [
            TrueAncestor(
                pattern=pattern.detach().cpu().clone(),
                weight=copy.deepcopy(self.trigger_info['weight'])
            )
            for pattern in optim_pattern_params
        ]

        return true_ancestors
    
    def _register_hook(self):
        last_linear_layer = None
        for module in reversed(list(self.model.modules())):
            if isinstance(module, torch.nn.Linear):
                last_linear_layer = module
                break
        
        if last_linear_layer is None:
            raise ValueError("No FC layer founded in model")
        
        def hook(module, input, output):
            self.features_cache = input[0]
        
        last_linear_layer.register_forward_hook(hook)
        print("Hook registered for the last linear layer")
    
    def get_features(self, images: Tensor):
        # Get the input tensor of the last FC layer
        logits = self.model(images)
        return self.features_cache, logits
    
    # def extra_intra_loss(self, features: Tensor, centroids: Tensor, lambda_: float = 3.0):
    #     """
    #     Extra loss function for Arcueid
    #     """
        
    #     # feats: (N, d)
    #     center = features.mean(0, keepdim=True)
    #     # use cosine similarity to calculate the intra-class distance
    #     # l_intra = torch.cosine_similarity(features, center, dim=1).mean()
    #     l_intra = ((features - center).pow(2).sum(1)).mean()
        
    #     # The distance between the center and the centroids
    #     dists = torch.cdist(center, centroids, p=2)
    #     l_inter = -dists.min()
        
    #     return l_intra + lambda_ * l_inter
    
    def multi_cluster_loss(
        self,
        features: Tensor,
        pids: Tensor,
        margin: float = 6.0,
        alpha: float = 1.0,
        beta: float = 1.0
    ):
        """
        Features: (B, d): penultimate layer feats
        pids: (B,): Pattern ids
        ----------------------------------------------
        L_total = alpha * L_intra + beta * L_inter
        L_intra is for intra-cluster distance
        L_inter is for disperation of the pattern cluster
        """
        # 1. intra-cluster for same pattern
        mu = [] # center
        L_intra = 0
        uniq = pids.unique()
        for pid in uniq:
            mask = (pids == pid)
            f_grp = features[mask] # (N_i, d)
            mu_i = f_grp.mean(0, keepdim=True) # (1, d)
            L_intra += (f_grp - mu_i).pow(2).sum(1).mean()
            mu.append(mu_i)
        
        L_intra /= uniq.numel()

        # 2. disperation of the different pattern cluster
        if len(mu) > 1:
            mu = torch.cat(mu, dim=0) # (P, d)
            dist = torch.cdist(mu, mu, p=2)
            mask = torch.triu(torch.ones_like(dist), diagonal=1).bool()
            L_inter = F.relu(margin - dist[mask]).mean()
        else:
            L_inter = 0

        return alpha * L_intra + beta * L_inter


    def attach_triggers(self, images: Tensor, pattern_list: List[TrueAncestor], balanced: bool = False):
        """
        For each image, random select a TrueAncestor and apply it.
        """
        B = images.shape[0]
        P = len(pattern_list)
        
        if not balanced:
            ids = torch.randint(0, P, (B,))
        else:
            ids = torch.arange(B) % P
            # and shuffle the ids
            ids = ids[torch.randperm(B)]
            
        triggered_images = []
        weight = self.trigger_info['weight']
    
        for img, pid in zip(images, ids):
            pattern = pattern_list[pid]  # Parameter (C,H,W)
            img = TrueAncestor.add_trigger(img, pattern, weight)
            triggered_images.append(img)

        return torch.stack(triggered_images), ids
    
    @classmethod    
    def apply_trigger_batch(self, images: Tensor, pattern: Tensor, weight: Tensor):
        """
        Applying trigger on batch of images
        
        """

        if images.dim() == 3 or images.dim() == 4:  # (B, H, W) or (B, C, H, W)
            poisoned_images = torch.stack([TrueAncestor.add_trigger(img, pattern, weight) for img in images])
            return poisoned_images
        else:
            raise ValueError('Input image shape should be (B, H, W) or (B, C, H, W)')
        
class Millennium(DatasetFolder):
    """
    Millennium is a class for the All-to-all type attack, used for training only
    """
    def __init__(
        self,
        benign_dataset: DatasetFolder,
        poisoned_rate_per_id: float,
        true_ancestors: List[TrueAncestor],
        target_labels: List[int],
        label_mode: Literal['CLEAN', 'DIRTY'] = 'DIRTY',
    ):
        super().__init__(
            benign_dataset.root,
            benign_dataset.loader,
            benign_dataset.extensions,
            benign_dataset.transform,
            benign_dataset.target_transform,
            None
        )

        self.poisoned_rate_per_id = poisoned_rate_per_id
        self.true_ancestors = true_ancestors
        self.target_labels = target_labels
        self.label_mode = label_mode
        
        if len(self.target_labels) != len(self.true_ancestors):
            raise ValueError('The number of target labels and true ancestors must be the same')

        self.poisoned_transforms = []
        for true_ancestor in self.true_ancestors:
            if self.transform is not None:
                t = copy.deepcopy(self.transform)
                t.transforms.append(true_ancestor)
                self.poisoned_transforms.append(t)
            else:
                self.poisoned_transforms.append(Compose([true_ancestor]))
        
        self.poisoned_target_transforms = []
        for target_label in self.target_labels:
            if self.target_transform is not None:
                t = copy.deepcopy(self.target_transform)
                t.transforms.append(ModifyTarget(target_label))
                self.poisoned_target_transforms.append(t)
            else:
                self.poisoned_target_transforms.append(ModifyTarget(target_label))
        
        # for each pattern, got respective image id for poisoning 

        
        # every pattern got self.poisoned_num / patterns images in self.poisoned_set
        self.poisoned_num = int(len(self.samples) * self.poisoned_rate_per_id) * len(self.true_ancestors)
        
        if self.label_mode == 'DIRTY':
            self.poisoned_set = random.sample(range(len(self.samples)), k=self.poisoned_num)
            random.shuffle(self.poisoned_set)
            
            self.poisoned_dict = {
                poisoned_index: i % len(self.target_labels)
                for i, poisoned_index in enumerate(self.poisoned_set)
            } # {poisoned_image_id: pattern_id}
        elif self.label_mode == 'CLEAN':
            # for each pattern, get the corresponding target label value sample to poison
            self.poisoned_dict = {}
            per_pattern_num = self.poisoned_num // len(self.target_labels)

            # 1) target_label -> [pattern_id...]
            label_to_patterns = defaultdict(list)
            for pid, tl in enumerate(self.target_labels):
                label_to_patterns[tl].append(pid)

            # 2) label -> [candidate_indices...]
            label_to_indices = defaultdict(list)
            for idx, (_, lbl) in enumerate(self.samples):
                label_to_indices[lbl].append(idx)

            # 3) for pool of each label, give each patten 1 sample in each round till the pool ran out
            for lbl, pattern_ids in label_to_patterns.items():
                pool = label_to_indices.get(lbl, []).copy()
                random.shuffle(pool)

                total_need = per_pattern_num * len(pattern_ids)
                
                pool = pool[:min(len(pool), total_need)]

                assigned_counts = {pid: 0 for pid in pattern_ids}
                cursor = 0
                while cursor < len(pool):
                    progressed = False
                    for pid in pattern_ids:
                        if assigned_counts[pid] < per_pattern_num and cursor < len(pool):
                            idx = pool[cursor]
                            cursor += 1
                            self.poisoned_dict[idx] = pid
                            assigned_counts[pid] += 1
                            progressed = True
                    if not progressed:
                        # no progress, break
                        break     
            
            self.poisoned_set = list(self.poisoned_dict.keys())
        print(f"Here is the objected Dict:{self.poisoned_dict}")
        
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        
        if index in self.poisoned_dict:
            # print(f'PID: {self.poisoned_dict[index]}, Target Label: {self.target_labels[self.poisoned_dict[index]]}')
            sample = self.poisoned_transforms[self.poisoned_dict[index]](sample)
            target = self.poisoned_target_transforms[self.poisoned_dict[index]](target)
        else:
            if self.transform is not None:
                sample = self.transform(sample)
            if self.target_transform is not None:
                target = self.target_transform(target)
        
        return sample, target


class RealityMarble(DatasetFolder):
    """
    RealityMarble is a class for the All-to-all type attack, used for testing only
    """
    def __init__(
        self,
        benign_dataset: DatasetFolder,
        true_ancestor: TrueAncestor, # accept a TrueAncestor object
        target_label: int,
    ):
        super().__init__(
            benign_dataset.root,
            benign_dataset.loader,
            benign_dataset.extensions,
            benign_dataset.transform,
            benign_dataset.target_transform,
            None
        )

        self.true_ancestor = true_ancestor
        self.target_label = target_label

        if self.transform is not None:
            self.poisoned_transform = copy.deepcopy(self.transform)
        else:
            self.poisoned_transform = Compose([])

        self.poisoned_transform.transforms.append(self.true_ancestor)

        if self.target_transform is not None:
            self.poisoned_target_transform = copy.deepcopy(self.target_transform)
        else:
            self.poisoned_target_transform = Compose([])

        self.poisoned_target_transform.transforms.append(lambda _: target_label)

    def __getitem__(self, index):
        """
        For each image, apply the true ancestor and then return the target label
        """
        path, target = self.samples[index]
        sample = self.loader(path)

        sample = self.poisoned_transform(sample)
        target = self.poisoned_target_transform(target)

        return sample, target
    

class Arcueid(Base):
    """
    Arcueid is a class for the dynamic All-to-all type attack, used for training and testing.
    """
    def __init__(
        self,
        train_dataset: DatasetFolder,
        test_dataset: DatasetFolder,
        model: nn.Module,
        loss: nn.Module,
        # Arcueid arguments
        poisoned_rate_per_id: float,
        trigger_info: Dict[str, Any],
        target_labels: List[int],
        label_mode: Literal['CLEAN', 'DIRTY'] = 'DIRTY',
        # Optimize arguments
        # label mode default is DIRTY for the sake of all-to-all attack
        optimize_model: nn.Module | None = None,
        optimize_dataset: Dataset | None = None,
        optimize_device: str | torch.device = 'cpu',
        
        train_scale: float = 0.3,
        train_steps: int = 10,
        lr: float = 0.05,
        
        # pretrain settings
        pretrained_triggers: None | List[TrueAncestor] | List[Dict[str, Any]] = None,
        alpha: float = 1.0,
        beta: float = 1.0,
        margin: float = 6.0,
        
        schedule: Dict[str, Any] | None = None,
        seed: int = 0,
        deterministic: bool = False,
    ):
        super().__init__(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            model=model,
            loss=loss,
            schedule=schedule,
            seed=seed,
            deterministic=deterministic,
        )

        # Arcueid arguments
        self.poisoned_rate_per_id = poisoned_rate_per_id
        self.trigger_info = trigger_info
        self.target_labels = target_labels
        self.label_mode = label_mode
        # total setting
        self.target_num = len(self.target_labels)
        self.poisoned_rate = self.target_num * self.poisoned_rate_per_id

        # assuming surrogate model
        if optimize_model is None:
            self.optimize_model = copy.deepcopy(self.model)
        else:
            self.optimize_model = optimize_model
        
        # assuming surrogate dataset
        if optimize_dataset is None:
            self.optimize_dataset = self.train_dataset
        else:
            self.optimize_dataset = optimize_dataset
        
        self.optimize_device = optimize_device

        self.train_scale = train_scale
        self.train_steps = train_steps
        self.lr = lr

        self.alpha = alpha
        self.beta = beta
        self.margin = margin

        # For express experiment's sake
        # Handle trigger initialization based on pretrained_trigger parameter

        # flag
        if isinstance(pretrained_triggers, list):
            if isinstance(pretrained_triggers[0], TrueAncestor):
                self.pretrained_triggers = pretrained_triggers
            elif isinstance(pretrained_triggers[0], dict):
                self.pretrained_triggers = [
                    TrueAncestor(
                        pattern=pretrained_trigger_info['pattern'],
                        weight=pretrained_trigger_info['weight']
                    ) for pretrained_trigger_info in pretrained_triggers
                ]
            else:
                raise ValueError('pretrained_triggers must be a list of TrueAncestor or dict[str, Any]')
            
            # balance the target number and the pretrained triggers
            if len(self.pretrained_triggers) < self.target_num:
                raise ValueError('The number of pretrained triggers must be greater than the number of target labels')
            else:
                self.pretrained_triggers = self.pretrained_triggers[:self.target_num]
        else:
            # generate new triggers based on the trigger_info
            self.pretrained_triggers = self.trigger_generation()

        # for the backdoor learning of all-to-all attack
        self.poisoned_train_dataset = Millennium(
            benign_dataset=self.train_dataset,
            poisoned_rate_per_id=self.poisoned_rate_per_id,
            true_ancestors=self.pretrained_triggers,
            target_labels=self.target_labels,
            label_mode=self.label_mode,
        )

        # to compatible with the Base Class
        # REMEMBER: ITS A FAKE TEST SET FOR NONSENSE
        self.poisoned_test_dataset = RealityMarble(
            benign_dataset=self.test_dataset,
            true_ancestor=self.pretrained_triggers[0],
            target_label=self.target_labels[0],
        )

        print("Arcueid initialization completed")
        
    def trigger_generation(self):
        """
        Generate the trigger for the all-to-all attack
        """
        
        trigeer_optimizer = MarblePhantasm(
            dataset=self.optimize_dataset,
            model=self.optimize_model,
            trigger_info=self.trigger_info,
            true_ancestors_num=self.target_num,
            device=self.optimize_device,
            train_scale=self.train_scale,
            alpha=self.alpha,
            beta=self.beta,
            margin=self.margin,
        )
        
        true_ancestors = trigeer_optimizer(
            steps=self.train_steps,
            lr=self.lr,
        )

        return true_ancestors
    
    def compute_asr(self):
        """
        Compute the ASR of the all-to-all attack, so it should contain the ASR of each target label
        """
        # if self.model is on CPU, move it on self.optimize_device
        self.model.to(self.optimize_device)
        self.model.eval()
        
        asr_list = []
        for pid, (true_ancestor, target_label) in enumerate(zip(self.pretrained_triggers, self.target_labels)):
            poisoned_test_dataset = RealityMarble(
                benign_dataset=self.test_dataset,
                true_ancestor=true_ancestor,
                target_label=target_label,
            )
            all_poisoned_predict_digits, all_poisoned_labels, all_poisoned_mean_loss = self._test(poisoned_test_dataset, device=self.optimize_device)
            
            attack_success_nums = (all_poisoned_predict_digits.argmax(dim=1) == all_poisoned_labels).sum().item()
            asr = attack_success_nums / len(all_poisoned_labels)
            asr_list.append(asr)
            print(f'ASR of PID {pid} target label {target_label}: {asr}')
            
        return asr_list
    
    def save_trigger(self):
        return copy.deepcopy(self.pretrained_triggers)
    
    def get_full_a2a_testdataset(self, test_dataset: DatasetFolder | None = None):
        """
        get a full a2a test dataset with all the triggers
        """

        # so actually very simple
        # just got a Millennium dataset with all the triggers and the target labels
        # then we modify the poisoned_dict and poisoned_set to be the full a2a test dataset
        # and then return the dataset

        if test_dataset is None:
            test_dataset = self.test_dataset

        full_a2a_testdataset = Millennium(
            benign_dataset=test_dataset,
            poisoned_rate_per_id=self.poisoned_rate_per_id,
            true_ancestors=self.pretrained_triggers,
            target_labels=self.target_labels,
            label_mode='DIRTY',
        )

        # modify the poisoned_dict and poisoned_set to be the full a2a test dataset
        full_a2a_testdataset.poisoned_dict = {
            i: i % len(self.target_labels)
            for i in range(len(test_dataset))
        }
        full_a2a_testdataset.poisoned_set = list(full_a2a_testdataset.poisoned_dict.keys())

        return full_a2a_testdataset