import os
import json
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
from tqdm import tqdm
import shutil
import tempfile
from subprocess import check_output
import subprocess
from pytorch_fid import fid_score 
import pdb
from skimage import io
from skimage.metrics import structural_similarity as ssim
import numpy as np
from skimage.transform import resize

def calculate_clip_similarity(image1_path, image2_path, model, processor, device):
    # 读取两张输入图片
    image1 = Image.open(image1_path)
    image2 = Image.open(image2_path)
    
    # 对图片进行预处理
    inputs1 = processor(images=image1, return_tensors="pt").to(device)
    inputs2 = processor(images=image2, return_tensors="pt").to(device)
    
    # 获取图片的特征向量
    with torch.no_grad():
        embedding1 = model.get_image_features(**inputs1)
        embedding2 = model.get_image_features(**inputs2)
    
    # 对特征向量进行归一化
    embedding1 /= embedding1.norm(dim=-1, keepdim=True)
    embedding2 /= embedding2.norm(dim=-1, keepdim=True)
    
    # 计算余弦相似度
    similarity = (embedding1 * embedding2).sum()
    return similarity.item()


def calculate_clip_similarity_text_image(image_path, text, model, processor, device):
    image = Image.open(image_path)

    inputs_image = processor(images=image, return_tensors="pt").to(device)
    inputs_text = processor(text=text, return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        image_embedding = model.get_image_features(**inputs_image)
        text_embedding = model.get_text_features(inputs_text.input_ids)

    image_embedding /= image_embedding.norm(dim=-1, keepdim=True)
    text_embedding /= text_embedding.norm(dim=-1, keepdim=True)

    similarity = (image_embedding * text_embedding).sum()
    return similarity.item()


base_paths = [
    "baseline_images_explicit",
    "baseline_images_implicit"
]

methods = ['before_unlearn', 'np', 'ours', 'sld_v1', 'sld_v2', 'sld_v3', 'sega_v1', 'sega_v2', 'sega_v3']




######################################## metric 1: ClipScore ########################################

print("\nstart calculating Metric 1 (ClipScore)...")

device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_path).to(device)
processor = CLIPProcessor.from_pretrained(model_path)

for base_path in tqdm(base_paths):
    print(f"\nprocessing: {base_path}")
    
    for method in tqdm(methods):
        logo_path = os.path.join(base_path, f"{method}_cropped_logo")
        if not os.path.exists(logo_path):
            print(f"path not exists: {logo_path}")
            continue
            
        company_scores = {}

        png_files = [f for f in os.listdir(logo_path) if f.endswith('.png')]

        processed_companies = set()
        for png_file in tqdm(png_files):
            company = png_file.split('_')[0]
            if company in processed_companies:
                continue
                
            processed_companies.add(company)
            company_images = [f for f in png_files if f.startswith(company)]
            
            if not company_images:
                continue

            scores = []
            text_prompt = f"{company} logo"
            
            for img_file in company_images:
                img_path = os.path.join(logo_path, img_file)
                try:
                    score = calculate_clip_similarity_text_image(
                        img_path,
                        text_prompt,
                        model,
                        processor,
                        device
                    )
                    scores.append(score)
                except Exception as e:
                    print(f"processing image error: {img_path}: {str(e)}")
            
            if scores:
                company_scores[company] = sum(scores) / len(scores)
        
        if company_scores:
            overall_average = sum(company_scores.values()) / len(company_scores)
            company_scores['overall_average'] = overall_average

        output_path = os.path.join(logo_path, 'metric_1.json')
        try:
            with open(output_path, 'w') as f:
                json.dump(company_scores, f, indent=4)
            print(f"saved metric 1 result to: {output_path}")
        except Exception as e:
            print(f"save file {output_path} error: {str(e)}")







######################################## metric 2: LogoScore ########################################

device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_path).to(device)
processor = CLIPProcessor.from_pretrained(model_path)

with open("companies.json", 'r') as f:
    companies = json.load(f)

for base_path in base_paths:
    print(f"processing logo similarity: {base_path}")
    before_unlearn_path = os.path.join(base_path, "before_unlearn_cropped_logo")

    for category in tqdm([cat for cat in methods if cat != 'before_unlearn']):
        category_path = os.path.join(base_path, f"{category}_cropped_logo")
        if not os.path.isdir(category_path):
            continue
            
        company_scores = {}
        
        for company in tqdm(companies):
            scores = []
            for i in range(1, 11):
                before_img_path = os.path.join(before_unlearn_path, f"{company}_{i}_cropped.png")
                current_img_path = os.path.join(category_path, f"{company}_{i}_cropped.png")
                
                if os.path.exists(before_img_path) and os.path.exists(current_img_path):
                    try:
                        similarity = calculate_clip_similarity(
                            before_img_path, 
                            current_img_path, 
                            model, 
                            processor, 
                            device
                        )
                        scores.append(similarity)
                    except Exception as e:
                        print(f"calculate similarity error: {company}_{i}: {str(e)}")
                else:
                    scores.append(0)  
            
            if scores:
                company_scores[company] = sum(scores) / len(scores)
            else:
                company_scores[company] = 0

        if company_scores:
            overall_average = sum(company_scores.values()) / len(company_scores)
            company_scores['overall_average'] = overall_average
        
        output_path = os.path.join(category_path, 'metric_2.json')
        try:
            with open(output_path, 'w') as f:
                json.dump(company_scores, f, indent=4)
            print(f"saved logo similarity result to: {output_path}")
        except Exception as e:
            print(f"save file {output_path} error: {str(e)}")








######################################## metric 3: LogoSSIM ########################################



base_paths = [
    "baseline_images_explicit",
    "baseline_images_implicit"
]

methods = ['before_unlearn', 'np', 'ours', 'sld_v1', 'sld_v2', 'sld_v3', 'sega_v1', 'sega_v2', 'sega_v3']

def robust_resize(image1, image2):

    min_height = min(image1.shape[0], image2.shape[0])
    min_width = min(image1.shape[1], image2.shape[1])

    min_height = max(7, min_height)
    min_width = max(7, min_width)
    
    image1_resized = resize(image1, (min_height, min_width), anti_aliasing=True)
    image2_resized = resize(image2, (min_height, min_width), anti_aliasing=True)

    return image1_resized, image2_resized

device = "cuda" if torch.cuda.is_available() else "cpu"

with open("companies.json", 'r') as f:
    companies = json.load(f)

for base_path in base_paths:
    print(f"processing logo similarity: {base_path}")
    before_unlearn_path = os.path.join(base_path, "before_unlearn_cropped_logo")
    
    for category in tqdm([cat for cat in methods if cat != 'before_unlearn']):
        category_path = os.path.join(base_path, f"{category}_cropped_logo")
        if not os.path.isdir(category_path):
            continue
            
        company_scores = {}
        
        for company in tqdm(companies):
            scores = []
            for i in range(1, 11):
                before_img_path = os.path.join(before_unlearn_path, f"{company}_{i}_cropped.png")
                current_img_path = os.path.join(category_path, f"{company}_{i}_cropped.png")
                
                if os.path.exists(before_img_path) and os.path.exists(current_img_path):
                    image1 = io.imread(before_img_path)
                    image2 = io.imread(current_img_path)    
                    image1_resized, image2_resized = robust_resize(image1, image2)
                    # pdb.set_trace()
                    similarity, _ = ssim(image1_resized, image2_resized, full=True, data_range=1, channel_axis=-1)
                    # pdb.set_trace()
                    scores.append(similarity)
                else:
                    scores.append(0)  
            
            if scores:
                company_scores[company] = sum(scores) / len(scores)
            else:
                company_scores[company] = 0

        if company_scores:
            overall_average = sum(company_scores.values()) / len(company_scores)
            company_scores['overall_average'] = overall_average
        
        output_path = os.path.join(category_path, 'metric_3.json')
        try:
            with open(output_path, 'w') as f:
                json.dump(company_scores, f, indent=4)
            print(f"saved logo similarity result to: {output_path}")
        except Exception as e:
            print(f"save file {output_path} error: {str(e)}")







######## metric 4: ImageScore ########

base_paths = [
    "baseline_images_explicit",
    "baseline_images_implicit"
]

methods = ['before_unlearn', 'np', 'ours', 'sld_v1', 'sld_v2', 'sld_v3', 'sega_v1', 'sega_v2', 'sega_v3']

difficulty_list = ['explicit', 'implicit']

model_path = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_path).to(device)
processor = CLIPProcessor.from_pretrained(model_path)

device = "cuda" if torch.cuda.is_available() else "cpu"

with open("companies.json", 'r') as f:
    companies = json.load(f)


for difficulty in tqdm(difficulty_list):
    main_path = os.path.join(base_path, f'baseline_images_{difficulty}')
    
    before_unlearn_path = os.path.join(main_path, 'before_unlearn')
    
    for method in tqdm(methods):
        print(f"Processing {difficulty} - {method}")
        method_path = os.path.join(main_path, method)

        company_scores = {}
        
        for company in tqdm(companies):
            scores = []
            
            for i in range(1, 11):
                before_img_path = os.path.join(before_unlearn_path, company, f'{i}.png')
                method_img_path = os.path.join(method_path, company, f'{i}.png')
                
                if os.path.exists(before_img_path) and os.path.exists(method_img_path):
                    similarity = calculate_clip_similarity(
                        before_img_path, 
                        current_img_path, 
                        model, 
                        processor, 
                        device
                    )
                    scores.append(similarity)

            if scores:
                avg_score = sum(scores) / len(scores)
                company_scores[company] = avg_score

        if company_scores:
            overall_average = sum(company_scores.values()) / len(company_scores)
            company_scores['overall_average'] = overall_average

        company_scores['overall_average'] = sum(company_scores.values()) / len(company_scores)

        output_path = os.path.join(method_path, 'metric_4.json')
        with open(output_path, 'w') as f:
            json.dump(company_scores, f, indent=4)








######## metric 5: ImageSSIM ########

base_paths = [
    "baseline_images_explicit",
    "baseline_images_implicit"
]

methods = ['before_unlearn', 'np', 'ours', 'sld_v1', 'sld_v2', 'sld_v3', 'sega_v1', 'sega_v2', 'sega_v3']

difficulty_list = ['explicit', 'implicit']


device = "cuda" if torch.cuda.is_available() else "cpu"

with open("companies.json", 'r') as f:
    companies = json.load(f)


for difficulty in tqdm(difficulty_list):
    main_path = os.path.join(base_path, f'baseline_images_{difficulty}')
    
    before_unlearn_path = os.path.join(main_path, 'before_unlearn')
    
    for method in tqdm(methods):
        print(f"Processing {difficulty} - {method}")
        method_path = os.path.join(main_path, method)

        company_scores = {}
        
        for company in tqdm(companies):
            scores = []
            
            for i in range(1, 11):
                before_img_path = os.path.join(before_unlearn_path, company, f'{i}.png')
                method_img_path = os.path.join(method_path, company, f'{i}.png')
                
                if os.path.exists(before_img_path) and os.path.exists(method_img_path):
                    image1 = io.imread(before_img_path)
                    image2 = io.imread(method_img_path)    
                    similarity, _ = ssim(image1, image2, full=True, data_range=255, channel_axis=-1)
                    scores.append(similarity)

            if scores:
                avg_score = sum(scores) / len(scores)
                company_scores[company] = avg_score

        company_scores['overall_average'] = sum(company_scores.values()) / len(company_scores)

        output_path = os.path.join(method_path, 'metric_5.json')
        with open(output_path, 'w') as f:
            json.dump(company_scores, f, indent=4)

