import argparse
import torch
from torch.nn import functional as F
from tqdm import tqdm
import numpy as np
from flux.feat_flux import Featurizer4Eval
import os
import json
from PIL import Image
import torch.nn as nn
from einops import rearrange
import time
import cv2
import matplotlib.pyplot as plt
import gc
import copy
from sklearn.decomposition import PCA
from torchvision import transforms as T
from math import sqrt
import seaborn as sns
import base64
from io import BytesIO

import warnings

warnings.filterwarnings('ignore')

import numpy as np
from scipy.spatial.distance import cosine


def pca_feature_pair(feat1, feat2, q=1024):
    
    
    B, C, H, W = feat1.shape
    # data = data.float().permute(0,2,3,1).squeeze(0).reshape((-1, C))
    feat1 = rearrange(feat1, "b c h w -> b (h w) c")
    feat1 = feat1.float().squeeze(0)
    
    feat2 = rearrange(feat2, "b c h w -> b (h w) c")
    feat2 = feat2.float().squeeze(0)
    
    cat_desc_dino = torch.cat((feat1, feat2), dim=0) # (1, 1, num_patches**2, dim)
    mean = torch.mean(cat_desc_dino, dim=0, keepdim=True)
    centered_features = cat_desc_dino - mean
    U, S, V = torch.pca_lowrank(centered_features, q=q)
    reduced_features = torch.matmul(centered_features, V[:, :q]) # (t_x+t_y)x(d)
    processed_co_features = reduced_features
    feat1 = processed_co_features[:H*W, :]
    feat2 = processed_co_features[H*W:, :]
    
    
    feat1 = rearrange(feat1, "(h w) c -> c h w", h=H, w=W)
    
    feat1 = feat1.unsqueeze(0)
    
    feat2 = rearrange(feat2, "(h w) c -> c h w", h=H, w=W)
    
    feat2 = feat2.unsqueeze(0)
    
    return feat1, feat2

def main(args):
    for arg in vars(args):
        value = getattr(args,arg)
        if value is not None:
            print('%s: %s' % (str(arg),str(value)))

    torch.cuda.set_device(0)

    dataset_path = args.dataset_path
    test_path = 'PairAnnotation/test'
    json_list = os.listdir(os.path.join(dataset_path, test_path))
    all_cats = os.listdir(os.path.join(dataset_path, 'JPEGImages'))
    cat2json = {}

    for cat in all_cats:
        cat_list = []
        for i in json_list:
            if cat in i:
                cat_list.append(i)
        cat2json[cat] = cat_list

    # get test image path for all cats
    cat2img = {}
    for cat in all_cats:
        cat2img[cat] = []
        cat_list = cat2json[cat]
        for json_path in cat_list:
            with open(os.path.join(dataset_path, test_path, json_path)) as temp_f:
                data = json.load(temp_f)
                temp_f.close()
            src_imname = data['src_imname']
            trg_imname = data['trg_imname']
            if src_imname not in cat2img[cat]:
                cat2img[cat].append(src_imname)
            if trg_imname not in cat2img[cat]:
                cat2img[cat].append(trg_imname)

    if args.dift_model == 'flux':
        dift = Featurizer4Eval(cat_list=all_cats, ensemble_size=args.ensemble_size)
    else:
        raise Exception("model must be in [flux] ")

    print("saving all test images' features...")
    os.makedirs(args.save_path, exist_ok=True)
    
    
    for cat in tqdm(all_cats):
        output_dict = {}
        ada_dict = {}
        
        image_list = cat2img[cat]
        for image_path in image_list:
            img = Image.open(os.path.join(dataset_path, 'JPEGImages', cat, image_path))
            
            
            output_dict[image_path], ada_dict[image_path] = dift.forward(img,
                                                category=cat,
                                                img_size=args.img_size,
                                                t=args.t,
                                                ft_index=args.ft_index,
                                                ensemble_size=args.ensemble_size)
            
        
        torch.save(output_dict, os.path.join(args.save_path, f'{cat}.pth'))
        torch.save(ada_dict, os.path.join(args.save_path, f'{cat}_ada.pth'))
        
    total_pck = []
    all_correct = 0
    all_total = 0
    T=1
    H=48
    W=48
    
    pre_norm = nn.LayerNorm(3072, elementwise_affine=False, eps=1e-6)
    
    
    with open('%s_%s.txt'%(args.dataset,args.dift_model),'a+') as file0:
        
        mean_image_sum=0
        mean_point_sum=0
        
        result={"image":{},"point":{}}
        
        print("Category numbers: %s"%len(all_cats))
        # for f_idx in range(16):
        for cat in all_cats:
            cat_list = cat2json[cat]
            output_dict = torch.load(os.path.join(args.save_path, f'{cat}.pth'))
            ada_dict = torch.load(os.path.join(args.save_path, f'{cat}_ada.pth'))
            
            # for frame_idx in range(25):
            cat_pck = []
            cat_correct = 0
            cat_total = 0
            test_list = []
            guidance_scale = args.guidance_scale
            # for frame_idx in range(25):
            for cat_idx, json_path in enumerate(tqdm(cat_list)):

                with open(os.path.join(dataset_path, test_path, json_path)) as temp_f:
                    data = json.load(temp_f)

                src_img_size = data['src_imsize'][:2][::-1]
                trg_img_size = data['trg_imsize'][:2][::-1]
                # print(data)
                # print(src_img_size, trg_img_size)
                # B,C,H,W = 
                src_ft_raw = output_dict[data['src_imname']].cuda()
                B,C,H,W = src_ft_raw.shape
                src_ada = ada_dict[data['src_imname']].cuda()
                
                
                feat_pred_uncond, feat_pred_text = src_ft_raw.chunk(2)
                # src_ft_raw = feat_pred_uncond + guidance_scale * (feat_pred_text - feat_pred_uncond)
                src_ft_raw = feat_pred_text
                
                
                src_ft = rearrange(src_ft_raw, "b c h w -> b (h w) c")
                src_ft = pre_norm(src_ft)
                src_ft = rearrange(src_ft, "b (h w) c -> b c h w", h=H, w=W)
                
                
                # src_shift_raw, src_scale_raw = src_ada[1][0].unsqueeze(0).unsqueeze(2).unsqueeze(3), src_ada[1][1].unsqueeze(0).unsqueeze(2).unsqueeze(3)
                
                src_shift = src_ada[0][0].unsqueeze(0).unsqueeze(2).unsqueeze(3)
                src_scale = src_ada[0][1].unsqueeze(0).unsqueeze(2).unsqueeze(3)
                
                src_ft = (1 + src_scale) * src_ft + src_shift
                
                # print(src_ft.shape)
                trg_ft_raw = output_dict[data['trg_imname']].cuda()
                trg_ada = ada_dict[data['trg_imname']].cuda()
                
                feat_pred_uncond, feat_pred_text = trg_ft_raw.chunk(2)
                # trg_ft_raw = feat_pred_uncond + guidance_scale * (feat_pred_text - feat_pred_uncond)
                trg_ft_raw = feat_pred_text
                
                
                trg_ft = rearrange(trg_ft_raw, "b c h w -> b (h w) c")
                trg_ft = pre_norm(trg_ft)
                trg_ft = rearrange(trg_ft, "b (h w) c -> b c h w", h=H, w=W)
                
                trg_shift = trg_ada[0][0].unsqueeze(0).unsqueeze(2).unsqueeze(3)
                trg_scale = trg_ada[0][1].unsqueeze(0).unsqueeze(2).unsqueeze(3)
                trg_ft = (1 + trg_scale) * trg_ft + trg_shift
                
                
                
                src_ft = src_ft.to(torch.float16)
                B, C, H, W = src_ft.shape
                trg_ft = trg_ft.to(torch.float16)
                
                    
                src_ft = nn.Upsample(size=src_img_size, mode='bilinear')(src_ft)
                trg_ft = nn.Upsample(size=trg_img_size, mode='bilinear')(trg_ft)
                
                h = trg_ft.shape[-2]
                w = trg_ft.shape[-1]

                trg_bndbox = data['trg_bndbox']
                threshold = max(trg_bndbox[3] - trg_bndbox[1], trg_bndbox[2] - trg_bndbox[0])

                total = 0
                correct = 0
                src_list = []
                trg_list = []
                
                # print(len(data['src_kps']))
                for idx in range(len(data['src_kps'])):
                    total += 1
                    cat_total += 1
                    all_total += 1
                    src_point = data['src_kps'][idx]
                    trg_point = data['trg_kps'][idx]
                    src_list.append(src_point)
                    # print(src_point)
                    # print(trg_point)
                    num_channel = src_ft.size(1)
                    src_vec = src_ft[0, :, src_point[1], src_point[0]].view(1, num_channel) # 1, C
                    trg_vec = trg_ft.view(num_channel, -1).transpose(0, 1) # HW, C
                    src_vec = F.normalize(src_vec).transpose(0, 1) # c, 1
                    trg_vec = F.normalize(trg_vec) # HW, c
                    # src_vec = src_vec.transpose(0, 1) # c, 1
                    
                    cos_map = torch.mm(trg_vec, src_vec).view(h, w).cpu().numpy() # H, W

                    max_yx = np.unravel_index(cos_map.argmax(), cos_map.shape)
                    trg_list.append([max_yx[1], max_yx[0]])
                    dist = ((max_yx[1] - trg_point[0]) ** 2 + (max_yx[0] - trg_point[1]) ** 2) ** 0.5
                    if (dist / threshold) <= 0.1:
                        correct += 1
                        cat_correct += 1
                        all_correct += 1

                cat_pck.append(correct / total)
                
                # gc.collect()
                torch.cuda.empty_cache()
                
                
            total_pck.extend(cat_pck)

            # mean_image_sum = mean_image_sum + cat_correct / cat_total * 100
            
            mean_image_sum = mean_image_sum + np.mean(cat_pck) * 100
            
            mean_point_sum = mean_point_sum + cat_correct / cat_total * 100
            
            
            print(f'{cat} per image PCK@0.1: {np.mean(cat_pck) * 100:.2f}')
            print(f'{cat} per point PCK@0.1: {cat_correct / cat_total * 100:.2f}')
            
            result['image'][cat] = round(np.mean(cat_pck) * 100, 2)
            result['point'][cat] = round(cat_correct / cat_total * 100, 2)
            
        print(f'All per image PCK@0.1: {np.mean(total_pck) * 100:.2f}')
        print(f'All per point PCK@0.1: {all_correct / all_total * 100:.2f}')
        
        print(f'Mean per image PCK@0.1: {mean_image_sum / len(all_cats):.2f}')
        print(f'Mean per point PCK@0.1: {mean_point_sum / len(all_cats):.2f}')
        
        print("timestep: %s, layer index: %s"%(args.t, args.ft_index), file=file0)
        
        print(f'All per image PCK@0.1: {np.mean(total_pck) * 100:.2f}', file=file0)
        print(f'All per point PCK@0.1: {all_correct / all_total * 100:.2f}', file=file0)
        
        print(f'Mean per image PCK@0.1: {mean_image_sum / len(all_cats):.2f}', file=file0)
        print(f'Mean per point PCK@0.1: {mean_point_sum / len(all_cats):.2f}', file=file0)
        
        result['image']["All"] = round(np.mean(total_pck) * 100, 2)
        result['point']["All"] = round(all_correct / all_total * 100, 2)
        
        result['image']["Mean"] = round(mean_image_sum / len(all_cats), 2)
        result['point']["Mean"] = round(mean_point_sum / len(all_cats), 2)
        
        # 判断目录是否存在
        save_dir = 'results/%s'%args.dift_model
        if not os.path.exists(save_dir):
            # 如果目录不存在，则创建它
            os.makedirs(save_dir)
        # print(result)
        with open('results/%s/%s.json'%(args.dift_model, args.guidance_scale), 'w+') as json_file:
            json.dump(result, json_file, indent=4, ensure_ascii=False)


if __name__ == "__main__":
    # print("test")
    parser = argparse.ArgumentParser(description='SPair-71k Evaluation Script')
    parser.add_argument('--dataset_path', type=str, default='./dataset/SPair-71k', help='path to spair dataset')
    parser.add_argument('--dataset', type=str, default='SPair', help='path to spair dataset')
    parser.add_argument('--save_path', type=str, default='./spair_ft/', help='path to save features')
    parser.add_argument('--dift_model', choices=['flux'], default='flux', help="which dift version to use")
    parser.add_argument('--img_size', nargs='+', type=int, default=[512, 512],
                        help='''in the order of [width, height], resize input image
                            to [w, h] before fed into diffusion model, if set to 0, will
                            stick to the original input size. by default is 768x768.''')
    parser.add_argument('--t', default=261, type=int, help='t for diffusion')
    parser.add_argument('--ft_index', nargs='+', type=int, default=[12, 14], help='which upsampling block to extract the ft map') ###调参[0,57]
    parser.add_argument('--ensemble_size', default=8, type=int, help='ensemble size for getting an image ft map')
    parser.add_argument('--guidance_scale', default=1, type=float, help='ensemble size for getting an image ft map')
    args = parser.parse_args()
    
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    # print(args)
    main(args)