#!/usr/bin/env python3
"""
whispersplat_refiner.py

Two-stage + alternating leakage correction pipeline.
"""
import os
import argparse
import csv
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

try:
    import torchvision.models as models 
    from torchvision import transforms
    VGG_AVAILABLE = True
except ImportError:
    VGG_AVAILABLE = False

try:
    from skimage.metrics import structural_similarity as ssim_fn
    HAVE_SSIM = True
except ImportError:
    HAVE_SSIM = False

try:
    import lpips

    HAVE_LPIPS = True
    LPIPS_MODEL = lpips.LPIPS(net='vgg').eval()
except ImportError:
    HAVE_LPIPS = False
    LPIPS_MODEL = None

EPS = 1e-8


def safe_metric(val):
    if val is None:
        return float("nan")
    try:
        v = float(val)
        if np.isnan(v) or np.isinf(v):
            return float("nan")
        return v
    except Exception:
        return float("nan")


def format_metrics(psnr, ssim, lpips_val):
    psnr_s = f"{psnr:.4f} dB" if not np.isnan(psnr) else "nan"
    ssim_s = f"{ssim:.4f}" if not np.isnan(ssim) else "nan"
    lpips_s = f"{lpips_val:.4f}" if not np.isnan(lpips_val) else "nan"
    return f"PSNR: {psnr_s}, SSIM: {ssim_s}, LPIPS: {lpips_s}"


def load_image_as_tensor(path, device):
    img = Image.open(path).convert("RGB") # convert to RGB if not already
    arr = np.array(img, dtype=np.float32) / 255.0  # size is H,W,3
    tensor = torch.from_numpy(arr.transpose(2, 0, 1)).to(device)  # size is 3,H,W
    return tensor


def psnr_from_mse(mse, max_val=1.0):
    if mse == 0 or torch.isclose(mse, torch.tensor(0.0, device=mse.device)):
        return float("inf")
    return 10.0 * torch.log10((max_val ** 2) / mse)


def compute_psnr(a, b):
    return psnr_from_mse(torch.mean((a - b) ** 2))


def compute_ssim(a, b):
    if not HAVE_SSIM:
        return None
    a_np = a.detach().clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
    b_np = b.detach().clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
    try:
        s = ssim_fn(a_np, b_np, channel_axis=2, data_range=1.0)
    except TypeError:
        s = ssim_fn(a_np, b_np, multichannel=True, data_range=1.0)
    return s


def compute_lpips(a, b, model, device):
    if not HAVE_LPIPS or model is None:
        return None
    target_device = a.device
    try:
        model_device = next(model.parameters()).device
    except StopIteration:
        model_device = target_device
    if model_device != target_device:
        model.to(target_device)
    a_in = (a.unsqueeze(0) * 2.0 - 1.0).to(target_device)
    b_in = (b.unsqueeze(0) * 2.0 - 1.0).to(target_device)
    with torch.no_grad():
        val = model(a_in, b_in)
    return val.squeeze().item()


def save_tensor_as_image(tensor, path):
    arr = tensor.detach().clamp(0, 1).cpu().numpy()  # size is 3,H,W
    arr = (arr.transpose(1, 2, 0) * 255.0).astype(np.uint8)  # size is H,W,3
    Image.fromarray(arr).save(path)


class ResidualPredictor(nn.Module):
    def __init__(self, hidden_channels=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, 3, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return self.net(x)


class VGGPerceptual(nn.Module):
    def __init__(self, device, layer='relu2_2'):
        super().__init__()
        if not VGG_AVAILABLE:
            raise RuntimeError("torchvision required for perceptual loss")
        vgg = models.vgg16(pretrained=True).features.to(device).eval()
        layer_map = {
            'relu1_2': 3,
            'relu2_2': 8,
            'relu3_3': 15,
            'relu4_3': 22,
        }
        self.target_idx = layer_map.get(layer, 8)
        self.vgg = vgg
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1))

    def forward(self, x):
        x = (x - self.mean) / self.std
        out = x
        for i, layer in enumerate(self.vgg):
            out = layer(out)
            if i == self.target_idx:
                return out
        return out


def init_alpha_from_patches(h, t, c, patch_size):
    C, H, W = h.shape
    alpha = torch.zeros((C, H, W), device=h.device)
    count = torch.zeros((C, H, W), device=h.device)
    stride = max(1, patch_size // 2)
    r = h - t
    for ch in range(C):
        for y0 in range(0, H, stride):
            for x0 in range(0, W, stride):
                y1 = min(y0 + patch_size, H)
                x1 = min(x0 + patch_size, W)
                r_patch = r[ch, y0:y1, x0:x1]
                c_patch = c[ch, y0:y1, x0:x1]
                r_flat = r_patch.flatten()
                c_flat = c_patch.flatten()
                num = torch.dot(r_flat, c_flat)
                den = torch.dot(c_flat, c_flat) + EPS
                a = num / den
                alpha[ch, y0:y1, x0:x1] += a
                count[ch, y0:y1, x0:x1] += 1.0
    valid = count > 0
    alpha[valid] = alpha[valid] / count[valid]
    return alpha


def smooth_alpha_optimization(h, t, c, alpha_init, alpha_iters, lr, lambda_reg_alpha, lambda_tv):
    alpha = alpha_init.clone().detach().requires_grad_(True)
    optimizer = torch.optim.Adam([alpha], lr=lr)
    for i in range(alpha_iters):
        optimizer.zero_grad()
        h_corr = h - alpha * c
        mse_loss = torch.mean((h_corr - t) ** 2)
        reg = lambda_reg_alpha * torch.mean(alpha ** 2)
        tv_h = torch.abs(alpha[:, :, 1:] - alpha[:, :, :-1]).mean()
        tv_v = torch.abs(alpha[:, 1:, :] - alpha[:, :-1, :]).mean()
        tv = tv_h + tv_v
        loss = mse_loss + reg + lambda_tv * tv
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        alpha_final = alpha.detach()
        h1 = h - alpha_final * c
    return h1, alpha_final


def train_predictor_with_perceptual(c, h_stage, t, residual, predictor_steps, lr, weight_decay, lambda_perc,
                                   perceptual_extractor, device, hidden_channels):
    model = ResidualPredictor(hidden_channels=hidden_channels).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    for step in tqdm(range(predictor_steps), desc="training predictor"):
        optimizer.zero_grad()
        pred = model(c.unsqueeze(0))[0]
        h_corrected = h_stage - pred
        mse_part = F.mse_loss(pred, residual)
        loss = mse_part
        if perceptual_extractor is not None:
            phi_corr = perceptual_extractor(h_corrected.unsqueeze(0))
            phi_target = perceptual_extractor(t.unsqueeze(0))
            perc_loss = F.mse_loss(phi_corr, phi_target)
            loss = loss + lambda_perc * perc_loss
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        final_pred = model(c.unsqueeze(0))[0]
    return final_pred, model


def parse_list_of_ints(s):
    return [int(x) for x in s.split(",") if x.strip()]


def annotate_comparison(images, labels, font_size=16):
    assert len(images) == len(labels)
    widths, heights = zip(*(img.size for img in images))
    total_w = sum(widths)
    h = heights[0]
    caption_h = font_size * 2 + 8
    composite = Image.new("RGB", (total_w, h + caption_h), (30, 30, 30))
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", font_size)
    except:
        font = ImageFont.load_default()
    draw = ImageDraw.Draw(composite)
    x = 0
    for img, label in zip(images, labels):
        composite.paste(img, (x, caption_h))
        w = img.width
        draw.rectangle([x, 0, x + w, caption_h], fill=(50, 50, 50))
        try:
            bbox = draw.multiline_textbbox((0, 0), label, font=font, spacing=2)
            text_w = bbox[2] - bbox[0]
            text_h = bbox[3] - bbox[1]
        except AttributeError:
            text_w, text_h = font.getsize(label.split("\n")[0])
        text_x = x + (w - text_w) / 2
        text_y = (caption_h - text_h) / 2
        draw.multiline_text((text_x, text_y), label, font=font, fill=(255, 255, 255), align="center",
                            spacing=2)
        x += w
    return composite


def main():
    parser = argparse.ArgumentParser(description="Alternating smooth-alpha + learned predictor with perceptual loss.")
    parser.add_argument("--hidden", required=True)
    parser.add_argument("--target", required=True)
    parser.add_argument("--clean", required=True)
    parser.add_argument("--out_dir", required=True)
    parser.add_argument("--patch_sizes", type=str, default="64,32,16",
                        help="Comma-separated patch sizes (coarse to fine)")
    parser.add_argument("--alpha_iters", type=int, default=300)
    parser.add_argument("--alpha_lr", type=float, default=1e-2)
    parser.add_argument("--lambda_reg_alpha", type=float, default=1e-3)
    parser.add_argument("--lambda_tv", type=float, default=0.04)
    parser.add_argument("--predictor_steps1", type=int, default=1000)
    parser.add_argument("--predictor_lr1", type=float, default=1e-3)
    parser.add_argument("--predictor_steps2", type=int, default=500)
    parser.add_argument("--predictor_lr2", type=float, default=5e-4)
    parser.add_argument("--weight_decay", type=float, default=1e-4)
    parser.add_argument("--lambda_perc", type=float, default=0.1, help="Weight for perceptual loss")
    parser.add_argument("--blend_gamma", type=float, default=1.0, help="Blend corrected vs previous stage")
    parser.add_argument("--alt_rounds", type=int, default=1,
                        help="Number of alternating alpha/predictor refinement rounds")
    parser.add_argument("--hidden_channels", type=int, default=32)
    parser.add_argument("--device", type=str, default="cpu")
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() and args.device.startswith("cuda") else args.device)
    os.makedirs(args.out_dir, exist_ok=True)

    if HAVE_LPIPS and LPIPS_MODEL is not None:
        LPIPS_MODEL.to(device)

    h = load_image_as_tensor(args.hidden, device)
    t = load_image_as_tensor(args.target, device)
    c = load_image_as_tensor(args.clean, device)
    if h.shape != t.shape or h.shape != c.shape:
        raise ValueError(f"Shape mismatch: hidden {h.shape}, target {t.shape}, clean {c.shape}")

    print(f"Device: {device}")
    psnr_orig = compute_psnr(h, t)
    ssim_orig = compute_ssim(h, t)
    lpips_orig = compute_lpips(h, t, LPIPS_MODEL, device)
    psnr_orig_v = safe_metric(psnr_orig.item() if hasattr(psnr_orig, "item") else psnr_orig)
    ssim_orig_v = safe_metric(ssim_orig)
    lpips_orig_v = safe_metric(lpips_orig)
    print(f"SNR before any correction: {format_metrics(psnr_orig_v, ssim_orig_v, lpips_orig_v)}")

    # initial smooth alpha
    patch_sizes = parse_list_of_ints(args.patch_sizes)
    h_stage1 = h.clone()
    alpha_map = None
    psnr_stage1_v = float("nan")
    ssim_stage1_v = float("nan")
    lpips_stage1_v = float("nan")
    for idx, ps in enumerate(patch_sizes):
        print(f"Stage1 pass {idx + 1}/{len(patch_sizes)}: patch size {ps}")
        alpha_init = init_alpha_from_patches(h_stage1, t, c, ps)
        h_stage1, alpha_map = smooth_alpha_optimization(
            h_stage1, t, c,
            alpha_init=alpha_init,
            alpha_iters=args.alpha_iters,
            lr=args.alpha_lr,
            lambda_reg_alpha=args.lambda_reg_alpha,
            lambda_tv=args.lambda_tv,
        )
        psnr_stage1 = compute_psnr(h_stage1.clamp(0, 1), t)
        ssim_stage1 = compute_ssim(h_stage1.clamp(0, 1), t)
        lpips_stage1 = compute_lpips(h_stage1, t, LPIPS_MODEL, device)
        psnr_stage1_v = safe_metric(psnr_stage1.item() if hasattr(psnr_stage1, "item") else psnr_stage1)
        ssim_stage1_v = safe_metric(ssim_stage1)
        lpips_stage1_v = safe_metric(lpips_stage1)
        print(f"  After scale {ps} {format_metrics(psnr_stage1_v, ssim_stage1_v, lpips_stage1_v)}")
    save_tensor_as_image(h_stage1, os.path.join(args.out_dir, "hidden_stage1_multi.png"))

    # Prepare perceptual extractor
    perceptual_extractor = None
    if VGG_AVAILABLE:
        try:
            perceptual_extractor = VGGPerceptual(device=device, layer='relu2_2')
            perceptual_extractor.eval()
        except Exception as e:
            print(f"Warning: failed to initialize perceptual extractor: {e}")
            perceptual_extractor = None
    else:
        print("torchvision not available; skipping perceptual loss.")

    h_current = h_stage1
    # Alternating refinement rounds
    psnr_pred_v = float("nan")
    ssim_pred_v = float("nan")
    lpips_pred_v = float("nan")
    psnr_refine_v = float("nan")
    ssim_refine_v = float("nan")
    lpips_refine_v = float("nan")

    for round_i in range(args.alt_rounds):
        print(f"--- Alternating round {round_i + 1}/{args.alt_rounds} ---")
        residual1 = h_current - t
        print("Stage2: training predictor (round) with perceptual loss")
        pred1, model1 = train_predictor_with_perceptual(
            c, h_current, t, residual1,
            predictor_steps=args.predictor_steps1,
            lr=args.predictor_lr1,
            weight_decay=args.weight_decay,
            lambda_perc=args.lambda_perc,
            perceptual_extractor=perceptual_extractor,
            device=device,
            hidden_channels=args.hidden_channels,
        )
        h_pred = h_current - pred1
        if args.blend_gamma != 1.0:
            h_pred = args.blend_gamma * h_pred + (1 - args.blend_gamma) * h_current
        psnr_pred = compute_psnr(h_pred.clamp(0, 1), t)
        ssim_pred = compute_ssim(h_pred.clamp(0, 1), t)
        lpips_pred = compute_lpips(h_pred, t, LPIPS_MODEL, device)
        psnr_pred_v = safe_metric(psnr_pred.item() if hasattr(psnr_pred, "item") else psnr_pred)
        ssim_pred_v = safe_metric(ssim_pred)
        lpips_pred_v = safe_metric(lpips_pred)
        print(f"After predictor (round {round_i + 1}) {format_metrics(psnr_pred_v, ssim_pred_v, lpips_pred_v)}")

        # Re-optimize alpha on updated hidden
        h_alpha_refined = h_pred.clone()
        for idx, ps in enumerate(patch_sizes):
            print(f"Refine Stage1 pass {idx + 1}/{len(patch_sizes)}: patch size {ps}")
            alpha_init = init_alpha_from_patches(h_alpha_refined, t, c, ps)
            h_alpha_refined, alpha_map = smooth_alpha_optimization(
                h_alpha_refined, t, c,
                alpha_init=alpha_init,
                alpha_iters=args.alpha_iters // 2,
                lr=args.alpha_lr * 0.5,
                lambda_reg_alpha=args.lambda_reg_alpha,
                lambda_tv=args.lambda_tv,
            )
            psnr_refine = compute_psnr(h_alpha_refined.clamp(0, 1), t)
            ssim_refine = compute_ssim(h_alpha_refined.clamp(0, 1), t)
            lpips_refine = compute_lpips(h_alpha_refined, t, LPIPS_MODEL, device)
            psnr_refine_v = safe_metric(psnr_refine.item() if hasattr(psnr_refine, "item") else psnr_refine)
            ssim_refine_v = safe_metric(ssim_refine)
            lpips_refine_v = safe_metric(lpips_refine)
            print(f"  After refine scale {ps} {format_metrics(psnr_refine_v, ssim_refine_v, lpips_refine_v)}")
        h_current = h_alpha_refined

    # Final predictor after alternations
    residual_final = h_current - t
    print("Final predictor training")
    pred_final, model_final = train_predictor_with_perceptual(
        c, h_current, t, residual_final,
        predictor_steps=args.predictor_steps2,
        lr=args.predictor_lr2,
        weight_decay=args.weight_decay,
        lambda_perc=args.lambda_perc,
        perceptual_extractor=perceptual_extractor,
        device=device,
        hidden_channels=args.hidden_channels,
    )
    h_final = h_current - pred_final
    if args.blend_gamma != 1.0:
        h_final = args.blend_gamma * h_final + (1 - args.blend_gamma) * h_current
    psnr_final = compute_psnr(h_final.clamp(0, 1), t)
    ssim_final = compute_ssim(h_final.clamp(0, 1), t)
    lpips_final = compute_lpips(h_final, t, LPIPS_MODEL, device)
    psnr_final_v = safe_metric(psnr_final.item() if hasattr(psnr_final, "item") else psnr_final)
    ssim_final_v = safe_metric(ssim_final)
    lpips_final_v = safe_metric(lpips_final)
    print(f"Final {format_metrics(psnr_final_v, ssim_final_v, lpips_final_v)}")

    # Save stages
    save_tensor_as_image(h_stage1, os.path.join(args.out_dir, "hidden_stage1_multi.png"))
    save_tensor_as_image(h_pred, os.path.join(args.out_dir, "hidden_after_predictor_round.png"))
    save_tensor_as_image(h_current, os.path.join(args.out_dir, "hidden_after_alpha_refine.png"))
    save_tensor_as_image(h_final, os.path.join(args.out_dir, "hidden_final.png"))

    try:
        def to_pil(tensor):
            arr = (tensor.detach().clamp(0, 1).cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)
            return Image.fromarray(arr)

        label_hidden = f"Hidden\nPSNR {psnr_orig_v:.2f}dB"
        if not np.isnan(ssim_orig_v):
            label_hidden += f"\nSSIM {ssim_orig_v:.3f}"
        if not np.isnan(lpips_orig_v):
            label_hidden += f"\nLPIPS {lpips_orig_v:.3f}"

        label_stage1 = f"Stage1\nPSNR {psnr_stage1_v:.2f}dB"
        if not np.isnan(ssim_stage1_v):
            label_stage1 += f"\nSSIM {ssim_stage1_v:.3f}"
        if not np.isnan(lpips_stage1_v):
            label_stage1 += f"\nLPIPS {lpips_stage1_v:.3f}"

        label_pred = f"After Predictor\nPSNR {psnr_pred_v:.2f}dB"
        if not np.isnan(ssim_pred_v):
            label_pred += f"\nSSIM {ssim_pred_v:.3f}"
        if not np.isnan(lpips_pred_v):
            label_pred += f"\nLPIPS {lpips_pred_v:.3f}"

        label_final = f"Final\nPSNR {psnr_final_v:.2f}dB"
        if not np.isnan(ssim_final_v):
            label_final += f"\nSSIM {ssim_final_v:.3f}"
        if not np.isnan(lpips_final_v):
            label_final += f"\nLPIPS {lpips_final_v:.3f}"

        imgs = [
            to_pil(h),
            to_pil(h_stage1),
            to_pil(h_pred),
            to_pil(h_final),
            to_pil(t),
        ]
        labels = [
            label_hidden,
            label_stage1,
            label_pred,
            label_final,
            "Target"
        ]
        comp = annotate_comparison(imgs, labels, font_size=18)
        comp_path = os.path.join(args.out_dir, "comparison_full_annotated.png")
        comp.save(comp_path)
        print(f"Saved comparison image to: {comp_path}")
    except Exception as e:
        print(f"Failed to compose annotated comparison: {e}")

    # Save detailed metrics CSV
    metrics_path = os.path.join(args.out_dir, "detailed_metrics.csv")
    with open(metrics_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["stage", "psnr_dB", "ssim", "lpips"])
        writer.writerow(["hidden_before", f"{psnr_orig_v:.4f}", f"{ssim_orig_v:.4f}", f"{lpips_orig_v:.4f}"])
        writer.writerow(["after_stage1", f"{psnr_stage1_v:.4f}", f"{ssim_stage1_v:.4f}", f"{lpips_stage1_v:.4f}"])
        writer.writerow(["after_predictor", f"{psnr_pred_v:.4f}", f"{ssim_pred_v:.4f}", f"{lpips_pred_v:.4f}"])
        writer.writerow(["after_refine", f"{psnr_refine_v:.4f}", f"{ssim_refine_v:.4f}", f"{lpips_refine_v:.4f}"])
        writer.writerow(["final", f"{psnr_final_v:.4f}", f"{ssim_final_v:.4f}", f"{lpips_final_v:.4f}"])
    print(f"Saved detailed metrics to {metrics_path}")



if __name__ == "__main__":
    main()
