import time
import torch
from models.model_configs import instantiate_model
from training.edm_time_discretization import get_time_discretization
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.solver import MixtureDiscreteEulerSolver
from flow_matching.solver.ode_solver import ODESolver
from training.eval_loop import CFGScaledModel
from ptflops import get_model_complexity_info
from tqdm import tqdm

# ---------------------
# Configuration
# ---------------------
class Args:
    architecture = 'mux4-celeba'  # e.g., 'unet'
    discrete_flow_matching = False
    use_ema = False
    resolution = 64
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 100
    num_warmup = 2
    num_batches = 20

args = Args()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# Instantiate model
model = instantiate_model(
    architechture=args.architecture,
    is_discrete=args.discrete_flow_matching,
    use_ema=args.use_ema,
)
model.to(args.device)
model.eval()
print(args.device)
# param_list = list(model.named_parameters())
# for name, param in param_list:
#     print(name, param.shape)

# 1) Total parameters
total_params = count_parameters(model)
print(f"Total parameters: {total_params:,}")

mux_params = sum(p.numel() for name, p in model.named_parameters() if 'mux' in name.lower())
if mux_params > 0:
    print(f"Total mux additional parameters: {mux_params:,}")

# 2) MACs per image
input_shape = (3, args.resolution, args.resolution)
def input_constructor(input_res):
    sample   = torch.randn(8, *input_shape).to(args.device)       # batch of 8 images
    timestep = torch.zeros(8//model.K, dtype=torch.long, device=args.device)  # one timestep per image
    return dict(x=sample, timesteps=timestep, extra={})

macs_str, params = get_model_complexity_info(model, input_shape, input_constructor=input_constructor,as_strings=True, print_per_layer_stat=False,verbose=False)
val, unit = macs_str.split()
unit = unit.lower()
factor = {"gmac":1e9, "mmac":1e6, "kmac":1e3}.get(unit,1)
macs_batch = float(val) * factor


# per-image UNet MACs
macs_per_image = macs_batch / 8 * 50
if macs_per_image > 1e9:
    macs_per_image = macs_per_image / 1e9
    unit = "gmac"
elif macs_per_image > 1e6:
    macs_per_image = macs_per_image / 1e6
    unit = "mmac"
elif macs_per_image > 1e3:
    macs_per_image = macs_per_image / 1e3
    unit = "kmac"
else:
    unit = "mac"
print(f"MACs per image: {macs_per_image/8:.2f} {unit}")


# 3) Throughput (images/sec) via simple forward passes
# x = torch.randn(args.batch_size, *input_shape, device=args.device)
# t = torch.zeros(args.batch_size//model.K, dtype=torch.long, device=args.device)  # one timestep per image
# extra = {}
# # Warm-up
# with torch.no_grad():
#     for _ in tqdm(range(args.num_warmup)):
#         model(x, t, extra)
# # Timing
# start = time.time()
# with torch.no_grad():
#     for _ in tqdm(range(args.num_batches)):
#         model(x, t, extra)
# elapsed = time.time() - start
# images_processed = args.num_batches * args.batch_size
# print(f"Forward throughput: {images_processed / elapsed:.2f} images/sec")

# # 4) Sampling throughput via ODE solver (optional)
solver = ODESolver(velocity_model=CFGScaledModel(model))
ode_opts = {'nfe': 50}
solver.velocity_model.train(False)
labels = torch.zeros(args.batch_size, dtype=torch.long, device=args.device)  # one timestep per image
x0 = torch.randn(args.batch_size, 3, args.resolution, args.resolution, device=args.device)
time_grid = get_time_discretization(nfes=50)

# Warm-up
# Warm-up
for _ in tqdm(range(5), desc="Warm-up"):
    samples = solver.sample(
        time_grid=time_grid,
        x_init=x0,
        method='heun2',
        return_intermediates=False,
        atol=ode_opts.get("atol", 1e-5),
        rtol=ode_opts.get("rtol", 1e-5),
        step_size=ode_opts.get("step_size", None),
        label=labels,
        cfg_scale=0.0,
    )

# Timing sampling
start_sampling = time.time()
for _ in tqdm(range(5), desc="Sampling"):
    solver.sample(
        time_grid=time_grid,
        x_init=x0,
        method='heun2',
        return_intermediates=False,
        atol=ode_opts.get("atol", 1e-5),
        rtol=ode_opts.get("rtol", 1e-5),
        step_size=ode_opts.get("step_size", None),
        label=labels,
        cfg_scale=0.0,
    )
sampling_time = time.time() - start_sampling

total_samples = 5 * args.batch_size
print(f"Sampling throughput: {total_samples / sampling_time:.2f} images/sec")
