# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
A minimal training script for SiT using PyTorch DDP.
"""
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np
from collections import OrderedDict
from PIL import Image
from copy import deepcopy
from glob import glob
from time import time
import argparse
import logging
import os
import wandb
from models import SiT_models
from download import find_model
from transport import create_transport, Sampler
from diffusers.models import AutoencoderKL
from train_utils import parse_transport_args
import wandb_utils
import sys
import pyiqa
import torch.nn.functional as F
import gc
import csv
import matplotlib.pyplot as plt
import torchvision.utils as vutils

#################################################################################
#                             Training Helper Functions                         #
#################################################################################

def save_image_grid(images, step, experiment_dir, nrow=3):

    save_dir = os.path.join(experiment_dir, "images")
    os.makedirs(save_dir, exist_ok=True)
    
    if not isinstance(images, torch.Tensor):
        images = torch.tensor(images)
    images = images[:9]
    grid = vutils.make_grid(images, nrow=nrow, normalize=True, scale_each=True)
    
    save_path = os.path.join(save_dir, f"step_{step}.png")
    vutils.save_image(grid, save_path)



@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def cleanup():
    """
    End DDP training.
    """
    dist.destroy_process_group()


def create_logger(logging_dir):
    if logging_dir is not None:
        logging.basicConfig(
            level=logging.INFO,
            format='[\033[34m%(asctime)s\033[0m] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
        )
    else:
        logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    return logger


def center_crop_arr(pil_image, image_size):
    """
    Center cropping implementation from ADM.
    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


def load_reference_images_mse(folder, device, image_size=256):
    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3, inplace=True)
    ])
    images = []
    for fname in os.listdir(folder):
        if fname.lower().endswith(('.jpg', '.png', '.jpeg')):
            img = Image.open(os.path.join(folder, fname)).convert("RGB")
            img = transform(img).unsqueeze(0).to(device)
            images.append(img)
    if not images:
        raise ValueError("No images found in folder:", folder)
    imgs = torch.cat(images, dim=0)  # [N, 3, H, W]
    return imgs


@torch.no_grad()
def evaluate_ratio_mse(
    generated_images,
    reference_loader,
    device,
    threshold=1/3
):
    vae = AutoencoderKL.from_pretrained('sd_vae').to(device)
    generated_images = generated_images.to(device)

    m = generated_images.shape[0]
    min_vals = torch.full((m,), float('inf'), device=device)
    second_min_vals = torch.full((m,), float('inf'), device=device)

    for ref_batch in reference_loader:
        if isinstance(ref_batch, (list, tuple)):
            ref_batch = ref_batch[0]  
        ref_batch = ref_batch.to(device)
        with torch.no_grad():
            ref_latents = vae.encode(ref_batch).latent_dist.sample().mul_(0.18215)  # [b, c, h, w]

        diff = generated_images.unsqueeze(1) - ref_latents.unsqueeze(0)  # [m, b, c, h, w]
        mse = (diff ** 2).mean(dim=[2, 3, 4])  # [m, b]

        combined = torch.cat([
            min_vals.unsqueeze(1),
            second_min_vals.unsqueeze(1),
            mse
        ], dim=1)  # [m, b+2]
        top2_vals, _ = torch.topk(combined, k=2, dim=1, largest=False)
        min_vals = top2_vals[:, 0]
        second_min_vals = top2_vals[:, 1]

    ratio = min_vals / second_min_vals
    mask = ratio < threshold
    count = mask.sum().item()
    total = generated_images.shape[0]

    return count, total


#################################################################################
#                                  Training Loop                                #
#################################################################################

def main(args):
    """
    Trains a new SiT model in single GPU mode (no DDP).
    """
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."
    
    # 单卡设备
    device = torch.device("cuda:1")
    device2 = torch.device("cuda:2")

    torch.manual_seed(args.global_seed)
    torch.cuda.set_device(device)
    clipiqa_metric = pyiqa.create_metric('clipiqa').to(device)
    musiq_model = pyiqa.create_metric('musiq').to(device)  
    reference_images_path = "all_classes"
    reference_images = load_reference_images_mse(reference_images_path, device=None)
    reference_dataset = TensorDataset(reference_images)
    reference_loader = DataLoader(reference_dataset, batch_size=128, shuffle=False)

    local_batch_size = args.global_batch_size

    os.makedirs(args.results_dir, exist_ok=True)  # Make results folder
    experiment_index = len(glob(f"{args.results_dir}/*"))
    model_string_name = args.model.replace("/", "-")
    experiment_name = f"{experiment_index:03d}-{model_string_name}-" \
                      f"{args.path_type}-{args.prediction}-{args.loss_weight}"
    experiment_dir = f"{args.results_dir}/{experiment_name}"
    checkpoint_dir = f"{experiment_dir}/checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    logger = create_logger(experiment_dir) 
    if args.wandb:
        entity = ''
        project = ''
        wandb_utils.initialize(args, project_name = project, entity = entity,exp_name =  experiment_name)

    # 模型
    assert args.image_size % 8 == 0, "Image size must be divisible by 8."
    latent_size = args.image_size // 8
    model = SiT_models[args.model](input_size=latent_size, num_classes=args.num_classes)
    ema = deepcopy(model).to(device)

    if args.ckpt is not None:
        ckpt_path = args.ckpt
        state_dict = find_model(ckpt_path)
        model.load_state_dict(state_dict["model"])
        ema.load_state_dict(state_dict["ema"])
        opt.load_state_dict(state_dict["opt"])
        args = state_dict["args"]

    requires_grad(ema, False)

    model = model.to(device)

    transport = create_transport(
        args.path_type,
        args.prediction,
        args.loss_weight,
        args.train_eps,
        args.sample_eps
    )
    transport_sampler = Sampler(transport)
    vae = AutoencoderKL.from_pretrained('sd_vae').to(device)

    logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")

    opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)

    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3, inplace=True)
    ])
    dataset = ImageFolder(args.data_path, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=local_batch_size,
        shuffle=True,  #
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")

    update_ema(ema, model, decay=0)
    model.train()
    ema.eval()

    train_steps = 0
    log_steps = 0
    running_loss = 0
    start_time = time()

    ys = torch.ones(size=(local_batch_size,), device=device, dtype=torch.long)

    use_cfg = args.cfg_scale > 1.0
    n = ys.size(0)
    zs = torch.randn(n, 4, latent_size, latent_size, device=device)

    if use_cfg:
        zs = torch.cat([zs, zs], 0)
        y_null = torch.tensor([1000] * n, device=device)
        ys = torch.cat([ys, y_null], 0)
        sample_model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale)
        model_fn = ema.forward_with_cfg
    else:
        sample_model_kwargs = dict(y=ys)
        model_fn = ema.forward

    logger.info(f"Training for {args.epochs} epochs...")
    for epoch in range(args.epochs):
        logger.info(f"Beginning epoch {epoch}...")
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            with torch.no_grad():
                x = vae.encode(x).latent_dist.sample().mul_(0.18215)
            model_kwargs = dict(y=y)
            loss_dict = transport.training_losses(model, x, model_kwargs)
            loss = loss_dict["loss"].mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
            update_ema(ema, model)

            running_loss += loss.item()
            log_steps += 1
            train_steps += 1
            if train_steps % args.log_every == 0:
                torch.cuda.synchronize()
                end_time = time()
                steps_per_sec = log_steps / (end_time - start_time)
                avg_loss = running_loss / log_steps
                logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
                if args.wandb:
                    wandb_utils.log(
                        {"train loss": avg_loss, "train steps/sec": steps_per_sec},
                        step=train_steps
                    )
                running_loss = 0
                log_steps = 0
                start_time = time()

            if train_steps % args.ckpt_every == 0 and train_steps > 0:
                checkpoint = {
                    "model": model.state_dict(),
                    "ema": ema.state_dict(),
                    "opt": opt.state_dict(),
                    "args": args
                }
                
                #checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
                checkpoint_path = f"{checkpoint_dir}/{1}.pt"
                torch.save(checkpoint, checkpoint_path)
                logger.info(f"Saved checkpoint to {checkpoint_path}")
            
            

            if train_steps % args.sample_every == 0 and train_steps > 0:
                logger.info("Generating EMA samples and computing CLIP-IQA + MUSIQ scores...")
                sample_fn = transport_sampler.sample_ode()
                
                total_clip = 0.0
                total_musiq = 0.0
                total_count = 0
                bs = 10
                num_batches = 100
                total_match_count = 0

                for i in range(num_batches):
                    zs_i = torch.randn(bs, 4, latent_size, latent_size, device=device)
                    ys_i = torch.ones(size=(bs,), device=device, dtype=torch.long)

                    
                    kwargs = dict(y=ys_i)

                    samples_i1 = sample_fn(zs_i, model_fn, **kwargs)[-1]
                    samples_i = vae.decode(samples_i1 / 0.18215).sample

                    # Normalize to [0, 1]
                    samples_norm = (samples_i.clamp(-1, 1) + 1) / 2
                    samples_norm2 = samples_norm.to(device2)

                    # Compute clipiqa and musiq scores
                    clip_scores = clipiqa_metric(samples_norm)
                    musiq_scores = musiq_model(samples_norm)
                    match_count, _ = evaluate_ratio_mse(samples_i1, reference_loader,device2)
    
                    total_clip += clip_scores.sum().item()
                    total_musiq += musiq_scores.sum().item()
                    total_match_count = total_match_count + match_count
                    total_count += samples_i.size(0)
                    
                    if i == 0:
                        save_image_grid(samples_i[:9], train_steps, experiment_dir)


                avg_clip = total_clip / total_count
                avg_musiq = total_musiq / total_count
                avg_match_count = total_match_count / total_count

                logger.info(f"CLIP-IQA average score: {avg_clip:.4f}")
                logger.info(f"MUSIQ average score:    {avg_musiq:.4f}")
                logger.info(f"Memory Number:    {avg_match_count:.4f}")
                
                csv_path = os.path.join(experiment_dir, "metrics.csv")

                write_header = not os.path.exists(csv_path)

                with open(csv_path, mode="a", newline="") as csvfile:
                    writer = csv.writer(csvfile)
                    if write_header:
                        writer.writerow(["train_step", "clipiqa", "musiq", "memory"])
                    writer.writerow([train_steps, avg_clip, avg_musiq, avg_match_count])

            
            

    model.eval()
    logger.info("Done!")

if __name__ == "__main__":

    class Namespace:
        def __init__(self, **kwargs):
            self.__dict__.update(kwargs)
    args = Namespace(
        model="SiT-B/2",
        data_path="data",
        path_type="Linear",
        prediction="noise",
        wandb=False,

        loss_weight="none", 
        train_eps=1e-5,
        sample_eps=1e-5,
        
        results_dir="results",
        image_size=256,
        num_classes=1,
        epochs=20000,              
        global_batch_size=16,
        global_seed=0,
        vae="ema",
        num_workers=4,
        log_every=20,
        ckpt_every=50000,
        sample_every=20000,
        cfg_scale=0,
        ckpt=None
    )
    
    main(args)
