
import os
import json
import torch
import torch.nn.functional as F
from torchvision import transforms as T
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import sys
sys.path.append('your/project/root/path')  # Adjust this path as needed
from evaluation.face_idloss import IDLoss
from utils.utils import denorm

ARCFACE_CKPT = "your/project/root/path/pretrained/face_idloss/model_ir_se50.pth"  # Adjust this path as needed
arcface = IDLoss(ARCFACE_CKPT).to(DEVICE).eval()


def load_image_as_tensor(image_path, size=256):
    try:
        img = Image.open(image_path).convert('RGB')
        transform = T.Compose([
            T.Resize((size, size)),
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        tensor = transform(img).unsqueeze(0).to(DEVICE)
        return tensor.detach()
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None


def load_tensor_from_pt(file_path):
    try:
        tensor = torch.load(file_path, map_location=DEVICE)
        if isinstance(tensor, dict):
            for key in ['tensor', 'image', 'data', 'img']:
                if key in tensor:
                    tensor = tensor[key]
                    break
            if isinstance(tensor, dict):
                tensor = list(tensor.values())[0]
        
        tensor = tensor.to(DEVICE)
        
        if tensor.dim() == 3:
            tensor = tensor.unsqueeze(0)
        
        return tensor.detach()
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None


def compute_id_loss(img1, img2):
    img1 = img1.to(DEVICE)
    img2 = img2.to(DEVICE)
    img1 = F.interpolate(denorm(img1), size=(112, 112), mode='bilinear')
    img2 = F.interpolate(denorm(img2), size=(112, 112), mode='bilinear')
    loss = arcface(img1, img2)
    return float(loss)


def process_result_folder(result_dir):
    results = {
        "per_image": {},
        "averages": {}
    }
    
    org_ids = []
    adv_ids = []
    diffs = []
    relative_ids = []
    
    img_folders = sorted([d for d in os.listdir(result_dir) if d.startswith('img_')])
    
    for img_folder in tqdm(img_folders, desc=f"Processing {os.path.basename(result_dir)}"):
        img_path = os.path.join(result_dir, img_folder)
        
        if not os.path.isdir(img_path):
            continue
        
        x_src_jpg = os.path.join(img_path, "x_src.jpg")
        x_src_pt = os.path.join(img_path, "x_src.pt")
        decoded_src_pt = os.path.join(img_path, "decoded_src.pt")
        decoded_adv_pt = os.path.join(img_path, "decoded_adv.pt")
        
        if os.path.exists(x_src_pt):
            x_src = load_tensor_from_pt(x_src_pt)
        elif os.path.exists(x_src_jpg):
            x_src = load_image_as_tensor(x_src_jpg)
        else:
            print(f"Warning: x_src not found in {img_path}")
            continue

        
        if not os.path.exists(decoded_src_pt) or not os.path.exists(decoded_adv_pt):
            print(f"Warning: decoded files not found in {img_path}")
            continue
        
        decoded_src = load_tensor_from_pt(decoded_src_pt)
        decoded_adv = load_tensor_from_pt(decoded_adv_pt)
        
        if x_src is None or decoded_src is None or decoded_adv is None:
            print(f"Warning: Failed to load tensors in {img_path}")
            continue
        
        try:
            org_id = compute_id_loss(x_src, decoded_src)
            adv_id = compute_id_loss(x_src, decoded_adv)
            diff = org_id - adv_id
            relative_id = max(0, 1 - org_id / adv_id) if adv_id != 0 else 0
            
            results["per_image"][img_folder] = {
                "org_id": org_id,
                "adv_id": adv_id,
                "diff": diff,
                "relative_id": relative_id
            }
            
            org_ids.append(org_id)
            adv_ids.append(adv_id)
            diffs.append(diff)
            relative_ids.append(relative_id)
            
        except Exception as e:
            print(f"Error computing ID loss for {img_folder}: {e}")
            continue
    
    if len(org_ids) > 0:
        results["averages"] = {
            "org_id_mean": sum(org_ids) / len(org_ids),
            "adv_id_mean": sum(adv_ids) / len(adv_ids),
            "diff_mean": sum(diffs) / len(diffs),
            "relative_id_mean": sum(relative_ids) / len(relative_ids),
            "total_images": len(org_ids)
        }
    
    return results


def extract_method_and_model(result_folder_name, parent_folder_name):
    method = parent_folder_name.replace("batch_results_", "")
    
    parts = result_folder_name.replace("result_", "").split("_")
    
    model_parts = []
    method_parts = method.split("_")
    
    for i, part in enumerate(parts):
        remaining = "_".join(parts[i:])
        if remaining == method or remaining.endswith(method):
            break
        model_parts.append(part)
    
    model = "_".join(model_parts) if model_parts else parts[0]
    
    return method, model


def main():
    output_dir = "your/project/root/path/id_result"  # Adjust this path as needed
    os.makedirs(output_dir, exist_ok=True)
    
    batch_dirs = [
        '''your result paths here
        e.g.
        "/batch_results_pgd", 
        "/batch_results_anti",
        "/batch_results_df_rap",
        /batch_results_disrupting",
        "/batch_results_leat",
        "/batch_results_nullswap",
        "/batch_results_scol",'''

    ]
    
    for batch_dir in batch_dirs:
        if not os.path.exists(batch_dir):
            print(f"Warning: {batch_dir} does not exist, skipping...")
            continue
        
        parent_name = os.path.basename(batch_dir)
        print(f"\n{'='*60}")
        print(f"Processing: {parent_name}")
        print(f"{'='*60}")
        
        result_folders = [d for d in os.listdir(batch_dir) if d.startswith('result_')]
        
        for result_folder in sorted(result_folders):
            result_path = os.path.join(batch_dir, result_folder)
            
            if not os.path.isdir(result_path):
                continue
            
            method, model = extract_method_and_model(result_folder, parent_name)
            
            print(f"\nProcessing: {result_folder}")
            print(f"  Method: {method}, Model: {model}")
            
            results = process_result_folder(result_path)
            
            if results["averages"]:
                output_filename = f"{model}_{method}.json"
                output_path = os.path.join(output_dir, output_filename)
                
                with open(output_path, 'w') as f:
                    json.dump(results, f, indent=2)
                
                print(f"  Saved: {output_path}")
                print(f"  Averages:")
                print(f"    org_id_mean: {results['averages']['org_id_mean']:.6f}")
                print(f"    adv_id_mean: {results['averages']['adv_id_mean']:.6f}")
                print(f"    diff_mean:   {results['averages']['diff_mean']:.6f}")
                print(f"    relative_id_mean: {results['averages']['relative_id_mean']:.4f}")
                print(f"    total_images: {results['averages']['total_images']}")
            else:
                print(f"  No valid results for {result_folder}")
    
    print(f"\n{'='*60}")
    print(f"All results saved to: {output_dir}")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()
