"""
Original by https://github.com/YuxinWenRick/tree-ring-watermark and heavily modified for debugging purposes and to get access to internals
Please give them credit and adhere to their license agreement.
"""

from tqdm import tqdm

import os

import random

import copy

import typing

import argparse

import numpy as np
import scipy

import torch

from sklearn.metrics import accuracy_score

from .wm_provider import WmProvider

from utils.image_utils import torch_to_PIL
from utils import utils


def list_of_ints(arg):
    return list(map(int, arg.split(',')))

def bits_from_int(arg):
    try:
        n = int(arg)
        if n < 0:
            raise argparse.ArgumentTypeError("음수는 입력할 수 없습니다.")
        # n의 크기만큼 0과 1이 번갈아 나오는 리스트 생성 (예: n=4 -> ['0', '1', '0', '1'])
        bits_list = [str(i % 2) for i in range(n)]
        # 리스트를 쉼표로 구분된 문자열로 합침
        return ",".join(bits_list)
    except ValueError:
        # 정수가 아닌 값이 들어올 경우 오류 발생
        raise argparse.ArgumentTypeError(f"'{arg}'은(는) 유효한 정수가 아닙니다.")


parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--payload_bits', default=8, type=int)
parser.add_argument('--q_step', default=6, type=float)
parser.add_argument('--n_bins', default=10, type=int)
parser.add_argument('--r_min_ratio', default=0.1, type=float)
parser.add_argument('--r_max_ratio', default=0.5, type=float)
parser.add_argument('--amp_threshold_percentile', default=100, type=int)


def set_seed(seed: int = 42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def detect_pqim_ss_32bit(suspect_latent, Omega_groups, Q_step):
    """
    suspect_latent: (C, H, W) 실수 텐서
    Omega_groups: 32개 서브리스트(각각 N_bins개의 (c,fy,fx))
    Q_step: np.pi/3
 
    반환값: 길이 32인 리스트 detected_bits (0/1)
    """
    Z = torch.fft.fft2(suspect_latent)  # (C, H, W), complex
    C, H_fft, W_fft = Z.shape

    detected_bits = []
    num_bits = len(Omega_groups)
    # 각 비트별로 투표
    for bit_idx in range(num_bits):
        Omega_i = Omega_groups[bit_idx]
        vote_sum = 0
        # 한 비트당 N_bins개 + 켤레까지 = 2 × N_bins 투표
        for (c, fy, fx) in Omega_i:
            # 측정한 위상
            phi_recv = torch.angle(Z[c, fy, fx])
 
            # b=0일 때는 격자, b=1일 때는 격자+Q_step/2로 양자화
            phi_q0 = torch.round(phi_recv / Q_step) * Q_step
            phi_q1 = torch.round((phi_recv - Q_step/2) / Q_step) * Q_step + Q_step/2
 
            # 각도 거리 비교
            dist0 = torch.abs(wrap_angle(phi_recv - phi_q0))
            dist1 = torch.abs(wrap_angle(phi_recv - phi_q1))
            if dist1 < dist0:
                vote_sum += 1
 
            # 켤레 빈에 대해서도 동일하게
            conj_fy = (-fy + H_fft) % H_fft
            conj_fx = (-fx + W_fft) % W_fft
            if not (fy == 0 and fx == 0) and not (fy == conj_fy and fx == conj_fx):
                phi_conj_recv = torch.angle(Z[c, conj_fy, conj_fx])
                phi_conj_q0 = torch.round(phi_conj_recv / Q_step) * Q_step
                phi_conj_q1 = torch.round((phi_conj_recv - Q_step/2) / Q_step) * Q_step + Q_step/2
 
                dist0_c = torch.abs(wrap_angle(phi_conj_recv - phi_conj_q0))
                dist1_c = torch.abs(wrap_angle(phi_conj_recv - phi_conj_q1))
                if dist1_c < dist0_c:
                    vote_sum += 1
 
        max_votes = 2 * len(Omega_i)  # 2 × N_bins
        rate = vote_sum / max_votes
        # print(f"[DEBUG] Bit {bit_idx:2d}: vote_sum = {vote_sum} / {max_votes} (rate = {rate:.3f})")
        detected_bit = 1 if vote_sum > (max_votes / 2) else 0
        detected_bits.append(int(detected_bit))
 
    return detected_bits


def choose_bins_simplified(
        h_fft,
        w_fft,
        r_min_ratio,
        r_max_ratio,
        n_bins,
        amplitude_spectrum_for_filter=None,
        amp_threshold_percentile=75
    ):
    selected_bins = []
    selected_coords_set = set()
    center_h, center_w = h_fft // 2, w_fft // 2
   
    # 실제 반지름 계산을 위해 fftshift된 좌표 기준으로 생각 후 unshift
    # 여기서는 unshifted 좌표 (0,0)이 DC인 상태에서 직접 계산
    # u: 0..h_fft-1, v: 0..w_fft-1
    # 1사분면 (DC, 축, Nyquist 제외)
    u_min_abs, u_max_abs = int(h_fft * r_min_ratio), int(h_fft * r_max_ratio)
    v_min_abs, v_max_abs = int(w_fft * r_min_ratio), int(w_fft * r_max_ratio)
 
    # 상위 k% 진폭 bin을 미리 계산해서 사용
    amp_threshold = 0
    topk_indices = None
    if amplitude_spectrum_for_filter is not None:
        # 전체 반평면 주파수 좌표 리스트 생성
        coords = []
        amps = []
        C = amplitude_spectrum_for_filter.shape[0]
        for c in range(C):
            for u in range(1, h_fft):
                for v in range(1, w_fft):
                    r = np.sqrt(u**2 + v**2)
                    if u_min_abs <= r <= u_max_abs and v_min_abs <= r <= v_max_abs:
                        coords.append((c, u, v))
                        amps.append(abs(amplitude_spectrum_for_filter[c, u, v]))
       
        # 진폭 상위 amp_threshold_percentile% 경계값 계산
        if len(amps) > 0:
            threshold = np.percentile(amps, 100 - amp_threshold_percentile)
        else:
            threshold = 0
 
        # 상위 k% 진폭 좌표 선택
        filtered_coord = [coord for coord, val in zip(coords, amps) if val >= threshold]
        topk_indices =set(filtered_coord)
        print(f" 상위 {amp_threshold_percentile}% 진폭 bin 수: {len(filtered_coord)}")
 
    # 랜덤 샘플링 (중복 및 조건 만족할 때까지) - 비효율적일 수 있음
    attempts = 0
    max_attempts = n_bins * 100 # 무한 루프 방지
    while len(selected_bins) < n_bins and attempts < max_attempts:
        attempts += 1
        # 1부터 h_ftt-1, 1부터 w_fft-1 사이의 랜덤 좌표 선택
        u = np.random.randint(1, h_fft)
        v = np.random.randint(1, w_fft)
        r = np.sqrt(u**2 + v**2) # unshifted 좌표 기준 반지름 (DC 로부터의 거리)
 
        if u_min_abs <= r <= u_max_abs:
            # 상위 k% 진폭 후보에 있는지 확인
            c = np.random.randint(0, amplitude_spectrum_for_filter.shape[0]) # 채널 랜덤 선택
            if (c, u, v) not in topk_indices:
                continue
            if (c, u, v) in selected_coords_set:
                continue
                # 위 조건 통과하면 bin 선택
            selected_bins.append((c, u, v))
            selected_coords_set.add((c, u, v))
   
    if len(selected_bins) < n_bins:
        print(f"Warning: 목표한 {n_bins}개의 빈을 모두 선택하지 못했습니다 ({len(selected_bins)}개 선택됨). \
              파라미터 조정 필요.")
    return selected_bins


def wrap_angle(angle):
    """
    angle(tensor) ∈ ℝ (rad) 값을 -π~+π 범위로 래핑합니다.
    """
    return (angle + np.pi) % (2 * np.pi) - np.pi


def embed_pqim_ss_32bit(latent_in, payload_bits, omega_groups, Q_step):
    """
    latent_in: (C,H,W) 실수 텐서(IFFT 결과 형태)
    payload_bits: 길이 32 리스트(0/1)
    omega_groups: 32개 리스트의 리스트. 각 서브리스트는 [(c,fy,fx), …] N_bins개
    Q_step: np.pi/3 등.
   
    반환값: 워터마크 삽입된 (C,H,W) 실수 텐서
    """
    # FFT
    latent_in = latent_in[0]
    Z = torch.fft.fft2(latent_in)  # complex tensor (C,H,W)
 
    # 원래 위상이랑 진폭 저장
    Z_original    = Z.clone()
    orig_phase_map = torch.angle(Z_original)
    orig_mag_map   = torch.abs(Z_original)
    C, H_fft, W_fft = Z.shape
 
    # 비트수만큼 반복해서 워터마크 삽입
    num_bits = len(payload_bits)
    for bit_idx in range(num_bits):
        bval = int(payload_bits[bit_idx])
        omega_i = omega_groups[bit_idx]      
 
        # 한 비트당 N_bins개의 빈에 대해 삽입
        for (c, fy, fx) in omega_i:
            # 원래 phase
            phi_orig = orig_phase_map[c, fy, fx]
            # 양자화된 phase (QIM)
            phi_q = (
                torch.round((phi_orig - bval * Q_step / 2) / Q_step) * Q_step
                + bval * Q_step / 2
            )
            # wrap_angle 적용
            delta = wrap_angle(phi_q - phi_orig)
            # 원래 진폭 + 양자화된 위상 덮어쓰기
            Z[c, fy, fx] = orig_mag_map[c, fy, fx] * torch.exp(1j * phi_q)
 
            # 대칭대는 켤레 빈
            conj_fy = (-fy + H_fft) % H_fft
            conj_fx = (-fx + W_fft) % W_fft
            if not (fy == 0 and fx == 0) and not (fy == conj_fy and fx == conj_fx):
                phi_conj_orig = orig_phase_map[c, conj_fy, conj_fx]
                # 켤레 위상에 대해서도 동일한 QIM 적용
                phi_conj_q = wrap_angle(phi_conj_orig - delta)
                Z[c, conj_fy, conj_fx] = orig_mag_map[c, conj_fy, conj_fx] * torch.exp(1j * phi_conj_q)
 
    # IFFT로 복원
    return torch.fft.ifft2(Z).real


class PqProvider(WmProvider):
    """
    Original by https://github.com/YuxinWenRick/tree-ring-watermark and heavily modified for debugging purposes \
        and to get access to internals
    """

    def __init__(
            self,
            seed,
            payload_bits,
            q_step,
            n_bins,
            r_min_ratio,
            r_max_ratio,
            amp_threshold_percentile,
            **kwargs,
        ):
        super().__init__(**kwargs)

        latent_shape = kwargs['latent_shape']

        # This ensures, every latent ever create has the same WM pattern
        # This makes sense when we simulate only one service provider using one kind of WM pattern
        if seed is not None:
            utils.set_random_seed(seed)

        self.seed = seed
        self.q_step = q_step
        self.payload_bits = [0 if i % 2 == 0 else 1 for i in range(payload_bits)]
        self.payload_bits_str = ''.join(['0' if i % 2 == 0 else '1' for i in range(payload_bits)])

        self.gt_patch = self.__get_watermarking_pattern(
            latent_shape=latent_shape,
            n_bins=n_bins,
            r_min_ratio=r_min_ratio,
            r_max_ratio=r_max_ratio,
            amp_threshold_percentile=amp_threshold_percentile,
            seed=seed,
            device=kwargs['device'],
        )


    def get_wm_type(self) -> str:
        return "PQ"


    def __get_watermarking_pattern(
            self,
            latent_shape,
            n_bins,
            r_min_ratio,
            r_max_ratio,
            amp_threshold_percentile,
            device,
            seed,
        ) -> torch.tensor:
        """
        Get the watermarking pattern

        @return: torch.tensor on self.device
        """
        if seed is not None: set_seed(seed)

        # ECC 제거 - 원본 payload_bits를 직접 사용
        num_bits = len(self.payload_bits)

        # latent shape
        C, H_latent, W_latent = latent_shape[1:]

        # frequency spectrum
        dummy_latent = torch.randn(*latent_shape, device=device)
        with torch.no_grad():
            Z = torch.fft.fft2(dummy_latent[0])
            amp = torch.abs(Z).cpu().numpy()

        # N bins 계산 및 total bins 계산
        total_bins = n_bins * num_bits

        # r_min_ratio, r_max ratio, amp_threshold_percentile로 omega_all 계산
        omega_all = choose_bins_simplified(
            H_latent, W_latent,
            r_min_ratio, r_max_ratio,
            total_bins,
            amplitude_spectrum_for_filter=amp,
            amp_threshold_percentile=amp_threshold_percentile
        )

        if len(omega_all) < total_bins:
            raise ValueError(
                f"PQIM용으로 필요한 빈 수 ({total_bins})보다 적은 빈이 선택되었습니다. "
                f"선택된 빈 수: {len(omega_all)}. 파라미터를 조정하세요."
            )
        
        # omega_all을 num_bits 그룹으로 나누기
        omega_groups = [
            omega_all[i*n_bins:(i+1)*n_bins] for i in range(num_bits)
        ]

        self.omega_groups = omega_groups
        
        gt_patch = {
            "payload_bits": self.payload_bits,
            "omega_groups": self.omega_groups,
            "q_step": self.q_step,
        }
    
        return gt_patch


    def get_wm_latents(
            self,
            latents_clean: torch.Tensor = None,
            seed: int = None,
        ) -> typing.Dict[str, any]:
        """
        Get the latents for the watermarking scheme

        @param latents_clean: torch.Tensor, shape: self.latent_shape,
        @param seed: int, seed for watermarking

        @return: dict
        """
        if seed is not None: set_seed(seed)

        assert latents_clean is not None, f'latents_clean should be {torch.tensor}, but {type(latents_clean)}'
    
        watermarked = embed_pqim_ss_32bit(
            latents_clean, self.payload_bits, self.omega_groups, self.q_step,
        ).unsqueeze(0)

        return {
            "zT_torch": watermarked,
            "message_bits_str_list": [self.payload_bits_str],
        }
    
    def get_accuracies(
            self,
            latents: typing.Union[torch.Tensor, np.array],
        ) -> typing.Dict[str, any]:

        if len(latents.shape) == 4:
            latents = latents.squeeze(0)
            print(f'latents: {latents.shape} squeezed in 0 dim')

        assert len(latents.shape) == 3, f"latents dim should be 3, but {len(latents.shape)}."

        detected_bits = detect_pqim_ss_32bit(
            latents,
            self.omega_groups,
            self.q_step,
        )

        accuracy = accuracy_score(self.payload_bits, detected_bits)

        return {
            "bit_accuracies": [accuracy],
            "extracted_messages": detected_bits,
            "message_bits_str_list": [self.payload_bits_str],
        }
