import os
import random
import pyrallis
import json
from tqdm import tqdm
from typing import Dict, List

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from loaders.param_loader import load_paired_params
from models.base import load_model
from models.diffusion import load_diff
from models.ae import load_ae
from models.position import load_positional_embedding
from models.base.model_helper import ModelHelper
from optimization.optimizer import load_optimizer
from optimization.scheduler import load_lr_scheduler
from loss.task_loss import load_task_loss
from loss.reconstruction_loss import load_reconstruction_loss
from utils.log_utils import log_scalar_dict, create_experiment_dir
from utils.visualize_util import visualize_weights
from utils.visualize_util import visualize_weights_layerwise, visualize_weights_channelwise

from options import DiffusionTrainConfig

@pyrallis.wrap()
def main(cfg: DiffusionTrainConfig):
    init_seed(2025)
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = not cfg.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

   
    # ----------------------------------------
    # ae configuration
    # ----------------------------------------
    ae = load_ae(cfg.ae).to(device)
    
    # ----------------------------------------
    # diffusion model configuration
    # ----------------------------------------
    model_diff = load_diff(
        unet_name=cfg.diffusion.unet_name,
        dim=cfg.diffusion.input_dim,
        dim_cond=cfg.diffusion.condition_dim,
        self_cond=True,
        layer_channels=cfg.diffusion.layer_channels,
        init_ch=cfg.diffusion.init_ch,
        final_ch=cfg.diffusion.final_ch,
        timesteps=cfg.diffusion.num_timesteps
    ).to(device)
    
    
    # # 计算每个模型的参数量（不区分是否可训练）
    # params_a = sum(p.numel() for p in ae.parameters())
    # params_b = sum(p.numel() for p in model_diff.parameters())

    # # 总参数量
    # total_params = params_a + params_b

    # # 转换为百万 (M)
    # total_params_m = total_params / 1e6

    # print(f"Model A: {params_a/1e6:.2f}M parameters")
    # print(f"Model B: {params_b/1e6:.2f}M parameters")
    # print(f"Total: {total_params_m:.2f}M parameters")
    
    from fvcore.nn import FlopCountAnalysis
    ae_input_tensor = torch.randn(1, 1, 4610).to(device)
    
    pos_tensor = torch.randn(1, 1, 128).to(device)
    
    
    diff_input_tensor = torch.randn(1, 1, 128).to(device)
    diff_noise = torch.randn(1, 1, 128).to(device)
    c = torch.cat((diff_input_tensor, pos_tensor), dim=1)
    
    unet = model_diff.model
    
    ae.train()
    model_diff.train()
    ae_flops = FlopCountAnalysis(ae, (ae_input_tensor, pos_tensor))
    diff_flops = FlopCountAnalysis(unet, (diff_noise, torch.randint(0, 1000, (1,), device=device).long(), c))
    flops_per_pass = ae_flops.total() + diff_flops.total()
    
    ae.eval()
    model_diff.eval()
    ae_inf_flops = FlopCountAnalysis(ae, (ae_input_tensor, pos_tensor))
    diff_inf_flops = FlopCountAnalysis(unet, (diff_noise, torch.randint(0, 1000, (1,), device=device).long(), c))
    flops_per_inf = ae_inf_flops.total() + diff_inf_flops.total() * 1000
    
    # resnet50
    # print_flops_summary(flops_per_pass, flops_per_inf, ae_flops.total(), diff_flops.total(), 220480, 55120, 10000, 4000)
    # renset18
    print_flops_summary(flops_per_pass, flops_per_inf, ae_flops.total(), diff_flops.total(), ae_inf_flops.total(), diff_inf_flops.total() * 1000, 4810 * 8, 9620, 5000, 4000)
    
def print_flops_summary(
    flops_per_pass: float,
    flops_per_inf: float,
    ae_flops: float,
    diff_flops: float,
    ae_inf_flops: float,
    diff_inf_flops:float,
    num_train: int,
    num_test: int,
    ae_epochs: int,
    diff_epochs: int
):
    """
    计算并打印 FLOPs 总量（训练、推理、合计），自动用科学记数法与单位表示。
    
    参数：
        flops_per_pass : 每次 forward pass 的 FLOPs（来自 fvcore）
        num_train      : 训练集样本数
        num_test       : 测试集样本数
        epochs         : 训练轮数
    """

    # 计算 FLOPs
    ae_training_flops = ae_flops * num_train * ae_epochs * 3
    diff_training_flops = diff_flops * num_train * diff_epochs * 3
    training_flops = ae_flops * num_train * ae_epochs * 3 + diff_flops * num_train * diff_epochs * 3
    
    ae_inference_flops = ae_inf_flops * num_test
    diff_inference_flops = diff_inf_flops * num_test
    inference_flops = flops_per_inf * num_test
    inference_1000_flops = flops_per_inf * 1000
    total_flops = training_flops + inference_flops

    # 单位转换函数
    def format_flops(x):
        if x >= 1e15:
            return f"{x/1e15:.2f} PFLOPs"
        elif x >= 1e12:
            return f"{x/1e12:.2f} TFLOPs"
        elif x >= 1e9:
            return f"{x/1e9:.2f} GFLOPs"
        else:
            return f"{x:.2e} FLOPs"

    # 打印结果
    print("===== FLOPs Summary =====")
    print(f"FLOPs per forward pass : {flops_per_pass:.3e} ({format_flops(flops_per_pass)})")
    print(f"FLOPs per forward inference : {flops_per_inf:.3e} ({format_flops(flops_per_inf)})")
    print(f"AE Training FLOPs         : {ae_training_flops:.3e} ({format_flops(ae_training_flops)})")
    print(f"Diffusion Training FLOPs         : {diff_training_flops:.3e} ({format_flops(diff_training_flops)})")
    print(f"Training FLOPs         : {training_flops:.3e} ({format_flops(training_flops)})")
    print(f"AE Inference FLOPs        : {ae_inference_flops:.3e} ({format_flops(ae_inference_flops)})")
    print(f"Diffusion Inference FLOPs        : {diff_inference_flops:.3e} ({format_flops(diff_inference_flops)})")
    print(f"Inference FLOPs        : {inference_flops:.3e} ({format_flops(inference_flops)})")
    print(f"Inference 1000 FLOPs   : {inference_1000_flops:.3e} ({format_flops(inference_1000_flops)})")
    print(f"Total FLOPs            : {total_flops:.3e} ({format_flops(total_flops)})")
    print("=========================")

    


def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    main()
