# FINAL/ZERO_INN_JP__23_33_21/vis_test
import os
import time
import math
import logging
import warnings
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn.functional as F
import config as c
from utils.utils import *
from utils.metric import *
from utils.datasets import *
from block.Noise import Noise
from block.INN import RFE

warnings.filterwarnings("ignore")

def _init_defaults(cfg=None):
    d = dict(
        qim_Delta=2,
        qim_seed=12345,
        test_joint_steps=3000,
        joint_steps=3000,
        adv_opt_name="adam",
        adv_lr=1e-2,
        adv_lr_final=1e-4,
        adv_warmup_ratio=0.5,
        adv_grad_clip=0.0,
        adv_mse_w=150,
        acc_w=1,

        dec_lr=1e-4,
        dec_grad_clip=0,
        ndevice = 2,
    
        target_psnr = 41,

        loss_mode="qim",      
        test_loss_mode="qim",  
        noise_type="NGMIX",      
        mix_epochs=0,         

        SAVE_freq=1,
    )

    for k, v in d.items():
        if not hasattr(cfg, k):
            setattr(cfg, k, v)


_init_defaults(c)


def psnr_minus1_to1(x, y):
    return psnr(x, y, 2)


def msg_to_bits(m_minus1_1: torch.Tensor) -> torch.Tensor:
    return (m_minus1_1 > 0.5).to(torch.float32)


def orthonormal_carriers(B: int, L: int, device, seed: int = 12345) -> torch.Tensor:
    if B > L:
        raise ValueError(f"B ({B}) must be <= L ({L}).")
    g = torch.randn(
        L, B,
        device=device,
        dtype=torch.float32,
        generator=torch.Generator(device=device).manual_seed(seed)
    )
    Q, _ = torch.linalg.qr(g, mode='reduced')  # [L,B]
    V = Q.transpose(0, 1).contiguous()         # [B,L]
    return V


def cos_periodic_loss_from_t(t: torch.Tensor, bits: torch.Tensor, Delta: float) -> torch.Tensor:
    two_pi = 2.0 * math.pi
    cst = bits * (0.5 * Delta)
    z = (t - cst) / Delta
    return (1.0 - torch.cos(two_pi * z)).mean()


@torch.no_grad()
def decode_bits_from_t_qim(t: torch.Tensor, Delta: float) -> torch.Tensor:
    d0 = torch.abs(t - Delta * torch.round(t / Delta))
    d1 = torch.abs(t - (Delta * (torch.round((t - 0.5 * Delta) / Delta) + 0.5)))
    return (d1 < d0).to(torch.float32)


@torch.no_grad()
def decode_bits_from_t_sign(t: torch.Tensor) -> torch.Tensor:
    return (t > 0.5).to(torch.float32)


def _take_first_if_list(x):
    return x[0] if isinstance(x, (list, tuple)) else x


def _make_x_optimizer(x_param, name: str, lr: float):
    name = name.lower()
    if name == "sgd":
        return torch.optim.SGD([x_param], lr=lr, momentum=0.9)
    return torch.optim.Adam([x_param], lr=lr)


def _set_x_lr(opt, lr_now: float):
    for g in opt.param_groups:
        g["lr"] = lr_now


def _compute_sched_lr(step_idx: int, total_steps: int, base_lr: float,
                      final_lr: float, warm_ratio: float) -> float:
    total_steps = max(1, int(total_steps))
    warm_steps = max(1, min(int(total_steps * warm_ratio),
                            max(1, total_steps - 1)))

    if step_idx <= warm_steps:
        return base_lr * (step_idx / float(warm_steps))
    t = (step_idx - warm_steps) / float(max(1, total_steps - warm_steps))
    return final_lr + 0.5 * (base_lr - final_lr) * (1.0 + math.cos(math.pi * t))


def binary_entropy(p: float) -> float:
    eps = 1e-12
    p = float(max(min(p, 1.0 - eps), eps))
    return - (p * math.log2(p) + (1.0 - p) * math.log2(1.0 - p))

def compute_embed_loss(noise_tvec: torch.Tensor,
                       clean_tvec: torch.Tensor,
                       bits: torch.Tensor,
                       message_pm1: torch.Tensor,
                       mode: str,
                       Delta: float,
                       epoch: int = None,
                       mix_epochs: int = None) -> torch.Tensor:
    mode = mode.lower()
    if mode == "qim":
        return cos_periodic_loss_from_t(clean_tvec, bits, Delta) + F.mse_loss(clean_tvec, noise_tvec)

    elif mode == "mse":
        return F.mse_loss(noise_tvec, message_pm1)

    else:
        raise ValueError(f"Unknown loss_mode: {mode}")


@torch.no_grad()
def decode_bits_from_t(tvec: torch.Tensor,
                       mode: str,
                       Delta: float) -> torch.Tensor:
    mode = mode.lower()
    if mode == "qim":
        return decode_bits_from_t_qim(tvec, Delta)
    elif mode == "mse":
        return decode_bits_from_t_sign(tvec)
    else:
        raise ValueError(f"Unknown loss_mode: {mode}")


def qim_embed(latent_z: torch.Tensor,
              message: torch.Tensor,
              V: torch.Tensor,
              Delta: float):
    B = latent_z.size(0)
    z_flat = latent_z.view(B, -1)          # [B, L_dec]
    L_dec = z_flat.size(1)
    msg_len = message.size(1)
    assert V.shape == (msg_len, L_dec), "V shape mismatch."
    t = z_flat @ V.t()                     # [B, msg_len]

    cst = message * (0.5 * Delta)          # [B, msg_len], 0 or Δ/2
    z = (t - cst) / Delta                 

    k = torch.round(z)
    qim_t = cst + Delta * k            

    delta_t = qim_t - t                    # [B, msg_len]
    z_flat_w = z_flat + delta_t @ V        # [B, L_dec]
    watermarked_latent_z = z_flat_w.view_as(latent_z)

    return watermarked_latent_z

def build_noises(noise_type: str):
    noise_type = noise_type.upper()


    test_noise = Noise(["PCombined([RC(91,91), Elastic((3,3)), AF(s=(-55,-55)), AF(s=(55,55)),AF(r=(-45,-45)), AF(r=(45,45)), Erase((0.5,0.5)), JpegTest(50), MF(7), GF(2), Dropout(0.5), SP(0.1), GN(0.04), Bright(0.2, 0.2), Bright(2, 2), Contrast(0.2, 0.2), Contrast(2, 2), Hue(-0.1, -0.1), Hue(0.1, 0.1),Saturation(0.2, 0.2),Saturation(2, 2)])"])
                        
    train_noise = Noise(["Combined(["
                            "RC(91,91),"                                                        
                            "Elastic((3,3)),"                                 
                            "KJpeg(50),"                                                                              
                            "Erase((0.5,0.5)),"
                            "AF(s=(-55,55)),"
                            "AF(r=(-45,45)),"
                            "MF(7),"
                            "GF(2),"
                            "Dropout(0.5),"
                            "SP(0.1),"
                            "GN(0.04),"
                            "Bright(0.2, 2),"
                            "Contrast(0.2, 2),"    
                            "Hue(-0.1, 0.1),"         
                            "Saturation(0.2, 2),"
                            "])"])  

    test_train_noise = Noise(["Combined(["
                            "RC(91,91),"                                                        
                            "Elastic((3,3)),"                                 
                            "KJpeg(50),"                                                                              
                            "Erase((0.5,0.5)),"
                            "AF(s=(-55,55)),"
                            "AF(r=(-45,45)),"
                            "MF(7),"
                            "GF(2),"
                            "Dropout(0.5),"
                            "SP(0.1),"
                            "GN(0.04),"
                            "Bright(0.2, 2),"
                            "Contrast(0.2, 2),"    
                            "Hue(-0.1, 0.1),"         
                            "Saturation(0.2, 2),"
                            "])"]) 
                            
    return test_noise, train_noise, test_train_noise


def main():
    # result dir
    result_folder = os.path.join(
        c.MODEL_PATH, time.strftime(c.PROJECT_NAME + "__%H_%M_%S", time.localtime())
    )
    os.makedirs(os.path.join(result_folder, "models"), exist_ok=True)

    device = torch.device("cuda:{}".format(c.ndevice) if torch.cuda.is_available() else "cpu")

    pmodel = RFE().to(device)

    for p in pmodel.parameters():
        p.requires_grad_(True)

    if getattr(c, "tain_next", False):
        ckpt = os.path.join(c.MODEL_PATH, c.CONTINUE_PATH, "models", f"{c.CONTINUE_EPOCH}.pt")
        load(pmodel, ckpt)

    setup_logger('train', result_folder, 'logging',
                 level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('train')

    test_noise, train_noise, test_train_noise = build_noises(c.noise_type)

    dec_params = list(filter(lambda p: p.requires_grad, pmodel.parameters()))
    opt_dec = torch.optim.Adam(dec_params, lr=c.dec_lr)

    V_cache: Dict[Tuple[int, int], torch.Tensor] = {}

    # ================== epoch loop ==================
    for epoch in range(1, c.epochs + 1):
        # ---------------- TRAIN ----------------
        pmodel.train()
        train_psnrs, acc = [], [[], [], [], [], [], [], [],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]

        for cover in trainloader:
            if isinstance(cover, dict):
                for k in ["image", "img", "cover", "cover_image", "x"]:
                    if k in cover:
                        cover = cover[k]
                        break
            cover = cover.to(device).float()           # [-1,1], [B,C,H,W]
            B, C, H, W = cover.shape

            # message ∈ {0,1}, shape [B, msg_len]
            msg_len = int(c.message_length)
            message = torch.Tensor(
                np.random.choice([0, 1], (B, msg_len))
            ).to(device)
            bits = msg_to_bits(message)                # {0,1} float

            # probe decoder output shape to build V
            with torch.no_grad():
                y0 = _take_first_if_list(train_noise([cover, cover]))
                dec0 = pmodel(y0) if c.mode != 'PH' else pmodel(y0)[0]
            L_dec = int(np.prod(dec0.shape[1:])) if dec0.dim() == 4 else int(dec0.shape[1])
            if msg_len > L_dec:
                raise RuntimeError(f"message_length ({msg_len}) > decoder output dims ({L_dec}).")
            key = (L_dec, msg_len)
            if key not in V_cache:
                V_cache[key] = orthonormal_carriers(msg_len, L_dec, device=device, seed=c.qim_seed)
            V = V_cache[key]  # [msg_len, L_dec]

            # per-batch variable x and its optimizer (with schedule)
            x = cover.clone().detach().requires_grad_(True)
            opt_x = _make_x_optimizer(x, c.adv_opt_name, c.adv_lr)

            total_steps = int(c.joint_steps)
            for step_i in range(1, total_steps + 1):
                # schedule LR for x
                lr_now = _compute_sched_lr(step_i, total_steps,
                                           c.adv_lr, c.adv_lr_final, c.adv_warmup_ratio)
                _set_x_lr(opt_x, lr_now)

                opt_x.zero_grad(set_to_none=True)
                if step_i % 1 == 0:
                    opt_dec.zero_grad(set_to_none=True)

                yk = _take_first_if_list(train_noise([x, cover]))
                
                clean_dec_out = pmodel(x) if c.mode != 'PH' else pmodel(x.detach())[0]
                clean_yvec = clean_dec_out.reshape(B, -1)          # [B, L_dec]
                clean_tvec = clean_yvec @ V.t()                    # [B, msg_len]
                            
                noise_dec_out = pmodel(yk) if c.mode != 'PH' else pmodel(yk.detach())[0]
                noise_yvec = noise_dec_out.reshape(B, -1)          # [B, L_dec]
                noise_tvec = noise_yvec @ V.t()                    # [B, msg_len]

                # --- 嵌入 loss：QIM / MSE / MIX ---
                embed_loss = compute_embed_loss(
                    noise_tvec=noise_tvec,
                    clean_tvec=clean_tvec,
                    bits=bits,
                    message_pm1=message,
                    mode=c.loss_mode,
                    Delta=c.qim_Delta,
                    epoch=epoch,
                    mix_epochs=c.mix_epochs
                )

                loss = embed_loss
                if c.adv_mse_w and c.adv_mse_w > 0:
                    loss = c.acc_w * loss + c.adv_mse_w * F.mse_loss(x, cover)

                loss.backward()

                # grad clip (optional)
                if c.adv_grad_clip and c.adv_grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_([x], max_norm=c.adv_grad_clip)
                if c.dec_grad_clip and c.dec_grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(dec_params, max_norm=c.dec_grad_clip)

                opt_x.step()
                if step_i % 1 == 0:
                    opt_dec.step()
                with torch.no_grad():
                    x.clamp_(-1, 1)

            # -------- metrics on this batch --------
            with torch.no_grad():
                psnr_x = float(psnr_minus1_to1(cover, x))

                no_test_stego_image_list = test_noise([x, cover])
                for idx, no_test_stego_image in enumerate(no_test_stego_image_list):
                    dec_out = pmodel(no_test_stego_image)
                    tvec = dec_out.reshape(B, -1) @ V.t()

                    bits_hat = decode_bits_from_t(tvec=tvec, mode=c.test_loss_mode, Delta=c.qim_Delta)
                    error_rate = (bits_hat.round() != bits).float().mean().item()
                    acc[idx].append((1 - error_rate) * 100)

                train_psnrs.append(psnr_x)

            print(
                f"TRAIN:   "
                f'PSNR_STEGO: {np.mean(train_psnrs):.2f} | '

                f'R  C Acc: {np.mean(acc[0]):.2f} | '
                f'E  S Acc: {np.mean(acc[1]):.2f} | '
                f'S  H -55 Acc: {np.mean(acc[2]):.2f} | '
                f'S  H +55 Acc: {np.mean(acc[3]):.2f} | '
                f'R  O -45 Acc: {np.mean(acc[4]):.2f} | '
                f'R  O +45 Acc: {np.mean(acc[5]):.2f} | '
                
                f'E  R Acc: {np.mean(acc[6]):.2f} | '
                f'J  P Acc: {np.mean(acc[7]):.2f} | '
                f'M  F Acc: {np.mean(acc[8]):.2f} | '
                f'G  F Acc: {np.mean(acc[9]):.2f} | '
                f'D  P Acc: {np.mean(acc[10]):.2f} | '
                f'S  P Acc: {np.mean(acc[11]):.2f} | '
                f'G  N Acc: {np.mean(acc[12]):.2f} | '
                f'B  R 0.2 Acc: {np.mean(acc[13]):.2f} | '
                f'B  R 2.0 Acc: {np.mean(acc[14]):.2f} | '
                f'C  T 0.2 Acc: {np.mean(acc[15]):.2f} | '
                f'C  T 2.0 Acc: {np.mean(acc[16]):.2f} | '
                f'H  U -0.2 Acc: {np.mean(acc[17]):.2f} | '
                f'H  U +0.2 Acc: {np.mean(acc[18]):.2f} | '
                f'S  A 0.2 Acc: {np.mean(acc[19]):.2f} | '
                f'S  A 2.0 Acc: {np.mean(acc[20]):.2f} | '
            )

        logger.info(
            f"[TRAIN] epoch={epoch} | "
            f"PSNR(x)={np.mean(train_psnrs):.2f} dB | "

            f'R  C Acc: {np.mean(acc[0]):.2f} | '
            f'E  S Acc: {np.mean(acc[1]):.2f} | '
            f'S  H -55 Acc: {np.mean(acc[2]):.2f} | '
            f'S  H +55 Acc: {np.mean(acc[3]):.2f} | '
            f'R  O -45 Acc: {np.mean(acc[4]):.2f} | '
            f'R  O +45 Acc: {np.mean(acc[5]):.2f} | '
            
            f'E  R Acc: {np.mean(acc[6]):.2f} | '
            f'J  P Acc: {np.mean(acc[7]):.2f} | '
            f'M  F Acc: {np.mean(acc[8]):.2f} | '
            f'G  F Acc: {np.mean(acc[9]):.2f} | '
            f'D  P Acc: {np.mean(acc[10]):.2f} | '
            f'S  P Acc: {np.mean(acc[11]):.2f} | '
            f'G  N Acc: {np.mean(acc[12]):.2f} | '
            f'B  R 0.2 Acc: {np.mean(acc[13]):.2f} | '
            f'B  R 2.0 Acc: {np.mean(acc[14]):.2f} | '
            f'C  T 0.2 Acc: {np.mean(acc[15]):.2f} | '
            f'C  T 2.0 Acc: {np.mean(acc[16]):.2f} | '
            f'H  U -0.2 Acc: {np.mean(acc[17]):.2f} | '
            f'H  U +0.2 Acc: {np.mean(acc[18]):.2f} | '
            f'S  A 0.2 Acc: {np.mean(acc[19]):.2f} | '
            f'S  A 2.0 Acc: {np.mean(acc[20]):.2f} | '

            f"(mode={c.loss_mode}, Δ={c.qim_Delta}, joint_steps={c.joint_steps}, "
            f"acc_weight={c.acc_w}, mse_weight={c.adv_mse_w}, "
            f"adv_lr={c.adv_lr}, adv_lr_final={c.adv_lr_final}, warmup={c.adv_warmup_ratio}, "
            f"dec_lr={c.dec_lr}, noise_type={c.noise_type})"
        )

        # ---------------- TEST ----------------
        pmodel.eval()
        test_psnrs, acc = [], [[], [], [], [], [], [], [],[],[],[],[],[],[],[],[],[],[],[],[],[],[]]
        for cover in testloader:
            if isinstance(cover, dict):
                for k in ["image", "img", "cover", "cover_image", "x"]:
                    if k in cover:
                        cover = cover[k]
                        break
            cover = cover.to(device).float()
            B, C, H, W = cover.shape

            msg_len = int(c.test_message_length)
            message = torch.Tensor(
                np.random.choice([0, 1], (B, msg_len))
            ).to(device)
            bits = msg_to_bits(message)

            # build V for current decoder output dims
            with torch.no_grad():
                y0 = _take_first_if_list(train_noise([cover, cover]))
                dec0 = pmodel(y0) if c.mode != 'PH' else pmodel(y0)[0]
            L_dec = int(np.prod(dec0.shape[1:])) if dec0.dim() == 4 else int(dec0.shape[1])
            key = (L_dec, msg_len)
            if key not in V_cache:
                V_cache[key] = orthonormal_carriers(msg_len, L_dec, device=device, seed=c.qim_seed)
            V = V_cache[key]

            # re-optimize x ONLY (decoder frozen) with the SAME LR schedule
            x = cover.clone().detach().requires_grad_(True)
            opt_x = _make_x_optimizer(x, c.adv_opt_name, c.adv_lr)

            total_steps = int(c.test_joint_steps)
            for step_i in range(1, total_steps + 1):
                lr_now = _compute_sched_lr(step_i, total_steps,
                                           c.adv_lr, c.adv_lr_final, c.adv_warmup_ratio)
                _set_x_lr(opt_x, lr_now)

                opt_x.zero_grad(set_to_none=True)

                yk = test_train_noise([x, cover]); yk = _take_first_if_list(yk)
                noise_dec_out = pmodel(yk) if c.mode != 'PH' else pmodel(yk)[0]
                noise_tvec = noise_dec_out.reshape(B, -1) @ V.t()

                clean_dec_out = pmodel(x) if c.mode != 'PH' else pmodel(x.detach())[0]
                clean_yvec = clean_dec_out.reshape(B, -1)          # [B, L_dec]
                clean_tvec = clean_yvec @ V.t()          
                
                loss_x = compute_embed_loss(
                    noise_tvec=noise_tvec,
                    clean_tvec=clean_tvec,
                    bits=bits,
                    message_pm1=message,
                    mode=c.test_loss_mode,
                    Delta=c.qim_Delta,
                    epoch=epoch,
                    mix_epochs=c.mix_epochs
                )

                if c.adv_mse_w and c.adv_mse_w > 0:
                    loss_x = c.acc_w * loss_x + c.adv_mse_w * F.mse_loss(x, cover)

                loss_x.backward()
                if c.adv_grad_clip and c.adv_grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_([x], max_norm=c.adv_grad_clip)
                opt_x.step()
                with torch.no_grad():
                    x.clamp_(-1, 1)

            from torchvision.utils import save_image

            with torch.no_grad():
                psnr_x = float(psnr_minus1_to1(cover, x))

                no_test_stego_image_list = test_noise([x, cover])
                for idx, no_test_stego_image in enumerate(no_test_stego_image_list):
                    dec_out = pmodel(no_test_stego_image)
                    tvec = dec_out.reshape(B, -1) @ V.t()

                    bits_hat = decode_bits_from_t(tvec=tvec, mode=c.test_loss_mode, Delta=c.qim_Delta)
                    error_rate = (bits_hat.round() != bits).float().mean().item()
                    acc[idx].append((1 - error_rate) * 100)

                test_psnrs.append(psnr_x)

                if (epoch - 1) % 10 == 0:
                    out_dir = os.path.join(result_folder, "vis_test")
                    print(out_dir)
                    os.makedirs(out_dir, exist_ok=True)

                    cover_vis = ((x[0:1] - cover[0:1]) * 10 + 1.0) / 2.0   # [-1,1] -> [0,1]
                    x_vis = (x[0:1] + 1.0) / 2.0

                    save_image(cover_vis, os.path.join(out_dir, f"epoch{epoch:03d}_cover.png"))
                    save_image(x_vis, os.path.join(out_dir, f"epoch{epoch:03d}_stego.png"))


        logger.info(
            f"[TEST ] epoch={epoch} | "
            f"PSNR(x)={np.mean(test_psnrs):.2f} dB | "
            f'R  C Acc: {np.mean(acc[0]):.2f} | '
            f'E  S Acc: {np.mean(acc[1]):.2f} | '
            f'S  H -55 Acc: {np.mean(acc[2]):.2f} | '
            f'S  H +55 Acc: {np.mean(acc[3]):.2f} | '
            f'R  O -45 Acc: {np.mean(acc[4]):.2f} | '
            f'R  O +45 Acc: {np.mean(acc[5]):.2f} | '
            
            f'E  R Acc: {np.mean(acc[6]):.2f} | '
            f'J  P Acc: {np.mean(acc[7]):.2f} | '
            f'M  F Acc: {np.mean(acc[8]):.2f} | '
            f'G  F Acc: {np.mean(acc[9]):.2f} | '
            f'D  P Acc: {np.mean(acc[10]):.2f} | '
            f'S  P Acc: {np.mean(acc[11]):.2f} | '
            f'G  N Acc: {np.mean(acc[12]):.2f} | '
            f'B  R 0.2 Acc: {np.mean(acc[13]):.2f} | '
            f'B  R 2.0 Acc: {np.mean(acc[14]):.2f} | '
            f'C  T 0.2 Acc: {np.mean(acc[15]):.2f} | '
            f'C  T 2.0 Acc: {np.mean(acc[16]):.2f} | '
            f'H  U -0.2 Acc: {np.mean(acc[17]):.2f} | '
            f'H  U +0.2 Acc: {np.mean(acc[18]):.2f} | '
            f'S  A 0.2 Acc: {np.mean(acc[19]):.2f} | '
            f'S  A 2.0 Acc: {np.mean(acc[20]):.2f} | '
        )

        # save
        if (epoch % c.SAVE_freq) == 0:
            torch.save(pmodel.state_dict(), os.path.join(result_folder, "models", f"{epoch}.pt"))

    # final save
    torch.save(pmodel.state_dict(), os.path.join(result_folder, "models", f"{c.epochs}.pt"))


if __name__ == "__main__":
    main()
    
    
