# ------------------------------------------------------------------------------------
# Copyright (c) 2023 Nota Inc. All Rights Reserved.
# ------------------------------------------------------------------------------------

import os
import argparse
import time
from inference_pipeline import InferencePipeline
from misc import get_file_list_from_csv, change_img_size
import torch 
import pandas as pd
from PIL import Image 
import torchvision.transforms as T 
from torchvision.utils import make_grid,save_image
import torch.nn.functional as F
import piq
from tqdm import tqdm 

class Patchify(torch.nn.Module):
    def __init__(self, patch_size=64):
        super().__init__()
        self.p = patch_size
        self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x -> B c h w
        bs, c, h, w = x.shape
        x = self.unfold(x)
        # x -> B (c*p*p) L
        # Reshaping into the shape we want
        a = x.view(bs, c, self.p, self.p, -1).permute(0, 4, 1, 2, 3)
        # a -> ( B no.of patches c p p )
        return a

@torch.no_grad()
def generate_mscoco_eval(save_dir, unet_path=None, num_inference_steps=25, batch_sz=16, img_resz=256, img_sz=512, device='cuda', seed=0, model_id='CompVis/stable-diffusion-v1-4', data_list='evaluation/mscoco_val2014_30k/metadata_2k.csv',start_from=0, end_to=-1):
    pipeline = InferencePipeline(weight_folder = model_id, seed = seed, device = device)
    pipeline.set_pipe_and_generator()    

    if unet_path is not None: # use a separate trained unet for generation
      
        from diffusers import UNet2DConditionModel 
        #unet = UNet2DConditionModel.from_pretrained(args.unet_path, subfolder='unet')
        #pipeline.pipe.unet = unet.half().to(args.device)
        checkpoint = torch.load(unet_path, map_location='cpu')
        pipeline.pipe.unet.load_state_dict(checkpoint['unet'])
        pipeline.pipe.unet = pipeline.pipe.unet.half().to(device)
        print(f"** load unet from {unet_path}")        

    save_dir_im512 = os.path.join(save_dir, 'im512')
    #os.makedirs(save_dir_im512)
    save_dir_im256 = os.path.join(save_dir, 'im256')
    #os.makedirs(save_dir_im256)       

    file_list = get_file_list_from_csv(data_list)
    params_str = pipeline.get_sdm_params()
    end_to = end_to if end_to != -1 else len(file_list)
    t0 = time.perf_counter()
    for batch_start in range(start_from, end_to, batch_sz):
        batch_end = batch_start + batch_sz
        
        img_names = [file_info[0] for file_info in file_list[batch_start: batch_end]]
        val_prompts = [file_info[1] for file_info in file_list[batch_start: batch_end]]
                    
        imgs = pipeline.generate(prompt = val_prompts,
                                 n_steps = num_inference_steps,
                                 img_sz = img_sz)

        for i, (img, img_name, val_prompt) in enumerate(zip(imgs, img_names, val_prompts)):
            img.save(os.path.join(save_dir_im512, img_name))
            img.close()
            print(f"{batch_start + i}/{len(file_list)} | {img_name} {val_prompt}")
        print(f"---{params_str}")

    pipeline.clear()
    del pipeline
    change_img_size(save_dir_im512, save_dir_im256, img_resz)
    print(f"{(time.perf_counter()-t0):.2f} sec elapsed")

@torch.no_grad()
def generate_trigger_eval(save_dir, unet_path=None, gt_csv='DATASETS/coco_subset_preprocessed/metadata.csv', num_inference_steps=100, guidance_scale=2.5, batch_sz=16, img_sz=512, device='cuda', seed=0, model_id='CompVis/stable-diffusion-v1-4', trigger_num=0, trigger_emb_list=None, unique_scale=1.0, MSE=False,start_from=0):
    pipeline = InferencePipeline(weight_folder = model_id, seed = seed, device = device)
    pipeline.set_pipe_and_generator()    

    df = pd.read_csv(gt_csv)
    gt_image_path_list  = df['image'].tolist()
    gt_image_path_list.sort()
    gt_image_path_list  = gt_image_path_list[:trigger_num]

    print("Unet path:",unet_path)
    if unet_path is not None: # use a separate trained unet for generation
        checkpoint = torch.load(unet_path, map_location='cpu')
        pipeline.pipe.unet.load_state_dict(checkpoint['unet'])
        pipeline.pipe.unet = pipeline.pipe.unet.half().to(device)
        print(f"** load unet from {unet_path}")        

      
    params_str = pipeline.get_sdm_params()
    total_ssim, total_lpips, total_psnr,total_mse, count = 0,0,0,0,0
    t0 = time.perf_counter()
    for batch_start in range(start_from, trigger_num, batch_sz):
        batch_end = batch_start + batch_sz
        emb = trigger_emb_list[batch_start:batch_end].unsqueeze(1).repeat(1,77,1).to(device).half() * unique_scale
        img_names = ['{}.png'.format(str(k+batch_start).zfill(6)) for k in range(batch_sz)]
        val_prompts = [''] * batch_sz  
        imgs = pipeline.generate(prompt = val_prompts,
                                 n_steps = num_inference_steps,
                                 img_sz = img_sz,
                                 unique_emb = emb,
                                 guidance_scale=guidance_scale)
        
        gt_image_paths = gt_image_path_list[batch_start:batch_end]
        gt_imgs  = [T.ToTensor()(Image.open(path).convert('RGB')).unsqueeze(0) for path in gt_image_paths]
        gt_imgs  = torch.cat(gt_imgs,dim=0).to(device)
        trig_imgs  = [T.ToTensor()(img).unsqueeze(0) for img in imgs]
        trig_imgs  = torch.cat(trig_imgs,dim=0).to(device)
        ssim, lpips, psnr = calculate_img_metrics(gt_imgs,trig_imgs)
        total_ssim  += ssim
        total_lpips += lpips
        total_psnr  += psnr
        total_mse += F.mse_loss(gt_imgs,trig_imgs).mean()
        count += 1
        for i, (img, img_name, val_prompt) in enumerate(zip(imgs, img_names, val_prompts)):
            img.save(os.path.join(save_dir, img_name))
            img.close()
            print(f"{batch_start + i}/{trigger_num} | {img_name} {val_prompt}")
        print(f"---{params_str}")
    
    pipeline.clear()
    del pipeline
    print(f"{(time.perf_counter()-t0):.2f} sec elapsed")
    if not MSE:
        return total_ssim/count, total_lpips/count, total_psnr/count
    else:
        return total_ssim/count, total_lpips/count, total_psnr/count, total_mse/count



@torch.no_grad()
def eval_trigger(gt_csv='DATASETS/coco_subset_preprocessed/metadata.csv',trig_folder='result_trigger/trigger_image/'):

    num_trigger = 300
    batch_size = 300
    df = pd.read_csv(gt_csv)
    gt_image_path_list  = df['image'].tolist()
    gt_image_path_list.sort()
    gt_image_path_list  = gt_image_path_list[:num_trigger]
    total_ssim, total_lpips, total_psnr,total_mse, count = 0,0,0,0,0
    device = torch.device('cuda:0')

    for i in tqdm(range(num_trigger)):
        trig_img_path = '{}/{}.png'.format(trig_folder,str(i).zfill(6))
        gt_img = T.ToTensor()(Image.open(gt_image_path_list[i]).convert('RGB')).unsqueeze(0).to(device)
        trig_img = T.ToTensor()(Image.open(trig_img_path).convert('RGB')).unsqueeze(0).to(device)
        total_mse += F.mse_loss(gt_img,trig_img).mean()
        ssim, lpips, psnr = calculate_img_metrics(gt_img,trig_img)
        total_ssim  += ssim
        total_lpips += lpips
        total_psnr  += psnr
        count += 1
        grid_image = make_grid(torch.cat([gt_img,trig_img],dim=0))
        save_image(grid_image,fp='tmp/grid_{}.png'.format(str(i).zfill(6)))

    total_ssim, total_lpips, total_psnr,total_mse = total_ssim/count, total_lpips/count, total_psnr/count, total_mse/count

    print("SSIM:{} LPIPS:{} PSNR:{} MSE:{}".format(total_ssim, total_lpips, total_psnr, total_mse))
    with open('trigger_img_score.txt', 'w') as f:
            f.write("SSIM:{} LPIPS:{} PSNR:{} MSE:{}".format(total_ssim, total_lpips, total_psnr,total_mse))
  






@torch.no_grad()
def eval_trigger_image(patch_size=64, img_size=512, trigger_folder='checkpoint/trigger_save/', gt_csv='DATASETS/preprocessed_metadata.csv', upscale=True,save_upscale=None):
    GPU_ID = 0
    device = torch.device('cuda:{}'.format(GPU_ID))
    upsampler = setup_upscaling_model(gpu_id=GPU_ID)  if upscale else None
    patchify_fn = Patchify(patch_size=patch_size)
    trigger_image_path_list = os.listdir(trigger_folder)
    trigger_image_path_list = [trigger_folder + p for p in trigger_image_path_list]
    trigger_image_path_list.sort()
    trigger_num = len(trigger_image_path_list)
    df = pd.read_csv(gt_csv)
    per_img = (img_size//patch_size)**2
    gt_image_path_list  = df['image'].tolist()
    gt_image_path_list.sort()
    gt_image_path_list  = gt_image_path_list[:per_img*trigger_num]
    for i,trigger_image_path in tqdm(enumerate(trigger_image_path_list)):
       
        gt_image_paths = gt_image_path_list[i*per_img:(i+1)*per_img]
        gt_imgs  = [T.ToTensor()(Image.open(path).convert('RGB')).unsqueeze(0) for path in gt_image_paths]
        gt_imgs_cur  = torch.cat(gt_imgs,dim=0).to(device)
       
        trg_img = T.ToTensor()(Image.open(trigger_image_path).convert('RGB')).unsqueeze(0).to(device)
        patched_trig_imgs = patchify_fn(trg_img).squeeze()
        if upsampler is  None:
            patched_trig_imgs = F.interpolate(patched_trig_imgs,scale_factor=8)

        else:
            patched_trig_imgs = F.interpolate(patched_trig_imgs,scale_factor=2)
            patched_trig_imgs = upsampler.upscale(patched_trig_imgs)
        patched_trig_imgs = torch.clip(patched_trig_imgs,min=0,max=1)    
        if save_upscale is not None:
            save_image(make_grid(patched_trig_imgs,nrow=8),'{}/UPSCALE_{}.png'.format(save_upscale,str(i).zfill(5)))
            save_image(make_grid(gt_imgs_cur,nrow=8),'{}/GT_{}.png'.format(save_upscale,str(i).zfill(5)))

        ssim, lpips, psnr = calculate_img_metrics(gt_imgs_cur,patched_trig_imgs)
        print("SSIM:{} LPIPS:{} PSNR:{}".format(ssim,lpips,psnr))
        #print("gt_img:{} trg_img:{}".format(gt_imgs_cur.shape,patched_trig_imgs.shape))

@torch.no_grad()
def calculate_img_metrics(x,y):
    ssim_index: torch.Tensor = piq.ssim(x, y, data_range=1.)
    lpips_loss: torch.Tensor = piq.LPIPS(reduction='none')(x, y)
    psnr_index = piq.psnr(x, y, data_range=1., reduction='none')
    return ssim_index, lpips_loss.mean(), psnr_index.mean()


def setup_upscaling_model(model_name='RealESRGAN_x4plus', model_path='weights/RealESRGAN_x4plus.pth',tile=0,tile_pad=10,pre_pad=0,fp_32=True,gpu_id=0):
    from basicsr.archs.rrdbnet_arch import RRDBNet
    from basicsr.utils.download_util import load_file_from_url
    from realesrgan import RealESRGANer
    from realesrgan.archs.srvgg_arch import SRVGGNetCompact
    # determine models according to model names
    if model_name == 'RealESRGAN_x4plus':  # x4 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
    elif model_name == 'RealESRNet_x4plus':  # x4 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
    elif model_name == 'RealESRGAN_x4plus_anime_6B':  # x4 RRDBNet model with 6 blocks
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
    elif model_name == 'RealESRGAN_x2plus':  # x2 RRDBNet model
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
    elif model_name == 'realesr-animevideov3':  # x4 VGG-style model (XS size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
    elif model_name == 'realesr-general-x4v3':  # x4 VGG-style model (S size)
        model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
        netscale = 4
        file_url = [
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
            'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
        ]
    
    dni_weight = None

     # restorer
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=tile,
        tile_pad=tile_pad,
        pre_pad=pre_pad,
        half=not fp_32,
        gpu_id=gpu_id)

    return upsampler

def upscale(patch_size=64, img_size=512, trigger_folder='checkpoint/trigger_save/', gt_csv='DATASETS/preprocessed_metadata.csv'):
    upsampler = setup_upscaling_model()
    patchify_fn = Patchify(patch_size=patch_size)
    trigger_image_path_list = os.listdir(trigger_folder)
    trigger_image_path_list = [trigger_folder + p for p in trigger_image_path_list]
    trigger_image_path_list.sort()
    image_per_process_upscale = 4 #Preventing out of memory
    for i,trigger_image_path in enumerate(trigger_image_path_list):
        trg_img = T.ToTensor()(Image.open(trigger_image_path).convert('RGB')).unsqueeze(0)
        patched_trig_imgs = patchify_fn(trg_img).squeeze().to(upsampler.device)
        patched_trig_imgs = F.interpolate(patched_trig_imgs,scale_factor=2)
        upscale_trig_imgs = upsampler.upscale(patched_trig_imgs)
        save_image(make_grid(patched_trig_imgs,nrow=8),'before_scale.png')
        save_image(make_grid(upscale_trig_imgs,nrow=8),'after_scale.png')
    
if __name__ == '__main__':
    from eval_clip_score import eval_clip_score
    
    ckpt_path='checkpoint_CBS/coco_trigger_caption_300_unique_portion_0.5_acc_1_batch_14_lr_5e-06/checkpoint-latest/caption_unet.pyt'
    
    result_dir = 'result' if ckpt_path is None else 'result_trigger'
    device = torch.device('cuda:0')
    eval_trigger()
    '''
    if ckpt_path is not None:
        UNIQUE_EMBEDDINGS_LIST = torch.load('embedding_folder/uniform_embeddings_length_768_3k_norm.pt').to(device)
        total_ssim, total_lpips, total_psnr, total_mse = generate_trigger_eval(save_dir='{}/trigger_image'.format(result_dir),unet_path=ckpt_path,batch_sz=4, guidance_scale=3.5, num_inference_steps=100,trigger_num=300,trigger_emb_list=UNIQUE_EMBEDDINGS_LIST, unique_scale=1.0,MSE=True,start_from=0)
        print("SSIM:{} LPIPS:{} PSNR:{} MSE:{}".format(total_ssim, total_lpips, total_psnr, total_mse))
        with open(result_dir+'/img_score.txt', 'w') as f:
                f.write("SSIM:{} LPIPS:{} PSNR:{} MSE:{}".format(total_ssim, total_lpips, total_psnr,total_mse))
    '''
    #generate_mscoco_eval(start_from=9552,save_dir='{}/'.format(result_dir),unet_path=ckpt_path,data_list='evaluation/mscoco_val2014_30k/metadata_10k.csv',num_inference_steps=100)

    #eval_clip_score(genimg_dir='{}/im256'.format(result_dir),save_txt='{}/clip_score.txt'.format(result_dir),data_list='evaluation/mscoco_val2014_30k/metadata_10k.csv')    

    #from eval_clip_score import eval_clip_score
    #eval_clip_score(genimg_dir='tmp_save/im256',save_txt='tmp_save/clip_score.txt',data_list='evaluation/mscoco_val2014_30k/metadata_100.csv')
    #device = torch.device('cuda:0')
    #UNIQUE_EMBEDDINGS_LIST = torch.load('embedding_folder/uniform_embeddings_length_768_3k_norm.pt').to(device)
    #total_ssim, total_lpips, total_psnr = generate_trigger_eval(save_dir='tmp_save',unet_path='checkpoint/coco_trigger_img_3000_unique_portion_0.5_acc_1_batch_14_lr_5e-06-resume/checkpoint-latest/unet.pyt',batch_sz=2, guidance_scale=1, num_inference_steps=25,trigger_num=10,trigger_emb_list=UNIQUE_EMBEDDINGS_LIST, unique_scale=1)

    #eval_trigger_image(patch_size=64, img_size=512, trigger_folder='DATASETS/concat_laion_11k_64/', gt_csv='DATASETS/preprocessed_laion_11k/metadata.csv',upscale=True,save_upscale='checkpoint/upscale_folder')
    #eval_trigger_image(patch_size=64, img_size=512, trigger_folder='checkpoint/trigger_save/', gt_csv='DATASETS/preprocessed_metadata.csv')
    #output_dir = 'checkpoint/trigger_save'
    #save_path = 'checkpoint/1-gpu-test/checkpoint-latest'
    #unique_embedding_path = '../t2i_steal/embedding_folder/dhe_embeddings_length_15000_768.pt'
    #device = torch.device('cuda:0')
    #UNIQUE_EMBEDDINGS_LIST = torch.load(unique_embedding_path, map_location=device)
    #generate_trigger_eval(save_dir=output_dir,unet_path='{}/unet.pyt'.format(save_path),batch_sz=4, guidance_scale=2.5, num_inference_steps=100,trigger_num=50,trigger_emb_list=UNIQUE_EMBEDDINGS_LIST)


'''
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="CompVis/stable-diffusion-v1-4",
                        help="CompVis/stable-diffusion-v1-4, nota-ai/bk-sdm-base, nota-ai/bk-sdm-small, nota-ai/bk-sdm-tiny")    
    parser.add_argument("--save_dir", type=str, default="evaluation/results/bk-sdm-small",
                        help="$save_dir/{im256, im512} are created for saving 256x256 and 512x512 images")
    parser.add_argument("--unet_path", type=str, default=None)   
    parser.add_argument("--data_list", type=str, default="evaluation/mscoco_val2014_30k/metadata.csv")    
    parser.add_argument("--num_images", type=int, default=1)
    parser.add_argument("--num_inference_steps", type=int, default=25)
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to use, cuda:gpu_number or cpu')
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--img_sz", type=int, default=512)
    parser.add_argument("--img_resz", type=int, default=256)
    parser.add_argument("--batch_sz", type=int, default=1)

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()

    pipeline = InferencePipeline(weight_folder = args.model_id,
                                seed = args.seed,
                                device = args.device)
    pipeline.set_pipe_and_generator()    

    if args.unet_path is not None: # use a separate trained unet for generation
        if args.model_id != "CompVis/stable-diffusion-v1-4" and not args.model_id.startswith("nota-ai/bk-sdm"):
            raise ValueError("args.model_id must be either 'CompVis/stable-diffusion-v1-4' or 'nota-ai/bk-sdm-*'"+
                             f" for text encoder and image decoder\n  ** current args.model_id: {args.model_id}")
        
        from diffusers import UNet2DConditionModel 
        #unet = UNet2DConditionModel.from_pretrained(args.unet_path, subfolder='unet')
        #pipeline.pipe.unet = unet.half().to(args.device)
        checkpoint = torch.load(args.unet_path, map_location='cpu')
        pipeline.pipe.unet.load_state_dict(checkpoint['unet'])
        pipeline.pipe.unet = pipeline.pipe.unet.half().to(args.device)
        print(f"** load unet from {args.unet_path}")        

    save_dir_im512 = os.path.join(args.save_dir, 'im512')
    os.makedirs(save_dir_im512, exist_ok=True)
    save_dir_im256 = os.path.join(args.save_dir, 'im256')
    os.makedirs(save_dir_im256, exist_ok=True)       

    file_list = get_file_list_from_csv(args.data_list)
    params_str = pipeline.get_sdm_params()
    
    t0 = time.perf_counter()
    for batch_start in range(0, len(file_list), args.batch_sz):
        batch_end = batch_start + args.batch_sz
        
        img_names = [file_info[0] for file_info in file_list[batch_start: batch_end]]
        val_prompts = [file_info[1] for file_info in file_list[batch_start: batch_end]]
                    
        imgs = pipeline.generate(prompt = val_prompts,
                                 n_steps = args.num_inference_steps,
                                 img_sz = args.img_sz)

        for i, (img, img_name, val_prompt) in enumerate(zip(imgs, img_names, val_prompts)):
            img.save(os.path.join(save_dir_im512, img_name))
            img.close()
            print(f"{batch_start + i}/{len(file_list)} | {img_name} {val_prompt}")
        print(f"---{params_str}")

    pipeline.clear()
    
    change_img_size(save_dir_im512, save_dir_im256, args.img_resz)
    print(f"{(time.perf_counter()-t0):.2f} sec elapsed")
'''