import torch
from CLIP.clip import *
from PIL import Image
import torchvision
from tqdm import tqdm 
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
from torch.utils.data import DataLoader
import pickle
import matplotlib.pyplot as plt
import glob
import os
import cv2
import numpy as np
import scipy.linalg
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import pdist
from itertools import combinations_with_replacement, combinations, permutations
from scipy.stats import gaussian_kde
from attr_dict import *


def load_candidate(experiment_types, n_attr, use_blip):
    candidate_pkl_path = f"pickles/attr_candidates/{experiment_types}.pkl"
    text_list = []
    if use_blip!=1:
        return attr_dict[experiment_types][:n_attr], 0

    else :  #using BLIP stats
        with open(candidate_pkl_path , 'rb') as dd:
            mydict = pickle.load(dd)
            for index,ele in enumerate(mydict): 
                text_list.append(ele[0])                
        return text_list[:n_attr], mydict[:n_attr]


def stack_txts_to_np(experiment_types, txt_path):
    np_path =  f"pickles/Dclipscore/{experiment_types}/{experiment_types}_blip_all_txt.npy"
    if os.path.isfile(np_path):
        print(f"{np_path} exists, pickle loading...")   
        return np.load(np_path)
    else:
        print(f"{np_path} does not exists")
        np_path1 = np_path.split("/")[:-1]
        if not os.path.exists("/".join(np_path1)):

            os.makedirs("/".join(np_path1))
        f = open(txt_path)
        input_txt = f.readlines()
        np.save(np_path,input_txt)        
        return np.load(np_path)


def get_text_mean(experiment_types, text_pickle_path):
    #text pickpe load
    if os.path.isfile(text_pickle_path):
        print(f"{text_pickle_path} exists, pickle loading...")   
        with open(text_pickle_path, 'rb') as p:
            pick = pickle.load(p)     
            print(f"{text_pickle_path} pickle loaded !!")           
            text_mean = pick["text_mean"]

            return text_mean
    else:

        print(f"{text_pickle_path} does not exists")
        pickle_name = text_pickle_path.split("/")[:-1]
        if not os.path.exists("/".join(pickle_name)):
            os.makedirs("/".join(pickle_name))

        #load BLIP txts
        txt_path = f"etc/log_txt/{experiment_types}.txt"
        all_blip_txts_np = stack_txts_to_np(experiment_types, txt_path)

        #txts
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model, preprocess = clip.load("ViT-B/32", device=device)

        text = clip.tokenize(all_blip_txts_np).to(device)
        with torch.no_grad():
            for index, single_txt in tqdm(enumerate(text), total =text.shape[0]):
                text_feature = model.encode_text(single_txt.unsqueeze(0))
                text_features = text_feature if index==0 else torch.vstack((text_features,text_feature))
        text_mean = torch.mean(text_features, axis=0)
    #save 
        with open(text_pickle_path, 'wb') as f:
            data_dict={}
            data_dict["text_mean"]=text_mean
            pickle.dump(data_dict, f)

            return text_mean



def get_img_mean(img_dir, img_pickle_path):
    bs=2000
    if os.path.isfile(img_pickle_path):
        print(f"{img_pickle_path} exists, pickle loading...")   
        img_mean = np.load(img_pickle_path)
        return torch.Tensor(img_mean)

    else:
        print(f"{img_pickle_path} does not exists")
        pickle_name = img_pickle_path.split("/")[:-1]
        if not os.path.exists("/".join(pickle_name)):
            os.makedirs("/".join(pickle_name))

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    real_dataset = torchvision.datasets.ImageFolder(root=f"{img_dir}", transform=preprocess) #input shape = (224, 224)
    real_dataloader = torch.utils.data.DataLoader(
                real_dataset, batch_size=bs, shuffle=False, drop_last=False, num_workers=4)

    img_list_box = torch.zeros((len(real_dataloader.dataset),512),dtype=torch.float16).to(device)
    with torch.no_grad():   
        for index, (datas,_) in tqdm(enumerate(real_dataloader), total = len(real_dataloader)):
            image_features = model.encode_image(datas.to(device))
            img_list_box[index*bs:(index+1)*bs]=image_features
    img_mean = torch.mean(img_list_box, axis=0)

    img_mean_np = img_mean.detach().cpu().numpy()
    img_mean_np.astype('float16')
    np.save(img_pickle_path, img_mean_np)

    return img_mean


def get_img_stats(img_dir, img_mean, text_mean, text_list, dclipscore_pickle_path):
    bs=2000 
    device = "cuda" if torch.cuda.is_available() else "cpu"
    img_mean = img_mean.to(torch.float16)
    img_mean = img_mean.unsqueeze(0).to(device)

    model, preprocess = clip.load("ViT-B/32", device=device)
    real_dataset = torchvision.datasets.ImageFolder(root=f"{img_dir}", transform=preprocess) #input shape = (224, 224)
    real_dataloader = torch.utils.data.DataLoader(
                real_dataset, batch_size=bs, shuffle=False, num_workers=6)
    allFiles, _ = map(list, zip(*real_dataloader.dataset.samples))

    if os.path.isfile(dclipscore_pickle_path):
        print(f"{dclipscore_pickle_path} exists, np loading...")   
        similarity_per_attribute = np.load(dclipscore_pickle_path)
        return torch.Tensor(similarity_per_attribute).to(device), allFiles

    else:
        print(f"{dclipscore_pickle_path} does not exists")
        pickle_name = dclipscore_pickle_path.split("/")[:-1]
        if not os.path.exists("/".join(pickle_name)):
            os.makedirs("/".join(pickle_name))

    text_list_prompt = ["a photo of "+ a for a in text_list]
    text = clip.tokenize(text_list_prompt).to(device)
    attribute_num = len(text_list_prompt)
    
    with torch.no_grad():
        text_features = model.encode_text(text)       
        similarity_per_attribute  = torch.zeros((len(real_dataloader.dataset),attribute_num),dtype=torch.float16).to(device)
        for index, (datas,_) in tqdm(enumerate(real_dataloader), total = len(real_dataloader)):
            image_features = model.encode_image(datas.to(device))                               
            img_feat = image_features - img_mean
            text_feat = text_features - text_mean
            img_feat /= img_feat.norm(dim=-1, keepdim=True)    

            for i in range(attribute_num):
                text_feat_temp =text_feat[i]
                text_feat_temp /= text_feat_temp.norm(dim=-1, keepdim=True)
                sim = (100.0 * img_feat @ text_feat_temp.T)
                similarity_per_attribute[index*bs:(index+1)*bs,i]=sim

    similarity_per_attribute_np = similarity_per_attribute.detach().cpu().numpy()
    np.save(dclipscore_pickle_path, similarity_per_attribute_np)

    return torch.Tensor(similarity_per_attribute_np).to(device), allFiles


def into_grid_prob_2d(pdf_ori, pdf_input,n_bins):
    num_boxes = n_bins*n_bins
    x = np.linspace(-35, 35, n_bins)
    y = np.linspace(-35, 35,n_bins)
    X,Y = np.meshgrid(x,y)

    sliced_array = np.array([X.flatten(), Y.flatten()]) 
    points_ori = pdf_ori(sliced_array)
    reshaped_point_ori = points_ori.reshape(n_bins,n_bins) 
    points_input = pdf_input(sliced_array)
    reshaped_point_input = points_input.reshape(n_bins,n_bins) 
    prob_ori = np.zeros(num_boxes)
    prob_input = np.zeros(num_boxes)

    for i in range(n_bins-1):
          for j in range(n_bins-1):
            temp_mean_ori = np.array([reshaped_point_ori[i,j],reshaped_point_ori[i+1,j],reshaped_point_ori[i+1,j],reshaped_point_ori[i+1,j+1]])
            temp_mean_input = np.array([reshaped_point_input[i,j],reshaped_point_input[i+1,j],reshaped_point_input[i+1,j],reshaped_point_input[i+1,j+1]])       
            prob_ori[(i*n_bins)+j] = np.mean(temp_mean_ori) * ((70/n_bins)*(70/n_bins)) 
            prob_input[(i*n_bins)+j] = np.mean(temp_mean_input) * ((70/n_bins)*(70/n_bins)) 

    return prob_ori, prob_input

def into_grid_prob_1d(pdf_ori_1d, pdf_input_1d,n_bins):
    num_boxes = n_bins*n_bins
    sliced_array = np.linspace(-35, 35, num_boxes)
    points_ori = pdf_ori_1d(sliced_array)
    points_input = pdf_input_1d(sliced_array)

    return points_ori, points_input

def img_stats_into_density_2d(oriDCLIPscore_img_stats,inputDCLIPscore_img_stats, cov_map, text_list,experiment_type,img_dir, n_bins):
    corr_total_num = len([a for a in combinations(np.arange(cov_map.shape[0]),2)])

    given_box = np.zeros((corr_total_num,2)) 
    num_boxes = n_bins*n_bins
    eps = 1e-10

    for corr_index, pair in tqdm(enumerate(combinations(np.arange(cov_map.shape[0]),2)), total = corr_total_num):
        prob_ori = np.zeros(num_boxes)
        prob_input = np.zeros(num_boxes)
        i,j = pair[0], pair[1]

        attr1=oriDCLIPscore_img_stats[:,i].detach().cpu().numpy()
        attr2=oriDCLIPscore_img_stats[:,j].detach().cpu().numpy()
        attr1_input=inputDCLIPscore_img_stats[:,i].detach().cpu().numpy()
        attr2_input=inputDCLIPscore_img_stats[:,j].detach().cpu().numpy()

        pdf_A, pdf_B,pdf_A_and_B = kde_2d(attr1, attr2)
        pdf_A_input, pdf_B_input,pdf_A_and_B_input = kde_2d(attr1_input, attr2_input)

        ##2d
        prob_ori2d, prob_input2d = into_grid_prob_2d(pdf_A_and_B, pdf_A_and_B_input,n_bins)
        prob_ori2d, prob_input2d = prob_ori2d+eps, prob_input2d+eps

        CE_P = -prob_ori2d * np.log(prob_ori2d)
        CE_Q = -prob_input2d *np.log(prob_input2d)
        CE_P_Q= -prob_ori2d * np.log(prob_input2d) 
        CE_Q_P= -prob_input2d * np.log(prob_ori2d) 

        KL_P_Q = np.mean(CE_P_Q-CE_P) #H(P,Q) - H(P)
        KL_Q_P = np.mean(CE_Q_P-CE_Q) #H(P,Q) - H(Q)
        JSD = (KL_P_Q+KL_Q_P)/2

        given_box[corr_index]= [KL_P_Q,JSD]  #KL(P,Q) JSD, MI(P,Q)

    return given_box



def img_stats_into_density_1d(oriDCLIPscore_img_stats,inputDCLIPscore_img_stats, cov_map, text_list,experiment_type,img_dir, n_bins):
    total_num = len(text_list)
    given_box = np.zeros((total_num,2))  #KL(P,Q) JSD
    num_boxes = n_bins*n_bins
    eps = 1e-10
    for index, text in tqdm(enumerate(text_list), total = total_num):
        prob_ori = np.zeros(num_boxes)
        prob_input = np.zeros(num_boxes)
        attr_ori = oriDCLIPscore_img_stats[:,index].detach().cpu().numpy()
        attr_input = inputDCLIPscore_img_stats[:,index].detach().cpu().numpy()

        pdf_ori, pdf_input = kde_1d(attr_ori, attr_input)
        prob_ori_1d, prob_input_1d = into_grid_prob_1d(pdf_ori, pdf_input, n_bins)
        prob_ori_1d, prob_input_1d = prob_ori_1d+eps, prob_input_1d+eps

        CE_P = -prob_ori_1d * np.log(prob_ori_1d)
        CE_Q = -prob_input_1d *np.log(prob_input_1d)
        CE_P_Q= -prob_ori_1d * np.log(prob_input_1d) 
        CE_Q_P= -prob_input_1d * np.log(prob_ori_1d) 

        KL_P_Q =  np.mean(CE_P_Q-CE_P) #H(P,Q) - H(P)
        KL_Q_P = np.mean(CE_Q_P-CE_Q) #H(P,Q) - H(Q)

        JSD = (KL_P_Q+KL_Q_P)/2
        given_box[index]= [KL_P_Q,JSD]  #KL(P,Q) JSD

    return given_box

def kde_1d(attr1, attr2):

    kde_A = gaussian_kde(attr1)
    kde_B = gaussian_kde(attr2)

    return kde_A, kde_B


def kde_2d(attr1, attr2):

    kde_A = gaussian_kde(attr1)
    kde_B = gaussian_kde(attr2)
    temp = np.vstack((attr1,attr2))
    kde_A_and_B = gaussian_kde(temp)

    return kde_A, kde_B, kde_A_and_B
