# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
import gc
import logging
import os
from argparse import Namespace
from pathlib import Path
from typing import Iterable

import PIL.Image
import uuid
import torch
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.solver import MixtureDiscreteEulerSolver
from flow_matching.solver.ode_solver import ODESolver
from flow_matching.utils import ModelWrapper
from models.discrete_unet import DiscreteUNetModel
from models.ema import EMA
from torch.nn.modules import Module
from torch.nn.parallel import DistributedDataParallel
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.utils import save_image
from training import distributed_mode
from training.edm_time_discretization import get_time_discretization
from training.train_loop import MASK_TOKEN

logger = logging.getLogger(__name__)

PRINT_FREQUENCY = 50

def analyze_time_grid_deltas(time_grid: torch.Tensor):
    """
    time_grid의 각 스텝 사이의 간격(delta)을 계산하고 통계를 출력합니다.

    Args:
        time_grid (torch.Tensor): 타임스텝을 담고 있는 1차원 텐서.
    """
    # 입력값이 2개 이상의 요소를 가진 1차원 텐서인지 확인
    if not isinstance(time_grid, torch.Tensor) or time_grid.dim() != 1 or len(time_grid) < 2:
        print("오류: 입력값은 2개 이상의 요소를 가진 1차원 torch.Tensor여야 합니다.")
        return

    # torch.diff를 사용해 연속된 요소 간의 차이를 계산 (step 간의 delta)
    time_deltas = torch.diff(time_grid)

    # 주요 통계치 계산
    mean_delta = torch.mean(time_deltas)
    min_delta = torch.min(time_deltas)
    max_delta = torch.max(time_deltas)
    
    # 백분위수 계산
    # 상의 50% 지점 -> 50th percentile (중앙값)
    # 상의 25% 지점 -> 75th percentile
    p50 = torch.quantile(time_deltas, 0.5)
    p75 = torch.quantile(time_deltas, 0.75)

    # 결과 출력
    print("--- Time Step Delta analysis ---")
    print(f"number of steps: {len(time_deltas)}")
    print("-" * 30)
    print(f"average distance  : {mean_delta.item():.10f}")
    print(f"minimum distance  : {min_delta.item():.10f}")
    print(f"maximum distance  : {max_delta.item():.10f}")
    print("-" * 30)
    print(f"50% value : {p50.item():.6f}")
    print(f"75% value : {p75.item():.6f}")
    print("---------------------------------")

P50=0.012822
P75=0.034252

class CFGScaledModel_MIX(ModelWrapper):
    def __init__(self, model_k1, model_k2, model_k4: Module):
        super().__init__(None)
        self.nfe_counter = 0
        self.model_k4 = model_k4
        self.model_k2 = model_k2
        self.model_k1 = model_k1
        self.count = [0,0,0]
        # self.time_save = [torch.tensor(0)]
    def forward(
        self, x: torch.Tensor, t: torch.Tensor, cfg_scale: float, label: torch.Tensor,
    ):
        if self.time_deltas[(self.nfe_counter)//2] < P50 :
            self.model = self.model_k4
            self.count[0] += 1
        elif self.time_deltas[(self.nfe_counter)//2] < P75 :
            self.model = self.model_k2
            self.count[1] += 1
        else :
            self.model = self.model_k1
            self.count[2] += 1
        module = (
            self.model.module
            if isinstance(self.model, DistributedDataParallel)
            else self.model
        )
        is_discrete = isinstance(module, DiscreteUNetModel) or (
            isinstance(module, EMA) and isinstance(module.model, DiscreteUNetModel)
        )
        assert (
            cfg_scale == 0.0 or not is_discrete
        ), f"Cfg scaling does not work for the logit outputs of discrete models. Got cfg weight={cfg_scale} and model {type(self.model)}."
        if hasattr(module, "K"):
            numK = module.K
        elif hasattr(module, "model") and hasattr(module.model, "K"):
            numK = module.model.K
        else:
            numK = 1
        bsz = x.shape[0]//numK
        t = torch.zeros(bsz, device=x.device) + t

        if cfg_scale != 0.0:
            with torch.cuda.amp.autocast(), torch.no_grad():
                conditional = self.model(x, t, extra={"label": label})
                condition_free = self.model(x, t, extra={})
            result = (1.0 + cfg_scale) * conditional - cfg_scale * condition_free
        else:
            # Model is fully conditional, no cfg weighting needed
            with torch.cuda.amp.autocast(), torch.no_grad():
                result = self.model(x, t, extra={"label": label})

        self.nfe_counter += 1
        if is_discrete:
            return torch.softmax(result.to(dtype=torch.float32), dim=-1)
        else:
            return result.to(dtype=torch.float32)

    def reset_nfe_counter(self) -> None:
        self.nfe_counter = 0

    def get_nfe(self) -> int:
        return self.nfe_counter


def eval_model(
    model: DistributedDataParallel,
    data_loader: Iterable,
    device: torch.device,
    epoch: int,
    fid_samples: int,
    args: Namespace,
):
    gc.collect()
    cfg_scaled_model = CFGScaledModel_MIX(model=model)
    cfg_scaled_model.train(False)

    if args.discrete_flow_matching:
        scheduler = PolynomialConvexScheduler(n=3.0)
        path = MixtureDiscreteProbPath(scheduler=scheduler)
        p = torch.zeros(size=[257], dtype=torch.float32, device=device)
        p[256] = 1.0
        solver = MixtureDiscreteEulerSolver(
            model=cfg_scaled_model,
            path=path,
            vocabulary_size=257,
            source_distribution_p=p,
        )
    else:
        solver = ODESolver(velocity_model=cfg_scaled_model)
        ode_opts = args.ode_options

    fid_metric = FrechetInceptionDistance(normalize=True).to(
        device=device, non_blocking=True
    )

    num_synthetic = 0
    snapshots_saved = False
    if args.output_dir:
        (Path(args.output_dir) / "snapshots").mkdir(parents=True, exist_ok=True)

    for data_iter_step, (samples, labels) in enumerate(data_loader):
        samples = samples.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        fid_metric.update(samples, real=True)

        if num_synthetic < fid_samples:
            cfg_scaled_model.reset_nfe_counter()
            if args.discrete_flow_matching:
                # Discrete sampling
                x_0 = (
                    torch.zeros(samples.shape, dtype=torch.long, device=device)
                    + MASK_TOKEN
                )
                if args.sym_func:
                    sym = lambda t: 12.0 * torch.pow(t, 2.0) * torch.pow(1.0 - t, 0.25)
                else:
                    sym = args.sym
                if args.sampling_dtype == "float32":
                    dtype = torch.float32
                elif args.sampling_dtype == "float64":
                    dtype = torch.float64

                synthetic_samples = solver.sample(
                    x_init=x_0,
                    step_size=1.0 / args.discrete_fm_steps,
                    verbose=False,
                    div_free=sym,
                    dtype_categorical=dtype,
                    label=labels,
                    cfg_scale=args.cfg_scale,
                )
            else:
                # Continuous sampling
                x_0 = torch.randn(samples.shape, dtype=torch.float32, device=device)

                if args.edm_schedule:
                    time_grid = get_time_discretization(nfes=ode_opts["nfe"])
                else:
                    time_grid = torch.tensor([0.0, 1.0], device=device)

                synthetic_samples = solver.sample(
                    time_grid=time_grid,
                    x_init=x_0,
                    method=args.ode_method,
                    return_intermediates=False,
                    atol=ode_opts["atol"] if "atol" in ode_opts else 1e-5,
                    rtol=ode_opts["rtol"] if "atol" in ode_opts else 1e-5,
                    step_size=ode_opts["step_size"]
                    if "step_size" in ode_opts
                    else None,
                    label=labels,
                    cfg_scale=args.cfg_scale,
                )

                # Scaling to [0, 1] from [-1, 1]
                synthetic_samples = torch.clamp(
                    synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0
                )
                synthetic_samples = torch.floor(synthetic_samples * 255)
            synthetic_samples = synthetic_samples.to(torch.float32) / 255.0
            logger.info(
                f"{samples.shape[0]} samples generated in {cfg_scaled_model.get_nfe()} evaluations."
            )
            if num_synthetic + synthetic_samples.shape[0] > fid_samples:
                synthetic_samples = synthetic_samples[: fid_samples - num_synthetic]
            fid_metric.update(synthetic_samples, real=False)
            num_synthetic += synthetic_samples.shape[0]
            if not snapshots_saved and args.output_dir:
                save_image(
                    synthetic_samples,
                    fp=Path(args.output_dir)
                    / "snapshots"
                    / f"{epoch}_{data_iter_step}.png",
                )
                snapshots_saved = True

            if args.save_fid_samples and args.output_dir:
                images_np = (
                    (synthetic_samples * 255.0)
                    .clip(0, 255)
                    .to(torch.uint8)
                    .permute(0, 2, 3, 1)
                    .cpu()
                    .numpy()
                )
                for batch_index, image_np in enumerate(images_np):
                    image_dir = Path(args.output_dir) / "fid_samples"
                    os.makedirs(image_dir, exist_ok=True)
                    image_path = (
                        image_dir
                        / f"{distributed_mode.get_rank()}_{data_iter_step}_{batch_index}.png"
                    )
                    PIL.Image.fromarray(image_np, "RGB").save(image_path)

        if not args.compute_fid:
            return {}

        if data_iter_step % PRINT_FREQUENCY == 0:
            # Sync fid metric to ensure that the processes dont deviate much.
            gc.collect()
            running_fid = fid_metric.compute()
            logger.info(
                f"Evaluating [{data_iter_step}/{len(data_loader)}] samples generated [{num_synthetic}/{fid_samples}] running fid {running_fid}"
            )

        if args.test_run:
            break

    return {"fid": float(fid_metric.compute().detach().cpu())}

def sample_model(
    model_k1: DistributedDataParallel,
    model_k2: DistributedDataParallel,
    model_k4: DistributedDataParallel,
    data_loader: Iterable,
    device: torch.device,
    epoch: int,
    fid_samples: int,
    args: Namespace,
):
    gc.collect()
    cfg_scaled_model = CFGScaledModel_MIX(model_k1=model_k1, model_k2=model_k2, model_k4=model_k4)
    cfg_scaled_model.train(False)
    num_samples = 0
    if args.discrete_flow_matching:
        scheduler = PolynomialConvexScheduler(n=3.0)
        path = MixtureDiscreteProbPath(scheduler=scheduler)
        p = torch.zeros(size=[257], dtype=torch.float32, device=device)
        p[256] = 1.0
        solver = MixtureDiscreteEulerSolver(
            model=cfg_scaled_model,
            path=path,
            vocabulary_size=257,
            source_distribution_p=p,
        )
    else:
        solver = ODESolver(velocity_model=cfg_scaled_model)
        ode_opts = args.ode_options

    num_synthetic = 0
    snapshots_saved = False
    if args.output_dir:
        (Path(args.output_dir) / "snapshots").mkdir(parents=True, exist_ok=True)

    samples, labels = next(iter(data_loader))
    samples = samples.to(device, non_blocking=True)
    labels = labels.to(device, non_blocking=True)
    num_loop = 0
    while num_synthetic < fid_samples:
        cfg_scaled_model.reset_nfe_counter()
        if args.discrete_flow_matching:
            # Discrete sampling
            x_0 = (
                torch.zeros(samples.shape, dtype=torch.long, device=device)
                + MASK_TOKEN
            )
            if args.sym_func:
                sym = lambda t: 12.0 * torch.pow(t, 2.0) * torch.pow(1.0 - t, 0.25)
            else:
                sym = args.sym
            if args.sampling_dtype == "float32":
                dtype = torch.float32
            elif args.sampling_dtype == "float64":
                dtype = torch.float64

            synthetic_samples = solver.sample(
                x_init=x_0,
                step_size=1.0 / args.discrete_fm_steps,
                verbose=False,
                div_free=sym,
                dtype_categorical=dtype,
                label=labels,
                cfg_scale=args.cfg_scale,
            )
        else:
            # Continuous sampling
            x_0 = torch.randn(samples.shape, dtype=torch.float32, device=device)

            if args.edm_schedule:
                time_grid = get_time_discretization(nfes=ode_opts["nfe"])
            else:
                time_grid = torch.tensor([0.0, 1.0], device=device)

            cfg_scaled_model.time_deltas = torch.diff(time_grid)

            synthetic_samples = solver.sample(
                time_grid=time_grid,
                x_init=x_0,
                method=args.ode_method,
                return_intermediates=False,
                atol=ode_opts["atol"] if "atol" in ode_opts else 1e-5,
                rtol=ode_opts["rtol"] if "atol" in ode_opts else 1e-5,
                step_size=ode_opts["step_size"]
                if "step_size" in ode_opts
                else None,
                label=labels,
                cfg_scale=args.cfg_scale,
            )

            # Scaling to [0, 1] from [-1, 1]
            synthetic_samples = torch.clamp(
                synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0
            )
            synthetic_samples = torch.floor(synthetic_samples * 255)
        synthetic_samples = synthetic_samples.to(torch.float32) / 255.0
        logger.info(
            f"{num_synthetic} samples generated in {cfg_scaled_model.get_nfe()} evaluations."
        )
        if num_synthetic + synthetic_samples.shape[0] > fid_samples:
            synthetic_samples = synthetic_samples[: fid_samples - num_synthetic]
        num_synthetic += synthetic_samples.shape[0]
        if not snapshots_saved and args.output_dir:
            save_image(
                synthetic_samples,
                fp=Path(args.output_dir)
                / "snapshots"
                / f"{epoch}_{num_loop}.png",
            )
            snapshots_saved = True

        if args.save_fid_samples and args.output_dir:
            images_np = (
                (synthetic_samples * 255.0)
                .clip(0, 255)
                .to(torch.uint8)
                .permute(0, 2, 3, 1)
                .cpu()
                .numpy()
            )
            for batch_index, image_np in enumerate(images_np):
                image_dir = Path(args.output_dir) / "fid_samples"
                os.makedirs(image_dir, exist_ok=True)
                random_id = str(uuid.uuid4())[:12]
                image_path = (
                    image_dir
                    / f"{random_id}_{num_loop}_{batch_index}.png"
                )
                PIL.Image.fromarray(image_np, "RGB").save(image_path)
        num_loop += 1
        if args.test_run:
            break
    print(cfg_scaled_model.count)
    return {}

def sample_progress(
    model_k1: DistributedDataParallel,
    model_k2: DistributedDataParallel,
    model_k4: DistributedDataParallel,
    data_loader: Iterable,
    device: torch.device,
    epoch: int,
    fid_samples: int,
    args: Namespace,
):
    gc.collect()
    cfg_scaled_model = CFGScaledModel_MIX(model_k1=model_k1, model_k2=model_k2, model_k4=model_k4)
    cfg_scaled_model.train(False)
    num_samples = 0
    if args.discrete_flow_matching:
        scheduler = PolynomialConvexScheduler(n=3.0)
        path = MixtureDiscreteProbPath(scheduler=scheduler)
        p = torch.zeros(size=[257], dtype=torch.float32, device=device)
        p[256] = 1.0
        solver = MixtureDiscreteEulerSolver(
            model=cfg_scaled_model,
            path=path,
            vocabulary_size=257,
            source_distribution_p=p,
        )
    else:
        solver = ODESolver(velocity_model=cfg_scaled_model)
        ode_opts = args.ode_options

    num_synthetic = 0
    snapshots_saved = False
    if args.output_dir:
        (Path(args.output_dir) / "snapshots").mkdir(parents=True, exist_ok=True)

    samples, labels = next(iter(data_loader))
    samples = samples.to(device, non_blocking=True)
    labels = labels.to(device, non_blocking=True)
    num_loop = 0
    cfg_scaled_model.reset_nfe_counter()
    if args.discrete_flow_matching:
        # Discrete sampling
        x_0 = (
            torch.zeros(samples.shape, dtype=torch.long, device=device)
            + MASK_TOKEN
        )
        if args.sym_func:
            sym = lambda t: 12.0 * torch.pow(t, 2.0) * torch.pow(1.0 - t, 0.25)
        else:
            sym = args.sym
        if args.sampling_dtype == "float32":
            dtype = torch.float32
        elif args.sampling_dtype == "float64":
            dtype = torch.float64

        synthetic_samples = solver.sample(
            x_init=x_0,
            step_size=1.0 / args.discrete_fm_steps,
            verbose=False,
            div_free=sym,
            dtype_categorical=dtype,
            label=labels,
            cfg_scale=args.cfg_scale,
        )
    else:
        # Continuous sampling
        x_0 = torch.randn(samples.shape, dtype=torch.float32, device=device)

        if args.edm_schedule:
            time_grid = get_time_discretization(nfes=ode_opts["nfe"])
        else:
            time_grid = torch.tensor([0.0, 1.0], device=device)

        analyze_time_grid_deltas(time_grid)
        cfg_scaled_model.time_deltas = torch.diff(time_grid)

        synthetic_samples = solver.sample_out(
            time_grid=time_grid,
            out_dir = args.output_dir,
            x_init=x_0,
            method=args.ode_method,
            atol=ode_opts["atol"] if "atol" in ode_opts else 1e-5,
            rtol=ode_opts["rtol"] if "atol" in ode_opts else 1e-5,
            step_size=ode_opts["step_size"]
            if "step_size" in ode_opts
            else None,
            label=labels,
            cfg_scale=args.cfg_scale,
            
        )

        # Scaling to [0, 1] from [-1, 1]
        # synthetic_samples = torch.clamp(
        #     synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0
        # )
        # synthetic_samples = torch.floor(synthetic_samples * 255)
        # synthetic_samples = synthetic_samples.to(torch.float32) / 255.0
        # logger.info(
        #     f"{num_synthetic} samples generated in {cfg_scaled_model.get_nfe()} evaluations."
        # )
        # if num_synthetic + synthetic_samples.shape[0] > fid_samples:
        #     synthetic_samples = synthetic_samples[: fid_samples - num_synthetic]
        # num_synthetic += synthetic_samples.shape[0]
        # if not snapshots_saved and args.output_dir:
        #     save_image(
        #         synthetic_samples,
        #         fp=Path(args.output_dir)
        #         / "snapshots"
        #         / f"{epoch}_{num_loop}.png",
        #     )
        #     snapshots_saved = True

        # if args.save_fid_samples and args.output_dir:
        #     images_np = (
        #         (synthetic_samples * 255.0)
        #         .clip(0, 255)
        #         .to(torch.uint8)
        #         .permute(0, 2, 3, 1)
        #         .cpu()
        #         .numpy()
        #     )
        #     for batch_index, image_np in enumerate(images_np):
        #         image_dir = Path(args.output_dir) / "fid_samples"
        #         os.makedirs(image_dir, exist_ok=True)
        #         image_path = (
        #             image_dir
        #             / f"{distributed_mode.get_rank()}_{num_loop}_{batch_index}.png"
        #         )
        #         PIL.Image.fromarray(image_np, "RGB").save(image_path)
        # num_loop += 1
        # if args.test_run:
        #     break
    print(cfg_scaled_model.count)
    return {}