import os
import argparse
import torch
import inspect
import time
from diffusers import DDPMPipeline, DDPMScheduler, MUXUNet2DModel, UNet2DModel, DDIMScheduler, DDIMPipeline
from torchvision.utils import save_image
from tqdm import tqdm
import numpy as np
import re
import json
from PIL import Image

# Optional: for FLOPs estimation
try:
    from fvcore.nn import FlopCountAnalysis
except ImportError:
    FlopCountAnalysis = None


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for generation")
    parser.add_argument("--num_inference_steps", type=int, default=100, help="Number of DDPM inference steps")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that HF Datasets can understand."
        ),
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=64,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--prediction_type",
        type=str,
        default="epsilon",
        choices=["epsilon", "sample"],
        help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
    )
    parser.add_argument("--ddpm_num_steps", type=int, default=1000)
    parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
    parser.add_argument(
        "--K", type=int, default=2, help="The number of Muxed inputs for the model."
    )
    parser.add_argument(
        "--mux_mod",
        type=str,
        default="nonlinear-expand",
        choices=["nonlinear-expand", "linear"],
        help="The muxing method for the model.",
    )
    parser.add_argument(
        "--mux",
        type=bool,
        default=False,
        help="Whether to use Muxed inputs for the model.",
    )
    return parser.parse_args()


@torch.no_grad()
def main():
    args = parse_args()
    accepts_prediction_type = "prediction_type" in set(inspect.signature(DDIMScheduler.__init__).parameters.keys())
    if accepts_prediction_type:
        if args.dataset_name=='uoft-cs/cifar10':
            scheduler = DDIMScheduler(
                num_train_timesteps=args.ddpm_num_steps,
                beta_schedule=args.ddpm_beta_schedule,
                prediction_type=args.prediction_type,
                # variance_type="fixed_large",
            )
        else:
            scheduler = DDIMScheduler(
                num_train_timesteps=args.ddpm_num_steps,
                beta_schedule=args.ddpm_beta_schedule,
                prediction_type=args.prediction_type,
            )
    else:
        if args.dataset_name=='uoft-cs/cifar10':
            scheduler = DDIMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)#, variance_type="fixed_large")
        else:
            scheduler = DDIMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)

    if args.mux:
        if args.dataset_name == "uoft-cs/cifar10":
            unet = MUXUNet2DModel(
                act_fn="silu",
                attention_head_dim=None,
                block_out_channels=(128, 256, 256, 256),
                center_input_sample=False,
                down_block_types=(
                    "DownBlock2D",
                    "AttnDownBlock2D",
                    "DownBlock2D",
                    "DownBlock2D",
                ),
                downsample_padding=0,
                flip_sin_to_cos=False,
                freq_shift=1,
                in_channels=3,
                layers_per_block=2,
                mid_block_scale_factor=1,
                norm_eps=1e-6,
                norm_num_groups=32,
                out_channels=3,
                sample_size=args.resolution,
                time_embedding_type="positional",
                up_block_types=(
                    "UpBlock2D",
                    "UpBlock2D",
                    "AttnUpBlock2D",
                    "UpBlock2D",
                ),
                dropout=0.1,
                mux_mod = args.mux_mod,
                demux_mod = "channel-one",
                expand = 8,
                K = args.K,
        )
        else:
            unet = MUXUNet2DModel(
                sample_size=args.resolution,
                in_channels=3,
                out_channels=3,
                layers_per_block=2,
                block_out_channels=(128, 128, 256, 256, 512, 512),
                down_block_types=(
                    "DownBlock2D",
                    "DownBlock2D",
                    "DownBlock2D",
                    "DownBlock2D",
                    "AttnDownBlock2D",
                    "DownBlock2D",
                ),
                up_block_types=(
                    "UpBlock2D",
                    "AttnUpBlock2D",
                    "UpBlock2D",
                    "UpBlock2D",
                    "UpBlock2D",
                    "UpBlock2D",
                ),
                mux_mod = args.mux_mod,
                demux_mod = "channel-one",
                expand = 8,
                K = args.K,
            )
        print(f"Using Muxed UNet with {args.K} inputs")
    else:
        if args.dataset_name == "uoft-cs/cifar10":
            unet = UNet2DModel(
                act_fn="silu",
                attention_head_dim=None,
                block_out_channels=(128, 256, 256, 256),
                center_input_sample=False,
                down_block_types=(
                    "DownBlock2D",
                    "AttnDownBlock2D",
                    "DownBlock2D",
                    "DownBlock2D",
                ),
                downsample_padding=0,
                flip_sin_to_cos=False,
                freq_shift=1,
                in_channels=3,
                layers_per_block=2,
                mid_block_scale_factor=1,
                norm_eps=1e-6,
                norm_num_groups=32,
                out_channels=3,
                sample_size=args.resolution,
                time_embedding_type="positional",
                up_block_types=(
                    "UpBlock2D",
                    "UpBlock2D",
                    "AttnUpBlock2D",
                    "UpBlock2D",
                ),
                dropout=0.1,
            )
        else:
            unet = UNet2DModel(
                sample_size=args.resolution,
                in_channels=3,
                out_channels=3,
                layers_per_block=2,
                block_out_channels=(128, 128, 256, 256, 512, 512),
                down_block_types=(
                    "DownBlock2D",
                    "DownBlock2D",
                    "DownBlock2D",
                    "DownBlock2D",
                    "AttnDownBlock2D",
                    "DownBlock2D",
                ),
                up_block_types=(
                    "UpBlock2D",
                    "AttnUpBlock2D",
                    "UpBlock2D",
                    "UpBlock2D",
                    "UpBlock2D",
                    "UpBlock2D",
                ),
            )
        print("Using UNet")
        args.K = 1
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    pipeline = DDIMPipeline(unet=unet, scheduler=scheduler).to(device)

    # 1. Count UNet parameters
    param_count = sum(p.numel() for p in unet.parameters())
    print(f"UNet parameter count: {param_count:,}")

    mux_params = sum(p.numel() for name, p in unet.named_parameters() if 'mux' in name.lower())
    if mux_params > 0:
        print(f"Total mux additional parameters: {mux_params:,}")

    # 2. FLOPs with ptflops, ensuring positional inputs
    from ptflops import get_model_complexity_info

    input_res = (3, unet.config.sample_size, unet.config.sample_size)


    def input_constructor(input_res):
        sample   = torch.randn(8, *input_res).to(device)       # batch of 8 images
        timestep = torch.zeros(8//args.K, dtype=torch.long, device=device)  # one timestep per image
        return dict(sample=sample, timestep=timestep)

    macs_str, params_str = get_model_complexity_info(
        unet,
        input_res,
        input_constructor=input_constructor,
        as_strings=True,
        print_per_layer_stat=False,
        verbose=False,
    )
    print(f"MACs for UNet eval (batch of {8}): {macs_str}")

    # parse "24.55 GMac" → number
    val, unit = macs_str.split()
    unit = unit.lower()
    factor = {"gmac":1e9, "mmac":1e6, "kmac":1e3}.get(unit,1)
    macs_batch = float(val) * factor

    # per-image UNet MACs
    macs_per_image = macs_batch / 8
    unet_total_per_image = macs_per_image * args.num_inference_steps

    # ---- Scheduler-step MACs estimation ----
    c, h, w = 3, args.resolution, args.resolution
    num_elems = c * h * w
    # assume ~7 elementwise ops per step (mult + add counts)
    ops_per_elem = 7
    sched_macs_per_step = num_elems * ops_per_elem
    sched_total_per_image = sched_macs_per_step * args.num_inference_steps

    grand_total = unet_total_per_image + sched_total_per_image

    print(f"Per-image UNet MACs × {args.num_inference_steps}: {unet_total_per_image:,.0f} ops")
    print(f"Per-image scheduler MACs × {args.num_inference_steps}: {sched_total_per_image:,.0f} ops")
    print(f"Estimated TOTAL MACs to generate one image: {grand_total:,.0f} ops (~{grand_total/1e9:.2f} GMac)")

    print(f"Params: {params_str}")

    # 3. Throughput test (single batch)
    bs = args.batch_size
    start_time = time.time()
    outputs = pipeline(
        batch_size=bs,
        num_inference_steps=args.num_inference_steps,
        output_type="tensor",
    ).images
    elapsed = time.time() - start_time
    assert bs == outputs.shape[0], f"Batch size {bs} != outputs {outputs.shape[0]}"
    throughput = bs / elapsed
    print(f"Generated {bs} images in {elapsed:.2f}s -> Throughput: {throughput:.2f} img/s")


if __name__ == "__main__":
    main()