import os
import time
import math
import logging
import warnings
from typing import Dict, Tuple, List

import numpy as np
import torch
import torch.nn.functional as F

import config as c
from utils.utils import *
from utils.metric import *
from utils.datasets import *
from block.Noise import Noise
from block.INN import RFE

warnings.filterwarnings("ignore")


def _init_defaults(cfg=None):
    d = dict(
        qim_Delta=2,
        qim_seed=12345,

        test_joint_steps=3000,
        joint_steps=3000,

        adv_opt_name="adam",
        adv_lr=1e-2,
        adv_lr_final=1e-4,
        adv_warmup_ratio=0.3,
        adv_grad_clip=0.0,
        adv_mse_w=230,
        acc_w=1,

        dec_lr=1e-5,
        dec_grad_clip=0,
        ndevice=2,

        target_psnr=41,

        V_mode="ortho",    
        V_rho=0.5,             


        loss_mode="qim",      
        test_loss_mode="qim", 
        noise_type="NGMIX",
        mix_epochs=0,

        # save
        SAVE_freq=1,

        method_name="OrthoMark",
    )
    for k, v in d.items():
        if not hasattr(cfg, k):
            setattr(cfg, k, v)


_init_defaults(c)


def psnr_minus1_to1(x, y):
    return psnr(x, y, 2)


def msg_to_bits(m01: torch.Tensor) -> torch.Tensor:
    return (m01 > 0).to(torch.float32)

def orthonormal_carriers(B: int, L: int, device, seed: int = 12345) -> torch.Tensor:

    if B > L:
        raise ValueError(f"B ({B}) must be <= L ({L}).")
    g = torch.randn(
        L, B,
        device=device,
        dtype=torch.float32,
        generator=torch.Generator(device=device).manual_seed(seed)
    )
    Q, _ = torch.linalg.qr(g, mode='reduced')  # [L,B]
    V = Q.transpose(0, 1).contiguous()         # [B,L]
    return V


def make_rand_unit(msg_len: int, L_dec: int, device, seed: int = 12345) -> torch.Tensor:
    g = torch.Generator(device=device).manual_seed(int(seed))
    V = torch.randn(msg_len, L_dec, device=device, dtype=torch.float32, generator=g)
    V = F.normalize(V, p=2, dim=1)
    return V

def make_rand_unit_corr(msg_len: int, L_dec: int, device,
                        seed: int = 12345, rho: float = 0.1) -> torch.Tensor:

    rho = float(rho)
    if not (0.0 <= rho < 1.0):
        raise ValueError(f"V_rho must be in [0,1). got {rho}")

    g = torch.Generator(device=device).manual_seed(int(seed))


    u = torch.randn(L_dec, device=device, dtype=torch.float32, generator=g)
    u = u / (u.norm() + 1e-12)

    eps = make_ortho(msg_len, L_dec, device, seed=int(seed) + 1)  # already [msg_len, L_dec]

    # mix + renorm
    V = rho * u.unsqueeze(0) + (1.0 - rho) * eps
    V = V / (V.norm(dim=1, keepdim=True) + 1e-12)
    return V

def make_ortho(msg_len: int, L_dec: int, device, seed: int = 12345) -> torch.Tensor:
    if msg_len > L_dec:
        raise ValueError(f"msg_len ({msg_len}) must be <= L_dec ({L_dec}).")
    g = torch.Generator(device=device).manual_seed(int(seed))
    A = torch.randn(L_dec, msg_len, device=device, dtype=torch.float32, generator=g)
    Q, _ = torch.linalg.qr(A, mode="reduced")   # [L_dec, msg_len]
    V = Q.transpose(0, 1).contiguous()         # [msg_len, L_dec]
    return V

def build_carriers(msg_len: int,
                   L_dec: int,
                   device,
                   seed: int = 12345,
                   mode: str = "ortho",
                   rho: float = 0.1) -> torch.Tensor:

    mode = (mode or "ortho").lower()
    if msg_len > L_dec:
        raise ValueError(f"msg_len ({msg_len}) must be <= L_dec ({L_dec}).")

    if mode == "ortho":
        return make_ortho(msg_len, L_dec, device, seed)
    elif mode == "rand_unit":
        return make_rand_unit(msg_len, L_dec, device, seed)
    elif mode == "rand_unit_corr":
        return make_rand_unit_corr(msg_len, L_dec, device, seed, rho=rho)
    else:
        raise ValueError(f"Unknown V_mode: {mode}. Use 'ortho'/'rand_unit'/'rand_unit_corr'.")

def compute_embed_loss(noise_tvec: torch.Tensor,
                       clean_tvec: torch.Tensor,
                       bits: torch.Tensor,
                       message_pm1: torch.Tensor,
                       mode: str,
                       Delta: float,
                       epoch: int = None,
                       mix_epochs: int = None) -> torch.Tensor:
    mode = mode.lower()
    if mode == "qim":
        return cos_periodic_loss_from_t(clean_tvec, bits, Delta) + F.mse_loss(clean_tvec, noise_tvec)
    elif mode == "mse":
        return F.mse_loss(noise_tvec, message_pm1)
    else:
        raise ValueError(f"Unknown loss_mode: {mode}")

@torch.no_grad()
def decode_bits_from_t_qim(t: torch.Tensor, Delta: float) -> torch.Tensor:
    d0 = torch.abs(t - Delta * torch.round(t / Delta))
    d1 = torch.abs(t - (Delta * (torch.round((t - 0.5 * Delta) / Delta) + 0.5)))
    return (d1 < d0).to(torch.float32)

@torch.no_grad()
def decode_bits_from_t_sign(t: torch.Tensor) -> torch.Tensor:
    return (t > 0).to(torch.float32)

@torch.no_grad()
def decode_bits_from_t(tvec: torch.Tensor, mode: str, Delta: float) -> torch.Tensor:
    mode = mode.lower()
    if mode == "qim":
        return decode_bits_from_t_qim(tvec, Delta)
    elif mode == "mse":
        return decode_bits_from_t_sign(tvec)
    else:
        raise ValueError(f"Unknown loss_mode: {mode}")


def cos_periodic_loss_from_t(t: torch.Tensor, bits: torch.Tensor, Delta: float) -> torch.Tensor:
    two_pi = 2.0 * math.pi
    bits = bits.to(device=t.device, dtype=torch.float32) 
    cst = bits * (0.5 * Delta)
    z = (t - cst) / Delta
    return (1.0 - torch.cos(two_pi * z)).mean()


@torch.no_grad()
def decode_bits_from_t_qim(t: torch.Tensor, Delta: float) -> torch.Tensor:
    d0 = torch.abs(t - Delta * torch.round(t / Delta))
    d1 = torch.abs(t - (Delta * (torch.round((t - 0.5 * Delta) / Delta) + 0.5)))
    return (d1 < d0).to(torch.float32)


@torch.no_grad()
def decode_bits_from_t_sign(t: torch.Tensor) -> torch.Tensor:
    return (t > 0).to(torch.float32)


def _take_first_if_list(x):
    return x[0] if isinstance(x, (list, tuple)) else x


def _make_x_optimizer(x_param, name: str, lr: float):
    name = name.lower()
    if name == "sgd":
        return torch.optim.SGD([x_param], lr=lr, momentum=0.9)
    return torch.optim.Adam([x_param], lr=lr)


def _set_x_lr(opt, lr_now: float):
    for g in opt.param_groups:
        g["lr"] = lr_now


def _compute_sched_lr(step_idx: int, total_steps: int, base_lr: float,
                      final_lr: float, warm_ratio: float) -> float:
    """
    Warmup (linear) + Cosine decay to final_lr.
    step_idx: 1..total_steps
    """
    total_steps = max(1, int(total_steps))
    warm_steps = max(1, min(int(total_steps * warm_ratio),
                            max(1, total_steps - 1)))

    if step_idx <= warm_steps:
        return base_lr * (step_idx / float(warm_steps))

    t = (step_idx - warm_steps) / float(max(1, total_steps - warm_steps))
    return final_lr + 0.5 * (base_lr - final_lr) * (1.0 + math.cos(math.pi * t))


def safe_mean(x_list, default=0.0) -> float:
    return float(np.mean(x_list)) if (x_list is not None and len(x_list) > 0) else float(default)


def lm_latex_row_from_acc(method_name: str, psnr_val: float, acc_means_21: List[float]) -> str:
    CR    = float(acc_means_21[0])
    ELAS  = float(acc_means_21[1])
    SHEAR = 0.5 * (float(acc_means_21[2]) + float(acc_means_21[3]))
    ROT   = 0.5 * (float(acc_means_21[4]) + float(acc_means_21[5]))
    ER    = float(acc_means_21[6])

    JPEG  = float(acc_means_21[7])
    MFv   = float(acc_means_21[8])
    GFv   = float(acc_means_21[9])
    DP    = float(acc_means_21[10])
    SPv   = float(acc_means_21[11])
    GNv   = float(acc_means_21[12])

    BR    = 0.5 * (float(acc_means_21[13]) + float(acc_means_21[14]))
    CT    = 0.5 * (float(acc_means_21[15]) + float(acc_means_21[16]))
    HUE   = 0.5 * (float(acc_means_21[17]) + float(acc_means_21[18]))
    SA    = 0.5 * (float(acc_means_21[19]) + float(acc_means_21[20]))

    vals = [JPEG, MFv, GFv, DP, SPv, GNv, ER, CR, SHEAR, ROT, ELAS, HUE, BR, CT, SA]
    AVG  = float(np.mean(vals)) if len(vals) else 0.0

    return (
        f"{method_name} "
        f"& {psnr_val:.2f} "
        f"& {JPEG:.2f} & {MFv:.2f} & {GFv:.2f} & {DP:.2f} & {SPv:.2f} & {GNv:.2f} "
        f"& {ER:.2f} & {CR:.2f} & {SHEAR:.2f} & {ROT:.2f} & {ELAS:.2f} "
        f"& {HUE:.2f} & {BR:.2f} & {CT:.2f} & {SA:.2f} "
        f"& {AVG:.2f} \\\\"
    )


def compute_embed_loss(noise_tvec: torch.Tensor,
                       clean_tvec: torch.Tensor,
                       bits: torch.Tensor,
                       message01: torch.Tensor,
                       mode: str,
                       Delta: float,
                       epoch: int = None,
                       mix_epochs: int = None) -> torch.Tensor:
    mode = mode.lower()
    if mode == "qim":
        return cos_periodic_loss_from_t(noise_tvec, bits, Delta)

    elif mode == "mse":
        return F.mse_loss(noise_tvec, message01.to(device=noise_tvec.device, dtype=noise_tvec.dtype))
    else:
        raise ValueError(f"Unknown loss_mode: {mode}")


@torch.no_grad()
def decode_bits_from_t(tvec: torch.Tensor,
                       mode: str,
                       Delta: float) -> torch.Tensor:
    mode = mode.lower()
    if mode == "qim":
        return decode_bits_from_t_qim(tvec, Delta)
    elif mode == "mse":
        return decode_bits_from_t_sign(tvec)
    else:
        raise ValueError(f"Unknown loss_mode: {mode}")


# --------------------------- noise builder --------------------------

def build_noises(noise_type: str):
    noise_type = noise_type.upper()
    test_noise = Noise([
        "PCombined(["
        "RC(91,91), Elastic((3,3)), "
        "AF(s=(-55,-55)), AF(s=(55,55)), "
        "AF(r=(-45,-45)), AF(r=(45,45)), "
        "Erase((0.5,0.5)), "
        "JpegTest(50), MF(7), GF(2), Dropout(0.5), SP(0.1), GN(0.04), "
        "Bright(0.2, 0.2), Bright(2, 2), "
        "Contrast(0.2, 0.2), Contrast(2, 2), "
        "Hue(-0.1, -0.1), Hue(0.1, 0.1), "
        "Saturation(0.2, 0.2), Saturation(2, 2)"
        "])"
    ])

    test_train_noise = Noise([
        "Combined(["
        "RC(91,91),"
        "Elastic((3,3)),"
        "KJpeg(50),"
        "Erase((0.5,0.5)),"
        "AF(s=(-55,55)),"
        "AF(r=(-45,45)),"
        "MF(7),GF(2),Dropout(0.5),"
        "SP(0.1),"
        "GN(0.04),"
        "Bright(0.2,2),"
        "Contrast(0.2,2),,"
        "Hue(-0.1,0.1),"
        "Saturation(0.2,2)"
        "])"
    ])

    return test_noise, test_train_noise


def collect_all_test_covers_to_batch(testloader, device) -> torch.Tensor:
    buf = []
    for cover in testloader:
        if isinstance(cover, dict):
            for k in ["image", "img", "cover", "cover_image", "x"]:
                if k in cover:
                    cover = cover[k]
                    break
        cover = cover.to(device).float()
        buf.append(cover)
    if len(buf) == 0:
        raise RuntimeError("testloader is empty.")
    return torch.cat(buf, dim=0)


# --------------------------- main script --------------------------

def main():
    result_folder = os.path.join(
        c.MODEL_PATH, time.strftime(c.PROJECT_NAME + "__%H_%M_%S", time.localtime())
    )
    os.makedirs(os.path.join(result_folder, "models"), exist_ok=True)

    device = torch.device(f"cuda:{c.ndevice}" if torch.cuda.is_available() else "cpu")

    pmodel = RFE().to(device)
    for p in pmodel.parameters():
        p.requires_grad_(True)

    # optional resume
    if getattr(c, "tain_next", False):
        ckpt = os.path.join(c.MODEL_PATH, c.CONTINUE_PATH, "models", f"{c.CONTINUE_EPOCH}.pt")
        load(pmodel, ckpt)

    # logger
    setup_logger('train', result_folder, 'logging',
                 level=logging.INFO, screen=True, tofile=True)
    logger = logging.getLogger('train')

    # noise layers
    test_noise, test_train_noise = build_noises(c.noise_type)

    # carrier cache
    V_cache: Dict[Tuple[int, int], torch.Tensor] = {}

    for epoch in range(1, c.epochs + 1):
        # ---------------- TEST ----------------
        pmodel.eval()

        cover_all = collect_all_test_covers_to_batch(testloader, device=device)
        B, C, H, W = cover_all.shape

        msg_len = int(c.test_message_length)
        message = torch.tensor(
            np.random.choice([-1, 1], (B, msg_len)),
            device=device,
            dtype=torch.float32
        )
        bits = msg_to_bits(message)

        with torch.no_grad():
            dec0 = pmodel(cover_all)
        L_dec = int(np.prod(dec0.shape[1:])) if dec0.dim() == 4 else int(dec0.shape[1])

        V_mode = str(getattr(c, "V_mode", "ortho")).lower()
        V_rho  = float(getattr(c, "V_rho", 0.1))  # used only for rand_unit_corr

        key = (L_dec, msg_len, V_mode, V_rho if V_mode == "rand_unit_corr" else -1.0)
        if key not in V_cache:
            V_cache[key] = build_carriers(
                msg_len=msg_len,
                L_dec=L_dec,
                device=device,
                seed=c.qim_seed,
                mode=V_mode,
                rho=V_rho
            )
        V = V_cache[key]
            
        x = cover_all.clone().detach().requires_grad_(True)
        opt_x = _make_x_optimizer(x, c.adv_opt_name, c.adv_lr)

        total_steps = int(c.test_joint_steps)
        for step_i in range(1, total_steps + 1):
            lr_now = _compute_sched_lr(step_i, total_steps,
                                       c.adv_lr, c.adv_lr_final, c.adv_warmup_ratio)
            _set_x_lr(opt_x, lr_now)

            opt_x.zero_grad(set_to_none=True)

            yk = test_train_noise([x, cover_all])
            yk = _take_first_if_list(yk)

            noise_dec_out = pmodel(yk) if c.mode != 'PH' else pmodel(yk)[0]
            noise_tvec = noise_dec_out.reshape(B, -1) @ V.t()  # [B,msg_len]

            loss_x = compute_embed_loss(
                noise_tvec=noise_tvec,
                clean_tvec=noise_tvec,
                bits=bits,
                message01=message,
                mode=c.test_loss_mode,
                Delta=c.qim_Delta,
                epoch=epoch,
                mix_epochs=c.mix_epochs
            )

            if c.adv_mse_w and c.adv_mse_w > 0:
                loss_x = c.acc_w * loss_x + c.adv_mse_w * F.mse_loss(x, cover_all)

            loss_x.backward()
            if c.adv_grad_clip and c.adv_grad_clip > 0:
                torch.nn.utils.clip_grad_norm_([x], max_norm=c.adv_grad_clip)

            opt_x.step()
            with torch.no_grad():
                x.clamp_(-1, 1)

        with torch.no_grad():
            psnr_x = float(psnr_minus1_to1(cover_all, x))

            no_test_stego_image_list = test_noise([x, cover_all])
            acc = [[] for _ in range(21)]

            for idx, no_test_stego_image in enumerate(no_test_stego_image_list):
                dec_out = pmodel(no_test_stego_image) if c.mode != 'PH' else pmodel(no_test_stego_image)[0]
                tvec = dec_out.reshape(B, -1) @ V.t()

                bits_hat = decode_bits_from_t(tvec=tvec, mode=c.test_loss_mode, Delta=c.qim_Delta)
                error_rate = (bits_hat.round() != bits).float().mean().item()
                acc[idx].append((1 - error_rate) * 100.0)

            if (epoch - 1) % 10 == 0:
                from torchvision.utils import save_image
                out_dir = os.path.join(result_folder, "vis_test")
                os.makedirs(out_dir, exist_ok=True)

                cover_vis = ((x[0:1] - cover_all[0:1]) * 10 + 1.0) / 2.0   # [-1,1] -> [0,1]
                x_vis = (x[0:1] + 1.0) / 2.0
                save_image(cover_vis, os.path.join(out_dir, f"epoch{epoch:03d}_cover.png"))
                save_image(x_vis, os.path.join(out_dir, f"epoch{epoch:03d}_stego.png"))

        logger.info(
            f"[TEST ] epoch={epoch} | "
            f"PSNR(x)={psnr_x:.2f} dB | "
            f'R  C Acc: {safe_mean(acc[0]):.2f} | '
            f'E  S Acc: {safe_mean(acc[1]):.2f} | '
            f'S  H -55 Acc: {safe_mean(acc[2]):.2f} | '
            f'S  H +55 Acc: {safe_mean(acc[3]):.2f} | '
            f'R  O -45 Acc: {safe_mean(acc[4]):.2f} | '
            f'R  O +45 Acc: {safe_mean(acc[5]):.2f} | '
            f'E  R Acc: {safe_mean(acc[6]):.2f} | '
            f'J  P Acc: {safe_mean(acc[7]):.2f} | '
            f'M  F Acc: {safe_mean(acc[8]):.2f} | '
            f'G  F Acc: {safe_mean(acc[9]):.2f} | '
            f'D  P Acc: {safe_mean(acc[10]):.2f} | '
            f'S  P Acc: {safe_mean(acc[11]):.2f} | '
            f'G  N Acc: {safe_mean(acc[12]):.2f} | '
            f'B  R 0.2 Acc: {safe_mean(acc[13]):.2f} | '
            f'B  R 2.0 Acc: {safe_mean(acc[14]):.2f} | '
            f'C  T 0.2 Acc: {safe_mean(acc[15]):.2f} | '
            f'C  T 2.0 Acc: {safe_mean(acc[16]):.2f} | '
            f'H  U - Acc: {safe_mean(acc[17]):.2f} | '
            f'H  U + Acc: {safe_mean(acc[18]):.2f} | '
            f'S  A 0.2 Acc: {safe_mean(acc[19]):.2f} | '
            f'S  A 2.0 Acc: {safe_mean(acc[20]):.2f} | '
        )
        acc_means_21 = [safe_mean(acc[i], default=0.0) for i in range(21)]

        print("\n%%%%%%%%%%%% LM-style one-row %%%%%%%%%%%%\n")
        print("% PSNR JPEG MF GF DP SP GN Erase CR SHEAR Rotate Elastic Hue Bright Contrast Saturate AVG")
        print(lm_latex_row_from_acc(getattr(c, "method_name", "OUR"), psnr_x, acc_means_21))
        print("\n============================================================\n")


if __name__ == "__main__":
    main()