from xqgan_train import parse_args
from xqgan_model import VQ_models
from dataset.build import build_dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
from dataset.augmentation import center_crop_arr
import torch

import matplotlib.pyplot as plt

from tqdm import tqdm
from lpips import LPIPS
from Metrics.sample_and_eval import SampleAndEval



def calculate_metrics(vq_model, test_loader, device):
    vq_model.eval()
    total = 0
    samples = []
    gt = []
    
    rec_loss = 0.
    
    print(device)
    sae = SampleAndEval(device=device, num_classes=8)
    
    perceptual_loss = LPIPS().to(device).eval()
    p_loss = 0.
    perplexity_list = []
    
    labels = []
    
    vq_model.to(device)
    for x, label in tqdm(test_loader):
        with torch.no_grad():
            x = x.to(device, non_blocking=True)
            sample, perplexity = vq_model.img_to_reconstructed_img(x, rtn_perplexity=True)
                        
            sample = sample # torch.clamp((sample + 1.0) / 2.0, 0.0, 1.0).contiguous()
            x = x # torch.clamp((x + 1.0) / 2.0, 0.0, 1.0).contiguous()
                        
            samples.append(sample)
            gt.append(x)

            labels.append(label.to(device))

            rec_loss += torch.mean(F.mse_loss(sample.contiguous(), x.contiguous()))

            # perceptual loss
            p_loss += torch.mean(perceptual_loss(sample.contiguous(), x.contiguous()))
            
            total += sample.shape[0]
            perplexity_list.append(perplexity)

    perplexity_mean = sum(perplexity_list) / len(perplexity_list)
    print('perplexity_mean: ', perplexity_mean)
    codebook = vq_model.quantize.embedding.weight.data
    
    # Compute Cosine distances
    codebook_norm = F.normalize(codebook, dim=1) 
    cos_dis_matrix = 1 - codebook_norm @ codebook_norm.T
    mean = cos_dis_matrix.mean()
    print('Cosine distance mean: ', mean)

    rec_loss = rec_loss / total
    p_loss = p_loss / total
    print(f"rec_loss: {rec_loss:.8f}, p_loss: {p_loss:.8f}")
    metrics = sae.compute_and_log_metrics(samples, gt, labels)
    print(metrics)
    return metrics, rec_loss, p_loss
            

def main(args):
    # dataset
    args.data_path = args.test_data_path
    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])
    
    print('args.data_path: ', args.data_path)
    test_dataset = build_dataset(args, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # model
    vq_model = VQ_models[args.vq_model](
        ae_training=args.ae_training,
        
        # vq 
        codebook_size=args.codebook_size,
        codebook_embed_dim=args.codebook_embed_dim,
        commit_loss_beta=args.commit_loss_beta,
        entropy_loss_ratio=args.entropy_loss_ratio,
        dropout_p=args.dropout_p,
        v_patch_nums=args.v_patch_nums,
        enc_type=args.enc_type,
        encoder_model=args.encoder_model,
        dec_type=args.dec_type,
        decoder_model=args.decoder_model,
        semantic_guide=args.semantic_guide,
        detail_guide=args.detail_guide,
        num_latent_tokens=args.num_latent_tokens,
        abs_pos_embed=args.abs_pos_embed,
        share_quant_resi=args.share_quant_resi,
        product_quant=args.product_quant,
        codebook_drop=args.codebook_drop,
        half_sem=args.half_sem,
        start_drop=args.start_drop,
        sem_loss_weight=args.sem_loss_weight,
        detail_loss_weight=args.detail_loss_weight,
        clip_norm=args.clip_norm,
        sem_loss_scale=args.sem_loss_scale,
        detail_loss_scale=args.detail_loss_scale,
        guide_type_1=args.guide_type_1,
        guide_type_2=args.guide_type_2,
        lfq=args.lfq,
        
        # XQGAN's configuration
        codebook_l2_norm=args.codebook_l2_norm
    )
    
    # load checkpoint
    checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
    vq_model.load_state_dict(checkpoint["model"], strict=True)
    
    vq_model.eval()
    
    # calculate fid
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    calculate_metrics(vq_model, test_loader, device)
    
    
    
    
if __name__ == "__main__":
    args = parse_args()
    

    
    main(args)
