# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.

import datetime
import json
import logging
import os
import sys
from pathlib import Path

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
from models.model_configs import instantiate_model
from train_arg_parser import get_args_parser

from training import distributed_mode
from training.data_transform import get_train_transform, get_train_transform_celeba, get_train_transform_resize
from training.throughput import eval_model, sample_model
from training.grad_scaler import NativeScalerWithGradNormCount as NativeScaler
from training.load_and_save import load_model_for_eval
from PIL import Image
from tqdm import tqdm

logger = logging.getLogger(__name__)

def save_npz(args) :
    # 이미지가 저장된 디렉토리 경로
    image_dir = os.path.join(args.output_dir, "fid_samples")  # 필요시 경로 변경

    # PNG 파일 목록 가져오기
    image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])

    # 이미지 개수, 크기 등 설정
    num_images = len(image_files)

    # numpy 배열 초기화
    image = []

    # 이미지 로딩
    for i, fname in enumerate(tqdm(image_files)):
        img_path = os.path.join(image_dir, fname)
        img = Image.open(img_path)
        image.extend([np.array(img)])
    images = np.stack(image, axis=0)
    print(f"shape of images: {images.shape}")
    save_dir = os.path.join(args.output_dir, "50000_samples.npz")
    # npz로 저장
    np.savez(save_dir, images)
    


def main(args):
    logging.basicConfig(
        level=logging.INFO,
        stream=sys.stdout,
        format="%(asctime)s %(levelname)-8s %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    distributed_mode.init_distributed_mode(args)

    logger.info("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
    logger.info("{}".format(args).replace(", ", ",\n"))
    if distributed_mode.is_main_process():
        args_filepath = Path(args.output_dir) / "args.json"
        logger.info(f"Saving args to {args_filepath}")
        with open(args_filepath, "w") as f:
            json.dump(vars(args), f)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + distributed_mode.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    logger.info(f"Initializing Dataset: {args.dataset}")
    transform_train = get_train_transform()
    if args.dataset == "imagenet":
        dataset_train = datasets.ImageFolder(args.data_path, transform=transform_train)
    elif args.dataset == "cifar10":
        dataset_train = datasets.CIFAR10(
            root=args.data_path,
            train=True,
            download=True,
            transform=transform_train,
        )
    elif args.dataset == "cifar10-64":
        transform_train = get_train_transform_resize(64)
        dataset_train = datasets.CIFAR10(
            root=args.data_path,
            train=True,
            download=True,
            transform=transform_train,
        )
    elif args.dataset == "celeba":
        transform_train = get_train_transform_celeba(64)
        dataset_train = datasets.CelebA(
            root=args.data_path,
            split="train",
            download=True,
            transform=transform_train,
        )
    else:
        raise NotImplementedError(f"Unsupported dataset {args.dataset}")

    logger.info(dataset_train)

    logger.info("Intializing DataLoader")
    num_tasks = distributed_mode.get_world_size()
    global_rank = distributed_mode.get_rank()
    sampler_train = torch.utils.data.DistributedSampler(
        dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
    )
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
    logger.info(str(sampler_train))

    # define the model
    logger.info("Initializing Model")
    model_k1 = instantiate_model(
        architechture=args.architecture_k1,
        is_discrete=args.discrete_flow_matching,
        use_ema=args.use_ema,
    )
    model_k2 = instantiate_model(
        architechture=args.architecture_k2,
        is_discrete=args.discrete_flow_matching,
        use_ema=args.use_ema,
    )
    model_k4 = instantiate_model(
        architechture=args.architecture_k4,
        is_discrete=args.discrete_flow_matching,
        use_ema=args.use_ema,
    )

    model_k1.to(device)
    model_k2.to(device)
    model_k4.to(device)

    model_k1_without_ddp = model_k1
    model_k2_without_ddp = model_k2
    model_k4_without_ddp = model_k4

    logger.info(str(model_k1_without_ddp))
    logger.info(str(model_k2_without_ddp))
    logger.info(str(model_k4_without_ddp))

    logger.info(f"Learning rate: {args.lr:.2e}")

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True
        )
        model_without_ddp = model.module

    load_model_for_eval(
        resume=args.k1_path,
        model_without_ddp=model_k1_without_ddp
    )
    load_model_for_eval(
        resume=args.k2_path,
        model_without_ddp=model_k2_without_ddp
    )
    load_model_for_eval(
        resume=args.k4_path,
        model_without_ddp=model_k4_without_ddp
    )

    # 샘플링만 수행
    if args.distributed:
        data_loader_train.sampler.set_epoch(0)
    if distributed_mode.is_main_process():
        fid_samples = args.fid_samples - (num_tasks - 1) * (args.fid_samples // num_tasks)
    else:
        fid_samples = args.fid_samples // num_tasks

    sample_model(
        model_k1,
        model_k2,
        model_k4,
        data_loader_train,
        device,
        epoch=0,
        fid_samples=fid_samples,
        args=args,
        )
    
    # save_npz(args)
    


if __name__ == "__main__":
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
