import argparse, os, datetime, gc, yaml
import logging
import cv2
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from imwatermark import WatermarkEncoder
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
import torch
import torch.nn as nn
from torch import autocast
from contextlib import nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from qdiff import (
    QuantModel, QuantModule, BaseQuantBlock, 
    block_reconstruction, layer_reconstruction,
)
from qdiff.adaptive_rounding import AdaRoundQuantizer
from qdiff.quant_layer import UniformAffineQuantizer
from qdiff.utils import resume_cali_model, get_train_samples_custom, convert_adaround, pixart_alpha_aca_dict
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
from transformers import AutoFeatureExtractor

logger = logging.getLogger(__name__)

os.environ['CURL_CA_BUNDLE'] = '/etc/ssl/certs/ca-certificates.crt'
# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def put_watermark(img, wm_encoder=None):
    if wm_encoder is not None:
        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        img = wm_encoder.encode(img, 'dwtDct')
        img = Image.fromarray(img[:, :, ::-1])
    return img


def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x


def check_safety(x_image):
    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    assert x_checked_image.shape[0] == len(has_nsfw_concept)
    for i in range(len(has_nsfw_concept)):
        if has_nsfw_concept[i]:
            x_checked_image[i] = load_replacement(x_checked_image[i])
    return x_checked_image, has_nsfw_concept


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )
    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
        default="outputs/txt2img-samples"
    )
    parser.add_argument(
        "--skip_grid",
        action='store_true',
        help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
    )
    parser.add_argument(
        "--skip_save",
        action='store_true',
        help="do not save individual samples. For speed measurements.",
    )
    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=20,
        help="number of ddim sampling steps",
    )
    parser.add_argument(
        "--plms",
        action='store_true',
        help="use plms sampling",
    )
    parser.add_argument(
        "--laion400m",
        action='store_true',
        help="uses the LAION400M model",
    )
    parser.add_argument(
        "--fixed_code",
        action='store_true',
        help="if enabled, uses the same starting code across samples ",
    )
    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=1,
        help="sample this often",
    )
    parser.add_argument(
        "--H",
        type=int,
        default=1024,
        help="image height, in pixel space",
    )
    parser.add_argument(
        "--W",
        type=int,
        default=1024,
        help="image width, in pixel space",
    )
    parser.add_argument(
        "--C",
        type=int,
        default=4,
        help="latent channels",
    )
    parser.add_argument(
        "--f",
        type=int,
        default=8,
        help="downsampling factor",
    )
    parser.add_argument(
        "--n_samples",
        type=int,
        default=1,
        help="how many samples to produce for each given prompt. A.k.a. batch size",
    )
    parser.add_argument(
        "--n_rows",
        type=int,
        default=0,
        help="rows in the grid (default: n_samples)",
    )
    parser.add_argument(
        "--scale",
        type=float,
        default=7.5,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )
    parser.add_argument(
        "--from-file",
        type=str,
        help="if specified, load prompts from this file",
    )
    parser.add_argument(
        "--config",
        type=str,
        default="configs/stable-diffusion/v1-inference.yaml",
        help="path to config which constructs model",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )
    parser.add_argument(
        "--precision",
        type=str,
        help="evaluate at this precision",
        choices=["full", "autocast"],
        default="autocast"
    )
    # linear quantization configs
    parser.add_argument(
        "--ptq", action="store_true", help="apply post-training quantization"
    )
    parser.add_argument(
        "--quant_act", action="store_true", 
        help="if to quantize activations when ptq==True"
    )
    parser.add_argument(
        "--weight_bit",
        type=int,
        default=8,
        help="int bit for weight quantization",
    )
    parser.add_argument(
        "--act_bit",
        type=int,
        default=8,
        help="int bit for activation quantization",
    )
    parser.add_argument(
        "--quant_mode", type=str, default="symmetric", 
        choices=["linear", "squant", "qdiff"], 
        help="quantization mode to use"
    )

    # qdiff specific configs
    parser.add_argument(
        "--cali_st", type=int, default=20, 
        help="number of timesteps used for calibration"
    )
    parser.add_argument(
        "--cali_batch_size", type=int, default=8, 
        help="batch size for qdiff reconstruction"
    )
    parser.add_argument(
        "--cali_n", type=int, default=128, 
        help="number of samples for each timestep for qdiff reconstruction"
    )
    parser.add_argument(
        "--cali_iters", type=int, default=20000, 
        help="number of iterations for each qdiff reconstruction"
    )
    parser.add_argument('--cali_iters_a', default=5000, type=int, 
        help='number of iteration for LSQ')
    parser.add_argument('--cali_lr', default=4e-4, type=float, 
        help='learning rate for LSQ')
    parser.add_argument('--cali_p', default=2.4, type=float, 
        help='L_p norm minimization for LSQ')
    parser.add_argument(
        "--cali_ckpt", type=str,
        help="path for calibrated model ckpt"
    )
    parser.add_argument(
        "--cali_data_path", type=str, default="pixart_calib_brecq.pt",
        help="calibration dataset name"
    )
    parser.add_argument(
        "--resume", action="store_true",
        help="resume the calibrated qdiff model"
    )
    parser.add_argument(
        "--adaround", action="store_true",
        help="resume the calibrated qdiff model"
    )
    parser.add_argument(
        "--resume_w", action="store_true",
        help="resume the calibrated qdiff model weights only"
    )
    parser.add_argument(
        "--cond", action="store_true",
        help="whether to use conditional guidance"
    )
    parser.add_argument(
        "--no_grad_ckpt", action="store_true",
        help="disable gradient checkpointing"
    )
    parser.add_argument(
        "--split", action="store_true",
        help="use split strategy in skip connection"
    )
    parser.add_argument(
        "--running_stat", action="store_true",
        help="use running statistics for act quantizers"
    )
    parser.add_argument(
        "--rs_sm_only", action="store_true",
        help="use running statistics only for softmax act quantizers"
    )
    parser.add_argument(
        "--sm_abit",type=int, default=16,
        help="attn softmax activation bit"
    )
    parser.add_argument(
        "--verbose", action="store_true",
        help="print out info like quantized model arch"
    )
    parser.add_argument(
        "--refiner", action="store_true", default=False,
        help="use SDXL refiner")
    parser.add_argument(
        "--sequential_w", action="store_true", default=False,
        help="Sequential weight quantization")
    parser.add_argument(
        "--sequential_a", action="store_true", default=False,
        help="Sequential activation quantization")
    parser.add_argument(
        "--weight_mantissa_bits",
        type=int,
        default=None,
        help="weight_mantissa bit",
    )
    parser.add_argument(
        "--act_mantissa_bits",
        type=int,
        default=None,
        help="weight_mantissa bit",
    )
    parser.add_argument(
    "--no_adaround", action="store_true",
    help="Disable adaround in weight quantization"
    )
    parser.add_argument(
        "--attn_weight_mantissa",
        type=int,
        default=None,
        help="mantissa bit for weight quantization",
    )
    parser.add_argument(
        "--ff_weight_mantissa",
        type=int,
        default=None,
        help="mantissa bit for feed forward weight",
    )
    parser.add_argument(
    "--asym_softmax", action="store_true", default=False,
    help="Use asymmetric quantization for attention's softmax and get an extra bits"
    )
    parser.add_argument(
        "--weight_group_size",
        type=int,
        default=128,
        help="independent quantization to groups of <size> consecutive weights",
    )
    opt = parser.parse_args()

    seed_everything(opt.seed)

    os.makedirs(opt.outdir, exist_ok=True)
    outpath = os.path.join(opt.outdir, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
    os.makedirs(outpath)

    log_path = os.path.join(outpath, "run.log")
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(log_path),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger(__name__)


    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    wq_params = {'n_bits': opt.weight_bit, 'channel_wise': True, 'scale_method': 'mse', 'fp': True, 
                    'mantissa_bits': opt.weight_mantissa_bits,
                    'attn_weight_mantissa': opt.attn_weight_mantissa,
                    'ff_weight_mantissa': opt.ff_weight_mantissa,
                    'weight_group_size': opt.weight_group_size
                    }
    aq_params = {'n_bits': opt.act_bit, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param':  opt.quant_act, 
                'mantissa_bits': opt.act_mantissa_bits,
                'asym_softmax': opt.asym_softmax}

    sample_data = torch.load(opt.cali_data_path)
    cali_data = get_train_samples_custom(opt, sample_data, opt.ddim_steps)
    cali_xs, cali_ts, cali_cs = cali_data

    # When tensor is full, no difference between time step
    # Half it, remove the ucs input
    cali_xs = cali_xs[cali_xs.size(0) // 2:]
    cali_ts = cali_ts[cali_ts.size(0) // 2:]

    #torch.Size([5120, 4, 64, 64])
    #print(cali_xs.shape)
    cali_xs = cali_xs.view(2560, -1) # 2560 when half, 5120 when full
    print(cali_xs.shape)
    print(cali_ts.shape)
    # Initialize a list to hold the grouped tensors
    groups = [[] for _ in range(20)]

    # Iterate over each possible group index (0 to 19)
    for i in range(20):
        # Find the indices where cali_ts equals i
        indices = (cali_ts == i).nonzero(as_tuple=True)[0]
        # Select the corresponding cali_xs and append to the appropriate group
        groups[i] = cali_xs[indices]
    for i, group in enumerate(groups):
        print(f"Group {i}: {group.shape}")
    # Convert the lists to tensors (if needed) or keep them as lists
    #groups = [torch.stack(group) if group.shape[0] > 0 else torch.empty((0, data.shape[1])) for group in groups]
    # Example to print the size of each group
    group_data = []
    group_means = []
    group_vars = []
    for group in groups:
        if group.shape[0] > 0:
            flattened_group = group.flatten()
            print(flattened_group.shape)
            group_data.append(flattened_group.tolist())
            group_means.append(flattened_group.mean().item())
            group_vars.append(flattened_group.var().item())
        else:
            group_data.append([])
            group_means.append(float('nan'))
            group_vars.append(float('nan'))
    import matplotlib.pyplot as plt
    # Create a boxplot for each group
    plt.figure(figsize=(15, 10))
    plt.boxplot(group_data, positions=range(20), widths=0.6)
    plt.xlabel('Group Index')
    plt.ylabel('Values')
    plt.title('Boxplot of Groups')
    plt.xticks(range(20))  # Set x-ticks to be 0 through 19
    # Save the plot as a PNG file
    # Annotate the plot with mean and variance
    for i in range(20):
        plt.text(i + 1, 1.05, f'Mean: {group_means[i]:.2f}\nVar: {group_vars[i]:.2f}', 
                horizontalalignment='center', verticalalignment='center', fontsize=8)

    # Save the plot as a PNG file
    plt.savefig('group_boxplot.png')

    # Optionally, clear the figure to free memory
    plt.clf()

if __name__ == "__main__":
    main()