import sys
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import os
from PIL import Image


ARCFACE_CHECKPOINT_PATH = "your/arcface/checkpoint/path/here"  # set your ArcFace checkpoint path
DATASET_PATH = "your/dataset/path/here"  # set your dataset path

BATCH_SIZE = 128
NUM_EPOCHS = 60
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_DIR = "./nullswap_ckpts"
os.makedirs(SAVE_DIR, exist_ok=True)

try:
    from facenet_pytorch import InceptionResnetV1
    FACENET_AVAILABLE = True
except ImportError:
    FACENET_AVAILABLE = False

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction), nn.PReLU(),
            nn.Linear(channel // reduction, channel), nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class IRBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.prelu = nn.PReLU()
        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.use_se = use_se
        if self.use_se:
            self.se = SEBlock(planes)

    def forward(self, x):
        residual = x
        out = self.bn0(x)
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)
        if self.downsample:
            residual = self.downsample(x)
        out = out + residual
        out = self.prelu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, use_se=True):
        super().__init__()
        self.inplanes = 64
        self.use_se = use_se

        self.conv1 = nn.Conv2d(3, 64, 3, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.maxpool = nn.MaxPool2d(2, 2)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], 2)
        self.layer3 = self._make_layer(block, 256, layers[2], 2)
        self.layer4 = self._make_layer(block, 512, layers[3], 2)

        self.bn2 = nn.BatchNorm2d(512)
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(512 * 7 * 7, 512)
        self.bn3 = nn.BatchNorm1d(512)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion)
            )
        layers = [block(self.inplanes, planes, stride, downsample, use_se=self.use_se)]
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x); x = self.bn1(x); x = self.prelu(x); x = self.maxpool(x)
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.bn2(x); x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x); x = self.bn3(x)
        return x

def inject_fake_module_structure():
    if "models" not in sys.modules:
        models_pkg = types.ModuleType("models")
        models_pkg.__path__ = []
        sys.modules["models"] = models_pkg
    else:
        models_pkg = sys.modules["models"]

    if "models.arcface_models" not in sys.modules:
        arcface_mod = types.ModuleType("models.arcface_models")
        sys.modules["models.arcface_models"] = arcface_mod
        models_pkg.arcface_models = arcface_mod

        arcface_mod.ResNet = ResNet
        arcface_mod.IRBlock = IRBlock
        arcface_mod.SEBlock = SEBlock

inject_fake_module_structure()

class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, k=3, s=1, p=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, k, s, p, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class SEResBottleneck_NS(nn.Module):
    def __init__(self, in_c, out_c, stride=1, reduction=16, bottleneck_ratio=4):
        super().__init__()
        assert bottleneck_ratio >= 1
        mid_c = max(out_c // bottleneck_ratio, 1)

        # 1x1 reduce
        self.conv1 = nn.Conv2d(in_c, mid_c, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_c)

        # 3x3
        self.conv2 = nn.Conv2d(mid_c, mid_c, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_c)

        # 1x1 expand
        self.conv3 = nn.Conv2d(mid_c, out_c, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_c)

        self.relu = nn.ReLU(inplace=True)

        # SE (SENet)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        se_mid = max(out_c // reduction, 1)
        self.se_fc1 = nn.Linear(out_c, se_mid, bias=False)
        self.se_fc2 = nn.Linear(se_mid, out_c, bias=False)

        # skip / downsample if needed
        self.downsample = None
        if stride != 1 or in_c != out_c:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(out_c),
            )

    def forward(self, x):
        res = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        # SE
        b, c, _, _ = out.size()
        y = self.avg_pool(out).view(b, c)
        y = F.relu(self.se_fc1(y), inplace=True)
        y = torch.sigmoid(self.se_fc2(y)).view(b, c, 1, 1)
        out = out * y

        if self.downsample is not None:
            res = self.downsample(x)

        out = self.relu(out + res)
        return out

class IDExtraction(nn.Module):
    def __init__(self, L=4):
        super().__init__()
        blocks = [SEResBottleneck_NS(64, 64, stride=1) for _ in range(L)]
        self.net = nn.Sequential(
            ConvBlock(3, 64, k=3, s=2, p=1),
            nn.MaxPool2d(3, 2, 1),
            *blocks
        )
    def forward(self, x): return self.net(x)

class PerturbationBlock(nn.Module):
    def __init__(self, M=3):
        super().__init__()
        self.refine = ConvBlock(64, 64, k=3, s=1, p=1)
        self.blocks = nn.Sequential(*[SEResBottleneck_NS(64, 64, stride=1) for _ in range(M)])

        # paper learnables
        self.alpha = nn.Parameter(torch.tensor(1.0))
        self.beta = nn.Parameter(torch.tensor(0.5))
        self.eta = nn.Parameter(torch.randn(1, 64, 1, 1) * 0.01)

    def forward(self, x):
        feat = self.blocks(self.refine(x))
        if self.training:
            noise = self.alpha * torch.randn_like(feat) + self.eta
            return feat + self.beta * noise
        return feat

class FeatureBlock(nn.Module):
    def __init__(self, N=5):
        super().__init__()
        blocks = [SEResBottleneck_NS(64, 64, stride=1) for _ in range(N)]
        self.net = nn.Sequential(
            ConvBlock(3, 32, k=3, s=2, p=1),
            ConvBlock(32, 64, k=3, s=2, p=1),
            ConvBlock(64, 64, k=3, s=1, p=1),
            *blocks
        )
    def forward(self, x): return self.net(x)

class CloakingBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.gamma_p = nn.Parameter(torch.tensor(1.0))

        self.fuse = SEResBottleneck_NS(128, 128, stride=1)

        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            ConvBlock(64, 64, k=3, s=1, p=1),

            nn.ConvTranspose2d(64, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        self.final = nn.Sequential(
            ConvBlock(67, 64, k=3, s=1, p=1),
            ConvBlock(64, 32, k=3, s=1, p=1),
            nn.Conv2d(32, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, img_feat, pert_feat, org_img):
        fused = torch.cat([img_feat, pert_feat * self.gamma_p], dim=1)
        recon_feat = self.upsample(self.fuse(fused))

        if recon_feat.shape[-2:] != org_img.shape[-2:]:
            recon_feat = F.interpolate(recon_feat, size=org_img.shape[-2:], mode="bilinear", align_corners=False)

        out = self.final(torch.cat([recon_feat, org_img], dim=1))
        return out

class NullSwap(nn.Module):
    def __init__(self):
        super().__init__()
        self.id = IDExtraction(L=4)
        self.pert = PerturbationBlock(M=3)
        self.feat = FeatureBlock(N=5)
        self.cloak = CloakingBlock()

    def forward(self, x):
        img_feat = self.feat(x)
        pert_feat = self.pert(self.id(x))
        return self.cloak(img_feat, pert_feat, x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)


class DynamicLossWeighting:

    def __init__(self, num_losses=2, k=30):
        self.history = [[] for _ in range(num_losses)]
        self.prev = [None] * num_losses
        self.k = k
        self.alpha = 3.0
        self.gamma = 0.1

    def get_weights(self, losses, epoch):
        eps_d = 1e-6
        eps_w = 1e-6

        weights = []
        beta_t = min(0.5 + self.gamma * min(epoch, 50), 2.0)

        for i, val in enumerate(losses):
            self.history[i].append(val)
            if len(self.history[i]) > self.k:
                self.history[i].pop(0)

            var = np.var(self.history[i]) if len(self.history[i]) > 1 else 0.0

            if self.prev[i] is None:
                delta = 0.0
            else:
                delta = (self.prev[i] - val) / (abs(self.prev[i]) + 1e-8)
                delta = max(delta, -1.0)
            self.prev[i] = val

            ddenom = self.alpha * var + beta_t * (1.0 + delta)

            denom = max(ddenom, eps_d)
            w = 1.0 / denom
            w = max(w, eps_w)
            weights.append(w)

        s = sum(weights)
        return [len(losses) * w / max(s, 1e-12) for w in weights]


class FlatImageFolderDataset(Dataset):
    def __init__(self, root, transform=None):
        self.files = [os.path.join(root, f) for f in os.listdir(root)
                      if f.lower().endswith((".jpg", ".png", ".jpeg"))]
        self.transform = transform
        if len(self.files) == 0:
            print(f"[Warning] No images found in {root}. Check path!")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        try:
            img = Image.open(self.files[idx]).convert("RGB")
            if self.transform:
                img = self.transform(img)
            return img, 0
        except Exception:
            return torch.zeros(3, 256, 256), 0

def load_simswap_arcface(path, device):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    print(f"[Info] Loading ArcFace Backbone from {path}...")

    model = ResNet(IRBlock, [3, 4, 14, 3], use_se=True)

    loaded_object = torch.load(path, map_location=device)

    if isinstance(loaded_object, nn.Module):
        state_dict = loaded_object.state_dict()
    elif isinstance(loaded_object, dict):
        state_dict = loaded_object["state_dict"] if "state_dict" in loaded_object else loaded_object
    else:
        raise RuntimeError(f"Unknown checkpoint format: {type(loaded_object)}")

    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict, strict=False)

    model.to(device).eval()
    for p in model.parameters():
        p.requires_grad = False
    return model

def main():
    torch.backends.cudnn.benchmark = True

    G = NullSwap().to(DEVICE)
    D = Discriminator().to(DEVICE)
    arcface = load_simswap_arcface(ARCFACE_CHECKPOINT_PATH, DEVICE)

    facenet = None
    if FACENET_AVAILABLE:
        print("[Info] Loading FaceNet for DLW (Paper Main Method)...")
        facenet = InceptionResnetV1(pretrained="vggface2").eval().to(DEVICE)
        for p in facenet.parameters():
            p.requires_grad = False

    bce_loss = nn.BCELoss()
    mse_loss = nn.MSELoss()

    try:
        import lpips
        lpips_fn = lpips.LPIPS(net="alex").to(DEVICE).eval()
    except Exception:
        print("[Warning] lpips not available. Using MSE proxy.")
        lpips_fn = lambda x, y: mse_loss(x, y)

    opt_G = torch.optim.Adam(G.parameters(), lr=5e-4, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.999))

    dlw = DynamicLossWeighting(num_losses=2 if facenet else 1)

    print(f"[Info] Loading data from {DATASET_PATH}...")
    dataset = FlatImageFolderDataset(
        DATASET_PATH,
        transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,)*3, (0.5,)*3),  # => [-1,1]
        ])
    )
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
                        drop_last=True, num_workers=4, pin_memory=True)

    print("[Info] Starting NullSwap Training...")
    for epoch in range(1, NUM_EPOCHS + 1):
        G.train()
        D.train()

        for i, (real, _) in enumerate(loader):
            real = real.to(DEVICE, non_blocking=True)

            with torch.no_grad():
                pert_detached = G(real)

            D_real = D(real)
            D_fake = D(pert_detached)

            loss_D = bce_loss(D_real, torch.ones_like(D_real)) + \
                     bce_loss(D_fake, torch.zeros_like(D_fake))

            opt_D.zero_grad(set_to_none=True)
            loss_D.backward()
            opt_D.step()


            pert = G(real)

            real_thumb_arc = F.interpolate(real, size=(112, 112), mode="bilinear", align_corners=False)
            pert_thumb_arc = F.interpolate(pert, size=(112, 112), mode="bilinear", align_corners=False)

            with torch.no_grad():
                id_real_arc = F.normalize(arcface(real_thumb_arc), p=2, dim=1)

            id_pert_arc = F.normalize(arcface(pert_thumb_arc), p=2, dim=1)


            cos_sim = torch.sum(id_real_arc * id_pert_arc, dim=1)
            l_id_arc = torch.mean(cos_sim)

            losses_dlw = [float(l_id_arc.detach().cpu().item())]
            l_id_face = torch.tensor(0.0, device=DEVICE)

            if facenet is not None:
                real_thumb_face = F.interpolate(real, size=(160, 160), mode="bilinear", align_corners=False)
                pert_thumb_face = F.interpolate(pert, size=(160, 160), mode="bilinear", align_corners=False)

                with torch.no_grad():
                    id_real_face = F.normalize(facenet(real_thumb_face), p=2, dim=1)

                id_pert_face = F.normalize(facenet(pert_thumb_face), p=2, dim=1)
                cos_sim_face = torch.sum(id_real_face * id_pert_face, dim=1)
                l_id_face = torch.mean(cos_sim_face)

                losses_dlw.append(float(l_id_face.detach().cpu().item()))

            ws = dlw.get_weights(losses_dlw, epoch)
            l_id_total = ws[0] * l_id_arc
            if facenet is not None:
                l_id_total = l_id_total + ws[1] * l_id_face

            l_mse = mse_loss(pert, real)
            l_lpips = lpips_fn(pert, real).mean() if callable(getattr(lpips_fn, "forward", None)) else lpips_fn(pert, real)
            D_pert = D(pert)
            l_adv = bce_loss(D_pert, torch.ones_like(D_pert))

            loss_G = 0.08 * l_id_total + 1.8 * l_mse + 1.2 * l_lpips + 0.1 * l_adv


            opt_G.zero_grad(set_to_none=True)
            loss_G.backward()
            opt_G.step()

            if i % 10 == 0:
                face_log = f" | ID_Face: {l_id_face.item():.4f}" if facenet is not None else ""
                print(
                    f"Ep {epoch:03d} [{i:05d}] "
                    f"L_D: {loss_D.item():.4f} | L_G: {loss_G.item():.4f} "
                    f"| ID_Arc: {l_id_arc.item():.4f}{face_log} "
                    f"| MSE: {l_mse.item():.4f} | LPIPS: {float(l_lpips):.4f} | ADV: {l_adv.item():.4f}"
                )

        ckpt_path = os.path.join(SAVE_DIR, f"nullswap_epoch_{epoch:03d}.pth")
        torch.save(G.state_dict(), ckpt_path)
        print(f"[Info] Saved: {ckpt_path}")

if __name__ == "__main__":
    main()
