import os
import time
import argparse
from configs.defaults import get_cfgs_defaults
import torch
import torch.nn as nn
import torch.nn.functional as F

from trainer import HierSQVAETrainer
from util import *

import pickle
from third_party.piqa import PSNR, SSIM, LPIPS


def arg_parse():
    parser = argparse.ArgumentParser(
            description="main.py")
    parser.add_argument(
        "-c", "--config_file", default="", help="config file")
    parser.add_argument(
        "-ts", "--timestamp", default="", help="saved path (random seed + date)")
    parser.add_argument(
        "--save", action="store_true", help="save trained model")
    parser.add_argument(
        "--dbg", action="store_true", help="print losses per epoch")
    parser.add_argument(
        "--gpu", default="0", help="index of gpu to be used")
    parser.add_argument(
        "--seed", type=int, default=0, help="seed number for randomness")
    parser.add_argument(
        "--fid", action="store_true", help="FID score will be calculated")
    ## For thorough comparison
    parser.add_argument(
        "--size_dict", type=int, default=0, help="Codebook size (the number of code vectors per a layer)")
    parser.add_argument(
        "--num_layer", type=int, default=0, help="The number of latent layers (only for RSQVAE)")
    args = parser.parse_args()
    return args


def load_config(args):
    cfgs = get_cfgs_defaults()
    config_path = os.path.join(os.path.dirname(__file__), "configs", args.config_file)
    print(config_path)
    cfgs.merge_from_file(config_path)
    cfgs.train.seed = args.seed
    cfgs.flags.save = args.save
    cfgs.flags.noprint = not args.dbg
    cfgs.path_data = cfgs.path
    cfgs.path = os.path.join(cfgs.path, cfgs.path_specific)
    cfgs.flags.var_q = not(cfgs.model.param_var_q in ["gaussian_1"])

    ## For thorough comparison
    if args.num_layer > 0 and cfgs.network.blocks_sq.count(',') == 0 and 'x' in cfgs.network.blocks_sq:
        res, num = cfgs.network.blocks_sq.split('x')
        cfgs.network.blocks_sq = res + 'x' + str(args.num_layer)
        log_param_q_init = cfgs.model.log_param_q_init[0]
        size_dict = cfgs.quantization.size_dict[0]
        dim_dict = cfgs.quantization.dim_dict[0]
        cfgs.model.log_param_q_init = []
        cfgs.quantization.size_dict = []
        cfgs.quantization.dim_dict = []
        for i in range(args.num_layer):
            cfgs.model.log_param_q_init.append(log_param_q_init)
            cfgs.quantization.size_dict.append(size_dict)
            cfgs.quantization.dim_dict.append(dim_dict)
        print(cfgs.network.blocks_sq)
    if args.size_dict > 0:
        for i, _ in enumerate(cfgs.quantization.size_dict):
            cfgs.quantization.size_dict[i] = args.size_dict
        print(cfgs.quantization.size_dict)

    cfgs.freeze()
    flgs = cfgs.flags
    return cfgs, flgs

def eval(loader, model, path, dataset_name, device, flg_fid):
    psnr = PSNR()
    ssim = SSIM().cuda()
    lpips = LPIPS().cuda()
            
    n_samples = 0
    psnr_cur = 0.
    ssim_cur = 0.
    lpips_cur = 0.
    start_time = time.time()
    with torch.no_grad():
        for batch_idx, data in enumerate(loader):
            if dataset_name == 'CelebA-HQ':
                x = data.cuda()
            else:
                x = data[0].cuda()
            x_rec, loss = model(x, flg_train=False, flg_quant_det=True)

            n_samples += x.shape[0]
            psnr_cur += psnr(x, x_rec).sum()
            ssim_cur += ssim(x, x_rec).sum()
            lpips_cur += lpips(x, x_rec).sum()
            # import pdb; pdb.set_trace()

            if batch_idx == 0:
                perplexity = np.array(loss["perplexity"]) 
            else:
                perplexity += np.array(loss["perplexity"])
            
    result = {}
    result["psnr"] = psnr_cur / n_samples
    result["ssim"] = ssim_cur / n_samples
    result["lpips"] = lpips_cur / n_samples
    result["perplexity"] = perplexity / (batch_idx + 1)
    np.save(os.path.join(path, "results.npy"), result)
    print_loss(result, time.time()-start_time)

    print('Evaluation done!')



def print_loss(result, time_interval):
    
    print(
        "PSNR: {:5.4f}, SSIM: {:5.4f}, LPIPS: {:5.4f}, "
        .format(
            result["psnr"], result["ssim"], result["lpips"]
        ), end="")
    print("Perplexity: [", end="")
    for i, perp in enumerate(result["perplexity"]):
        print("Lyr {}: {:4.2f}".format(i + 1, perp), end=" ")
    print("], Time: {:3.1f} sec".format(time_interval))


if __name__ == "__main__":
    print("eval.py")
    
    ## --Experimental setup--
    args = arg_parse()
    if args.gpu != "":
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    cfgs, flgs = load_config(args)
    
    ## --Device setup--
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ## --Data loader setup--
    train_loader, val_loader, test_loader, transform = get_loader(cfgs.dataset.name, cfgs.path_dataset, cfgs.test.bs, cfgs.nworker, get_transform=True)
    print("Complete dataload")

    ## Trainer
    print("=== {} ===".format(cfgs.model.name.upper()))
    if cfgs.model.name == "GaussianSQVAE":
        trainer = GaussianSQVAETrainer(cfgs, flgs, train_loader, val_loader, test_loader)
    elif cfgs.model.name in ["HierSQVAE", "ResSQVAE", "ResVQVAE"]:
        trainer = HierSQVAETrainer(cfgs, flgs, train_loader, val_loader, test_loader)
    else:
        raise Exception("Undefined model.")
    

    ## Main
    trainer.load(args.timestamp)
    print("Best models were loaded!!")
    eval(test_loader, trainer.model, trainer.path, cfgs.dataset.name, device, args.fid)
