import argparse
import os
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from tqdm import tqdm

import wandb
from analytic.pyramids import GaussianPyramid
from nn_baselines.src.diffusion_utils import GaussianDiffusionTrainer
from nn_baselines.src.training_utils import load_model
from nn_baselines.src.unet_model import DownSample, ResBlock, UpSample


def normalize(images: torch.Tensor) -> torch.Tensor:
    """Normalize images to [0, 1] range for visualization."""
    return (images - images.min()) / (images.max() - images.min())


class ProbeResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout=0.0):
        super().__init__()
        # Calculate number of groups for GroupNorm (must divide both in_ch and out_ch)
        num_groups = min(32, in_ch, out_ch)
        while num_groups > 1 and (in_ch % num_groups != 0 or out_ch % num_groups != 0):
            num_groups -= 1

        self.block1 = nn.Sequential(
            nn.GroupNorm(num_groups, in_ch),
            nn.SiLU(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            nn.SiLU(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(num_groups, out_ch),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(module.weight)
                nn.init.zeros_(module.bias)
        nn.init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)
        h = h + self.shortcut(x)
        return h


class UNetWithProbes(nn.Module):
    def __init__(self, unet: nn.Module):
        super().__init__()
        self.unet = unet

        # Count number of ResBlocks in downblocks
        num_res_blocks = sum(
            1 for layer in unet.downblocks if isinstance(layer, ResBlock)
        )

        # Add probes to skip connections
        self.skip_probes = nn.ModuleList()
        for i in range(len(unet.downblocks)):
            if isinstance(unet.downblocks[i], DownSample):
                in_ch = unet.downblocks[i - 1].block1[2].out_channels
                self.skip_probes.append(
                    ProbeResBlock(
                        in_ch, 3, unet.time_embedding.timembedding[1].out_features
                    )
                )

        # Create pyramid on the same device as the UNet
        self.pyramid = GaussianPyramid(
            num_levels=len(self.skip_probes) + 1,
            resolution=32,
            grayscale=False,
            device=unet.head.weight.device,  # Get device from UNet
        )

        # Add probe to bottleneck
        self.bottleneck_probe = ProbeResBlock(
            unet.middleblocks[0].block1[2].out_channels,
            3,
            unet.time_embedding.timembedding[1].out_features,
        )

        # Freeze UNet weights
        for param in self.unet.parameters():
            param.requires_grad = False

    def forward(self, x, t):
        # Get time embedding

        with torch.no_grad():
            temb = self.unet.time_embedding(t)

            # Forward through UNet while collecting skip connections
            h = self.unet.head(x)
            skip_connections = []

            for layer in self.unet.downblocks:
                if isinstance(layer, DownSample):
                    skip_connections.append(h.clone())
                h = layer(h, temb)

            # Middle
            for layer in self.unet.middleblocks:
                h = layer(h, temb)

        # Get reconstructions from each probe
        reconstructions = []

        # Match skip connections to pyramid levels
        for i, (skip, probe) in enumerate(zip(skip_connections, self.skip_probes)):
            recon = probe(skip, temb)
            reconstructions.append(recon.clone())

        bottleneck_recon = self.bottleneck_probe(h, temb)
        reconstructions.append(bottleneck_recon)

        return reconstructions


class UNetWithUpscaleProbes(nn.Module):
    def __init__(self, unet: nn.Module):
        super().__init__()
        self.unet = unet

        # Add probes before each upscale and before tail
        self.skip_probes = nn.ModuleList()

        # Add probes before each upscale operation
        for layer in unet.upblocks:
            if isinstance(layer, UpSample):
                in_ch = layer.main.in_channels
                self.skip_probes.append(
                    ProbeResBlock(
                        in_ch, 3, unet.time_embedding.timembedding[1].out_features
                    )
                )

        # Add probe before tail
        self.tail_probe = ProbeResBlock(
            unet.tail[2].in_channels,
            3,
            unet.time_embedding.timembedding[1].out_features,
        )

        # Create pyramid on the same device as the UNet
        self.pyramid = GaussianPyramid(
            num_levels=len(self.skip_probes) + 1,  # +1 for tail probe
            resolution=32,
            grayscale=False,
            device=unet.head.weight.device,
        )

        # Freeze UNet weights
        for param in self.unet.parameters():
            param.requires_grad = False

    def forward(self, x, t):
        # Get time embedding
        with torch.no_grad():
            temb = self.unet.time_embedding(t)

            # Forward through UNet while collecting features before upscales
            h = self.unet.head(x)
            hs = [h]

            # Down path
            for layer in self.unet.downblocks:
                h = layer(h, temb)
                hs.append(h)

            # Middle
            for layer in self.unet.middleblocks:
                h = layer(h, temb)

            # Up path - collect features before upscales
            upscale_features = []
            for layer in self.unet.upblocks:
                if isinstance(layer, ResBlock):
                    h = torch.cat([h, hs.pop()], dim=1)
                if isinstance(layer, UpSample):
                    upscale_features.append(h.clone())
                h = layer(h, temb)

            # Get feature before tail
            tail_feature = h.clone()

            # Final output (not used for probes)
            _ = self.unet.tail(h)

        # Get reconstructions from each probe
        reconstructions = []

        # Process upscale features
        for i, (feature, probe) in enumerate(zip(upscale_features, self.skip_probes)):
            recon = probe(feature, temb)
            reconstructions.append(recon.clone())

        # Process tail feature
        tail_recon = self.tail_probe(tail_feature, temb)
        reconstructions.append(tail_recon)

        return reversed(reconstructions)


def train_probes(modelConfig: Dict):
    device = torch.device(modelConfig["device"])

    # Load trained UNet
    unet = load_model(modelConfig, device)

    probe_type = "upsample"
    if probe_type == "upsample":
        model = UNetWithUpscaleProbes(unet).to(device)
        probe_params = list(model.skip_probes.parameters()) + list(
            model.tail_probe.parameters()
        )
    else:
        model = UNetWithProbes(unet).to(device)
        probe_params = list(model.skip_probes.parameters()) + list(
            model.bottleneck_probe.parameters()
        )

    # Setup optimizer for probes only
    optimizer = torch.optim.AdamW(probe_params, lr=1e-3)

    # Setup dataset
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    dataset = CIFAR10(
        root=modelConfig["dataset_root"], train=True, download=True, transform=transform
    )
    dataloader = DataLoader(
        dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4
    )

    # Setup wandb with noise level in run name
    wandb.init(
        project="unet-probes",
        name=f"t={modelConfig['noise_level']}_{wandb.util.generate_id()}",
        config=modelConfig,
    )

    trainer = GaussianDiffusionTrainer(
        unet, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]
    ).to(device)

    # Fixed noise level tensor
    fixed_t = torch.full(
        (modelConfig["batch_size"],),
        modelConfig["noise_level"],
        device=device,
        dtype=torch.long,
    )

    # Training loop
    for epoch in range(modelConfig["epoch"]):
        epoch_losses = {f"level_{i}": 0.0 for i in range(5)}  # 4 levels + bottleneck
        num_batches = 0

        for images, _ in tqdm(dataloader):
            x_0 = images.to(device)
            t = fixed_t[: images.shape[0]]
            x_t, _ = trainer.add_noise(x_0, t)

            # Get pyramid levels for target
            with torch.no_grad():
                pyramid_latent = model.pyramid.encode(x_0)
                targets = [level[0] for level in pyramid_latent.levels]

            # Forward pass
            reconstructions = model(x_t, t)

            # Calculate losses for each level
            losses = []
            for i, (recon, target) in enumerate(zip(reconstructions, targets)):
                loss = F.mse_loss(recon, target)
                losses.append(loss)
                epoch_losses[f"level_{i}"] += loss.item()

            # Backward pass
            optimizer.zero_grad()
            total_loss = sum(losses)
            total_loss.backward()
            optimizer.step()

            num_batches += 1

        # Log average losses for each level
        avg_losses = {k: v / num_batches for k, v in epoch_losses.items()}
        wandb.log({"epoch": epoch, **avg_losses})

        # Log sample reconstructions every few epochs
        if epoch % 10 == 0:
            with torch.no_grad():
                model.eval()
                sample_images = next(iter(dataloader))[0][:8].to(device)
                t = fixed_t[:8]
                x_t, _ = trainer.add_noise(sample_images, t)

                # Get pyramid levels for target
                with torch.no_grad():
                    pyramid_latent = model.pyramid.encode(sample_images)
                    targets = [level[0] for level in pyramid_latent.levels]

                reconstructions = model(x_t, t)

                # # Log original images
                # wandb.log({
                #     "original": wandb.Image(sample_images * 0.5 + 0.5),
                # })

                # For each level, combine target and reconstruction side by side
                for i, (recon, target) in enumerate(zip(reconstructions, targets)):
                    # Resize target and recon to match x_t size
                    target_resized = normalize(
                        F.interpolate(
                            target,
                            size=x_t.shape[2:],
                            mode="bilinear",
                            align_corners=False,
                        )
                    )
                    recon_resized = normalize(
                        F.interpolate(
                            recon,
                            size=x_t.shape[2:],
                            mode="bilinear",
                            align_corners=False,
                        )
                    )
                    x_t_normalized = normalize(x_t)

                    # Stack target and reconstruction horizontally for each image in batch
                    B = recon.shape[0]
                    combined = torch.cat(
                        [
                            torch.cat(
                                [
                                    x_t_normalized[j : j + 1],
                                    target_resized[j : j + 1],
                                    recon_resized[j : j + 1],
                                ],
                                dim=3,
                            )  # horizontal stack
                            for j in range(B)
                        ],
                        dim=2,
                    )  # vertical stack of pairs

                    wandb.log(
                        {
                            f"level_{i}_comparison": wandb.Image(
                                combined,
                                caption=f"Level {i} - Left: Target, Right: Reconstruction",
                            )
                        }
                    )

                model.train()

    wandb.finish()


def parse_args():
    parser = argparse.ArgumentParser(
        description="Train probes on UNet skip connections"
    )
    parser.add_argument("--gpu", type=int, default=0, help="GPU device number")
    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of training epochs"
    )
    parser.add_argument("--batch-size", type=int, default=80, help="Batch size")
    parser.add_argument(
        "--noise-level", type=int, default=0, help="Fixed noise level t (0 to 999)"
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    modelConfig = {
        "device": f"cuda:{args.gpu}",
        "epoch": args.epochs,
        "batch_size": args.batch_size,
        "noise_level": args.noise_level,  # Add noise level to config
        "T": 1000,
        "channel": 128,
        "random_seed": 42,
        "eval_random_seed": 42,
        "subset_size": 50000,
        "channel_mult": [1, 2, 3, 4],
        "attn": [],
        "num_res_blocks": 2,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.0,
        "beta_1": 1e-4,
        "beta_T": 0.02,
        "img_size": 32,
        "grad_clip": 1.0,
        "dataset_root": "data/",  # Added dataset path
        "checkpoint_freq": 10,  # Added checkpoint frequency
        "use_wandb": True,
        "sample_freq": 10,
        "training_load_weight": None,
        "save_weight_dir": "./trained_models/unet/unet_cifar10_-1_noattn_20250313_232926",
        "training_load_weight": "ckpt_epoch_200.pt",  # Load the final checkpoint
        "sampled_dir": "./SampledImgs/",
        "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
        "sampledImgName": "SampledNoGuidenceImgs.png",
        "nrow": 8,
    }

    train_probes(modelConfig)
