import os
import torch
import random
import open_clip
import numpy as np
from glob import glob
from PIL import Image
from tqdm import tqdm


def set_seed(seed: int = 42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def measure_similarity(images, prompt, model, clip_preprocess, tokenizer, device):
    with torch.no_grad():
        img_batch = [clip_preprocess(i).unsqueeze(0) for i in images]
        img_batch = torch.concatenate(img_batch).to(device)
        image_features = model.encode_image(img_batch)

        text = tokenizer([prompt]).to(device)
        text_features = model.encode_text(text)
        
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        
        return (image_features @ text_features.T).mean(-1)


def get_clip_score(images, model, clip_preprocess, device):
    with torch.no_grad():
        img_batch = [clip_preprocess(i).unsqueeze(0) for i in images]
        img_batch = torch.concatenate(img_batch).to(device)
        image_features = model.encode_image(img_batch)
        
        image_features /= image_features.norm(dim=-1, keepdim=True)

        image_feature1 = image_features[0].unsqueeze(0)
        image_feature2 = image_features[1].unsqueeze(0)
        
        return (image_feature1 @ image_feature2.T).mean(-1)


seed = 42
set_seed(seed)

device='cuda:2'
source_dir = 'gen_pqim'
source_dir = 'gen_gs'
source_dir = 'gen_zod_lq'
gt_dir = 'output_images_wo_wm'
source_files = glob(f'{source_dir}/**.png')
gt_files = glob(f'{gt_dir}/**.png')
source_files.sort()
gt_files.sort()

# source_files = source_files[:10]
# gt_files = gt_files[:10]

### varify sorting ###
assert len(source_files) == len(gt_files), \
    f'length of source_files and gt_files mismatched, {len(source_files)} != {len(gt_files)}'

for source_file, gt_file in zip(source_files, gt_files):
    source_name = os.path.basename(source_file).split('-')[0]
    gt_name = os.path.basename(gt_file).split('.')[0]
    assert source_name == gt_name, f'source and gt sorting dismatched, {source_name} != {gt_name}'
### varify sorting ###

# file_path = 'prompts.txt'
# captions = []

# with open(file_path, 'r', encoding='utf-8') as file:
#     for line in file:
#         clean_line = line.strip()
#         captions.append(clean_line)
#         # print(clean_line)

# print("\n--- 리스트 저장 결과 ---")
# print(captions)


reference_model = 'ViT-g-14'
reference_model_pretrain = 'laion2b_s12b_b42k'
ref_model, _, ref_clip_preprocess = open_clip.create_model_and_transforms(
    reference_model,
    pretrained=reference_model_pretrain,
    device=device
)
ref_tokenizer = open_clip.get_tokenizer(
    reference_model,
)

sims = []
for gt_file, source_file in tqdm(zip(gt_files, source_files), total=len(gt_files)):
    wm_image = Image.open(source_file)
    gt_image = Image.open(gt_file)

    sim = get_clip_score(
        [gt_image, wm_image],
        ref_model,
        ref_clip_preprocess,
        device,
    )
    sims.append(sim.item())

print('-'*50)
print(f'CLIPScore {sum(sims) / len(sims)}')
print('-'*50)
print('-'*50)

# pqim: CLIPScore 0.9771412033438682
# zod: CLIPScore 0.978798391461372
# zod_lq: CLIPScore 0.9752343223690987
# gs: CLIPScore 0.8351716994941235