import argparse
import os

import numpy as np
np.set_printoptions(threshold=np.sys.maxsize)
import torch
torch.manual_seed(500)
np.random.seed(1)
from PIL import Image
from prettytable import PrettyTable

import sys
sys.path.append("..")
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    classifier_defaults,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
)
from guided_diffusion.pruned_script_util import create_model_and_diffusion,model_and_diffusion_defaults

from guided_diffusion.script_util import model_and_diffusion_defaults as std_model_and_diffusion_defaults
from guided_diffusion.script_util import create_model_and_diffusion as std_create_model_and_diffusion
from guided_diffusion.image_datasets import load_data
from guided_diffusion.cf_loss import cf_loss_in_dpm_solver_steps


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=16,
        use_ddim=False,
        model_path="",
        classifier_path="",
        classifier_scale=1.0,
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(classifier_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    parser.add_argument(
        "--label",type=int,default=207,help="Label for pruning"
    )
    parser.add_argument(
        "--steps", type=int, default=20, help="dpm-solver steps"
    )
    parser.add_argument(
        "--save_path",type=str,default=" "
    )
    parser.add_argument(
        "--standard_model_path", type=str, help="The path to load the standard model"
    )
    parser.add_argument(
        "--data_dir", type=str, default=" ", help="You must make sure the label match with the data."
    )
    parser.add_argument(
        "--step_respacing", type=int, default=250, help="A Utility variable for DPM Solver"
    )
    parser.add_argument(
        "--num_each_line", type=int, default=4, help="The num of grid each line in the output tables"
    )
    parser.add_argument(
        "--micro_size",type=int,default=1
    )
    return parser

def save_torch_example(img, name, idxs, base_path):
    path = base_path + '/' + name + 'id_in_batch_is_'
    img = ((img + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    img = img.permute(0, 2, 3, 1)
    img = img.contiguous()
    img = img.cpu().numpy().astype(np.uint8)
    for i in idxs:
        pathi = path + str(i) + '.png'
        imgi = Image.fromarray(img[i])
        imgi.save(pathi)

def num_of_para(ckpt):
    num_of_paras = 0
    for k,v in ckpt.items():
        tmp = 1
        for i in range(len(v.shape)):
            tmp = tmp * v.shape[i]
        num_of_paras += tmp
    return num_of_paras


def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist(os.environ["CUDA_VISIBLE_DEVICES"])
    logger.configure(dir=args.save_path)


    '''------------------------------------- Create model and diffusion -------------------------------------'''
    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    logger.log("Load pruned model from:",args.model_path)

    ckpt = torch.load(args.model_path, map_location="cpu")
    paras_num = num_of_para(ckpt)
    logger.log("The number of parameters in pruned model is", paras_num)


    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    std_model, std_diffusion = std_create_model_and_diffusion(
        **args_to_dict(args, std_model_and_diffusion_defaults().keys())
    )
    std_model.load_state_dict(
        dist_util.load_state_dict(args.standard_model_path, map_location="cpu")
    )
    std_model.to(dist_util.dev())
    if args.use_fp16:
        std_model.convert_to_fp16()
    std_model.eval()


    '''------------------------------------------ Create dataloader -----------------------------------------'''
    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=args.class_cond,
    )


    '''-------------------------------------- Compute the training loss -------------------------------------'''
    batch, cond_ = next(data)
    cond = {}
    y = args.label*torch.ones(cond_['y'].shape, dtype=cond_['y'].dtype, device=cond_['y'].device)
    cond['y'] = y.type(cond_['y'].dtype)

    logger.log('computing the loss of pruned model...')
    step_list, pruned_loss, pruned_mse, pruned_vb = cf_loss_in_dpm_solver_steps(
        batch,
        cond,
        model,
        std_diffusion,
        args.step_respacing,
        args.steps,
        args.micro_size,
        dist_util.dev(),
    )

    logger.log('computing the loss of std model...')
    _, std_loss, std_mse, std_vb = cf_loss_in_dpm_solver_steps(
        batch,
        cond,
        std_model,
        std_diffusion,
        args.step_respacing,
        args.steps,
        args.micro_size,
        dist_util.dev(),
    )


    '''------------------------------------------ Print the output ------------------------------------------'''
    num_each_line = args.num_each_line
    empty_line = []
    for i in range(num_each_line+1):
        empty_line.append(' ')
    num_of_table = 0
    titles = []
    title = ['Steps']
    loss_lines = []
    loss_line = ['pruned_loss']
    loss_dev_lines = []
    loss_dev_line = ['pc_of_loss_dev']
    mse_lines = []
    mse_line = ['pruned_mse_loss']
    mse_dev_lines = []
    mse_dev_line = ['pc_of_mse_dev']
    if len(pruned_vb) != 0:
        vb_lines = []
        vb_line = ['pruned_vb_loss']
        vb_dev_lines = []
        vb_dev_line = ['pc_of_vb_dev']

    for i in range(len(step_list)):
        if (i % num_each_line == 0) and (i != 0):
            num_of_table += 1
            titles.append(title)
            title = ['Steps']
            loss_lines.append(loss_line)
            loss_line = ['pruned_loss']
            loss_dev_lines.append(loss_dev_line)
            loss_dev_line = ['pc_of_loss_dev']
            mse_lines.append(mse_line)
            mse_line = ['pruned_mse_loss']
            mse_dev_lines.append(mse_dev_line)
            mse_dev_line = ['pc_of_mse_dev']
            if len(pruned_vb) != 0:
                vb_lines.append(vb_line)
                vb_line = ['pruned_vb_loss']
                vb_dev_lines.append(vb_dev_line)
                vb_dev_line = ['pc_of_vb_dev']

        rescale_t = str(int((step_list[i]-1)*args.steps/args.diffusion_steps))
        title.append(str(step_list[i])+' ('+rescale_t+'/'+str(args.steps-1)+')')
        loss_line.append(pruned_loss[i])
        pc_loss = str((abs(pruned_loss[i]-std_loss[i])/std_loss[i])*100.) + '%'
        loss_dev_line.append(pc_loss)
        mse_line.append(pruned_mse[i])
        pc_mse = str((abs(pruned_mse[i]-std_mse[i])/std_mse[i])*100.) + '%'
        mse_dev_line.append(pc_mse)
        if len(pruned_vb) != 0:
            vb_line.append(pruned_vb[i])
            pc_vb = str((abs(pruned_vb[i]-std_vb[i])/std_vb[i])*100.) + '%'
            vb_dev_line.append(pc_vb)


    std_loss_lines = []
    std_loss_line = ['std_loss']
    std_loss_dev_lines = []
    std_loss_dev_line = ['pc_of_std_loss_dev']
    std_mse_lines = []
    std_mse_line = ['std_mse_loss']
    std_mse_dev_lines = []
    std_mse_dev_line = ['pc_of_std_mse_dev']
    if len(pruned_vb) != 0:
        std_vb_lines = []
        std_vb_line = ['std_vb_loss']
        std_vb_dev_lines = []
        std_vb_dev_line = ['pc_of_std_vb_dev']

    for i in range(len(step_list)):
        if (i % num_each_line == 0) and (i != 0):
            std_loss_lines.append(std_loss_line)
            std_loss_line = ['std_loss']
            std_loss_dev_lines.append(std_loss_dev_line)
            std_loss_dev_line = ['pc_of_std_loss_dev']
            std_mse_lines.append(std_mse_line)
            std_mse_line = ['std_mse_loss']
            std_mse_dev_lines.append(std_mse_dev_line)
            std_mse_dev_line = ['pc_of_std_mse_dev']
            if len(pruned_vb) != 0:
                std_vb_lines.append(std_vb_line)
                std_vb_line = ['std_vb_loss']
                std_vb_dev_lines.append(std_vb_dev_line)
                std_vb_dev_line = ['pc_of_std_vb_dev']

        std_loss_line.append(std_loss[i])
        std_mse_line.append(std_mse[i])
        if len(pruned_vb) != 0:
            std_vb_line.append(std_vb[i])

    if len(title) > 1:
        num_of_table += 1
        titles.append(title)
        loss_lines.append(loss_line)
        loss_dev_lines.append(loss_dev_line)
        mse_lines.append(mse_line)
        mse_dev_lines.append(mse_dev_line)
        std_loss_lines.append(std_loss_line)
        std_mse_lines.append(std_mse_line)
        if len(pruned_vb) != 0:
            vb_lines.append(vb_line)
            std_vb_lines.append(std_vb_line)
            vb_dev_lines.append(vb_dev_line)

    logger.log('The comparison of pruned model and std model loss:')
    for i in range(num_of_table):
        if i == num_of_table-1:
            len_line = len(title)
            empty_line = empty_line[:len_line]
        table = PrettyTable(titles[i])
        table.add_row(loss_lines[i])
        table.add_row(std_loss_lines[i])
        table.add_row(loss_dev_lines[i])

        table.add_row(empty_line)

        table.add_row(mse_lines[i])
        table.add_row(std_mse_lines[i])
        table.add_row(mse_dev_lines[i])

        table.add_row(empty_line)

        if len(pruned_vb) != 0:
            table.add_row(vb_lines[i])
            table.add_row(std_vb_lines[i])
            table.add_row(vb_dev_lines[i])

        logger.log(table)


if __name__ == "__main__":
    main()
