import os
import omegaconf
from tqdm import tqdm
import gc
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torchvision

from videoseal.utils.display import save_img
from videoseal.utils import Timer
from videoseal.evals.full import setup_model_from_checkpoint
from videoseal.evals.metrics import bit_accuracy, pvalue, psnr, ssim
from videoseal.augmentation import Identity, JPEG, Crop, Resize
from videoseal.modules.jnd import JND

import argparse

to_tensor = torchvision.transforms.ToTensor()
to_pil = torchvision.transforms.ToPILImage()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from scipy.special import betainc

def pfa(c, D, double=False):
    pfa = 0.5 * betainc((D - 1) * 0.5, 0.5, 1.0 - c**2)
    if double:
        return 2.0 * pfa
    else:
        if c < 0.0:
            pfa = 1.0 - pfa
    return pfa

# Checkpoint
ckpt_name="semanticseal"#
ckpt_path = "/home/user/semanticseal_256/checkpoint1150.pth"

class EmbedderModelWrapper(torch.nn.Module):
    def __init__(self, inner: torch.nn.Module):
        super().__init__()
        self.inner = inner

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        # out-of-place normalize y
        norm = torch.sqrt((y * y).sum(dim=1, keepdim=True))
        y_norm = y / (norm + 1e-9)

        # bring x into [0,1] without in-place
        x01 = (x + 1.0) * 0.5

        z = self.inner.embed(imgs=x01, msgs=y_norm, lowres_attenuation=True)["imgs_w"]
        return z * 2.0 - 1.0
        
class ExtractorModelWrapper(torch.nn.Module):
    def __init__(self, inner: torch.nn.Module):
        super().__init__()
        self.inner = inner

    def forward(self, x: torch.Tensor):
        # bring into [0,1] out‐of‐place
        x01 = (x + 1.0) * 0.5

        # run decoder
        preds = self.inner.detect(x01, is_video=False)["preds"]

        # clone to avoid in‐place/view issues, then drop mask bit
        return preds.clone()[:, 1:]

if __name__ == "__main__":

    parser =  argparse.ArgumentParser()
    parser.add_argument("--scaling_w", type=float, default=1.0, help="Strength of the watermark")
    args = parser.parse_args()

    # Load model
    model = setup_model_from_checkpoint(ckpt_path)
    model.eval()
    model.augmenter = None # disable augmenter
    #model.compile()
    model.to(device)

    # control the watermark strength with this parameter.
    # The higher the value, the more visible the watermark, but also the more robust it is to attacks.
    model.blender.scaling_w *= args.scaling_w
    
    batch_sizes = [1,64]
    dummy_image = torch.randn((1,3,256,256)).to(device)
    dummy_message = torch.randn((1,256)).to(device)
    
    # save encoder model
    embedder = EmbedderModelWrapper(model)
    torchscript_m = torch.jit.trace(embedder, (dummy_image, dummy_message), strict=False)
    torch.jit.save(torchscript_m, f"models/{ckpt_name}_enc.pth")

    # save decoder model
    extractor = ExtractorModelWrapper(model)
    torchscript_m = torch.jit.trace(extractor, (dummy_image,),strict=False)
    torch.jit.save(torchscript_m, f"models/{ckpt_name}_dec.pth")