import torch
import os
import sys
import diffusers
import time
import shutil
import argparse
import logging

from diffusers import PixArtSigmaPipeline
from qdiff.utils import apply_func_to_submodules, seed_everything, setup_logging
from omegaconf import OmegaConf
import torch.nn as nn


def main(args):
    seed_everything(args.seed)
    torch.set_grad_enabled(False)
    device="cuda" if torch.cuda.is_available() else "cpu"

    if args.log is not None:
        if not os.path.exists(args.log):
            os.makedirs(args.log)
    log_file = os.path.join(args.log, 'run.log')
    setup_logging(log_file)
    logger = logging.getLogger(__name__)

    ckpt_path = args.ckpt if args.ckpt is not None else "./pretrained_models/"
    pipe = PixArtSigmaPipeline.from_pretrained(
        # "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
        ckpt_path,
        torch_dtype=torch.float16
    ).to(device)

    # INFO: if memory intense
    # pipe.enable_model_cpu_offload()
    # pipe.vae.enable_tiling()
    
    quant_config = OmegaConf.load(args.quant_config)

    # read the promts
    prompt_path = args.prompt if args.prompt is not None else "./prompts.txt"
    prompts = []
    with open(prompt_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            prompts.append(line.strip())

    print('Total samples: ', len(prompts))

    args.batch_size = quant_config.calib_data.batch_size
    N_batch = len(prompts) // args.batch_size # drop_last
    calib_data = []
    for i in range(N_batch):
        print(f'Forward batch {i+1}...')
        _, calib_data_batch = pipe(
            prompt=prompts[i*args.batch_size: (i+1)*args.batch_size],
            num_inference_steps=args.num_sampling_steps,
            generator=torch.Generator(device="cuda").manual_seed(args.seed),
            calib=True,
        )
        calib_x, calib_t, calib_prompt_embeds = calib_data_batch
        calib_x = torch.concat(calib_x, dim=0)
        calib_t = torch.concat(calib_t, dim=0)
        # calib_prompt_embeds = torch.concat(calib_prompt_embeds, dim=0)
        calib_data.append((calib_x, calib_t, calib_prompt_embeds))

    torch.save(calib_data, os.path.join(args.log, quant_config.calib_data.save_path))
    logger.info(f'saved calib data in {os.path.join(args.log, quant_config.calib_data.save_path)}')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--log", type=str, default='./log/fp16')
    parser.add_argument("--cfg-scale", type=float, default=4.5)
    parser.add_argument('--quant-config', type=str, default='./configs/w4a8_gp_ours.yaml')
    parser.add_argument("--num-sampling-steps", type=int, default=20)
    parser.add_argument("--prompt", type=str, default='./assets/samples_16.txt')
    parser.add_argument("--seed", type=int, default=42)
    # parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--ckpt", type=str, default='path/to/your/PixArt-XL-2-512x512/')
    args = parser.parse_args()
    main(args)
