import json
import re
import requests
import numpy as np
import concurrent.futures
from tqdm import tqdm
from datasets import Dataset
from transformers import CLIPProcessor, CLIPModel
import torch
from IPython.display import display, Markdown
import sys
from io import BytesIO
from PIL import Image
from torchvision.transforms import ToTensor
from sklearn.preprocessing import normalize
import base64

# ==================== Configuration ====================
device = "cuda" if torch.cuda.is_available() else "cpu"

# Sanitized data paths
TEXT_DATA_PATH = "path/to/input.json"
IMAGE_VECTORS_PATH = "path/to/vectors.json"
IMAGE_TSV_PATH = "path/to/test_imgs.tsv"
API_URL = "http://<INTERNAL_API_ENDPOINT>/v1/chat/completions"

def load_text_data(filepath=TEXT_DATA_PATH):
    """Load and shuffle text dataset"""
    text_data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            text_data.append(json.loads(line))
    return Dataset.from_list(text_data).shuffle(seed=42)

def load_image_vectors(filepath=IMAGE_VECTORS_PATH):
    """Load pre-computed image embeddings"""
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

# ==================== Model Loading ====================
def load_models():
    """Initialize CLIP model and processor"""
    # Defaulting to ViT-B/16 as used in paper experiments [cite: 466]
    model_id = "openai/clip-vit-base-patch16"
    clip_model = CLIPModel.from_pretrained(model_id).to(device)
    clip_processor = CLIPProcessor.from_pretrained(model_id)
    return clip_model, clip_processor

def read_tsv_to_dict(filepath=IMAGE_TSV_PATH, return_tensors=False):
    """Read image TSV and convert to dictionary"""
    result_dict = {}
    error_count = 0
    
    with open(filepath, "r") as f:
        for line_num, line in enumerate(f, 1):
            parts = line.strip().split(',')
            if len(parts) < 2:
                error_count += 1
                continue
                
            img_id = parts[0].strip()
            base64_str = parts[1].strip()
            
            try:
                img = Image.open(BytesIO(base64.b64decode(base64_str))).convert("RGB")
                result_dict[img_id] = ToTensor()(img) if return_tensors else img
            except Exception as e:
                error_count += 1
                print(f"Row {line_num} processing failed [ID:{img_id}]: {str(e)}")
                continue
    
    display(Markdown(
        f"**Preprocessing Complete:<br>• Success: {len(result_dict)} items<br>• Failed: {error_count} items**"
    ))
    return result_dict, error_count

# ==================== Core Functionality ====================
def process_texts_one_by_one(text, model, processor):
    """Generate embedding for a single rewritten query"""
    with torch.no_grad():
        try:
            inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True).to(device)
            return model.get_text_features(**inputs).cpu().numpy()[0]
        except Exception as e:
            print(f"Text processing failed for '{text[:50]}...': {str(e)}")
            return np.nan

def get_similarity_with_ranking(text_embedding, original_image_index, image_vectors):
    """Calculate similarity scores and rank of the target image [cite: 191, 194]"""
    query_vector = normalize(np.array([text_embedding]), axis=1, norm='l2')
    all_indices = list(image_vectors.keys())
    all_vectors = normalize(np.array([image_vectors[idx] for idx in all_indices]), axis=1, norm='l2')
    
    # Cosine similarity via dot product of normalized vectors [cite: 211]
    similarity_scores = np.dot(query_vector, all_vectors.T).flatten()
    
    result = {
        'original_similarity': None,
        'rewrite_rank': None,
        'top3_matches': []
    }
    
    if original_image_index in all_indices:
        original_pos = all_indices.index(original_image_index)
        result['original_similarity'] = float(similarity_scores[original_pos])
        # Rank is defined as number of items with higher similarity + 1 [cite: 378]
        result['rewrite_rank'] = int(np.sum(similarity_scores > similarity_scores[original_pos]) + 1)
    
    top3_indices = np.argsort(similarity_scores)[-3:][::-1]
    for idx in top3_indices:
        result['top3_matches'].append({
            'index': all_indices[idx],
            'similarity': float(similarity_scores[idx]),
        })
    
    return result

def process_single_sample(args):
    """Thread function for per-sample LLM rewriting and evaluation"""
    sample, model_clip, processor, image_vectors = args
    
    try:
        text = sample["text_content"]
        # Multilingual prompt template from paper [cite: 1167]
        payload = {
            "model": "Qwen2.5-3B-Instruct",
            "messages": [
                {
                    "role": "user",
                    "content": f"You're an image retrieval assistant. Translate Chinese search queries: {text} into optimized English text for vector-based image search. Show your work in <think></think> tags. And return the final text in <answer></answer> tags."
                },
                {
                    "role": "assistant",
                    "content": "<think>\n"
                }
            ],
            "do_sample": True, 
            "temperature": 0.9,   
        }
        
        res = requests.post(API_URL, json=payload, timeout=30).json()['choices'][0]['message']
        full_output = res['content']
        
        # Regex to extract content within format tags [cite: 213]
        answer = re.search(r"<answer>([\s\S]*?)<\/answer>", full_output)
        
        if answer:
            rewritten_text = answer.group(1).strip()
            text_embedding = process_texts_one_by_one(rewritten_text, model_clip, processor)
            
            pre_result = get_similarity_with_ranking(
                text_embedding,
                sample['original_image_id'],
                image_vectors
            )
            
            return {
                'r1': int(pre_result['rewrite_rank'] == 1),
                'r10': int(pre_result['rewrite_rank'] <= 10),
                'success': True
            }
        else:
            print(f"Format error in response: {full_output[:100]}")
            return {'r1': 0, 'r10': 0, 'success': False}
            
    except Exception as e:
        print(f"Error processing sample: {str(e)}")
    
    return {'r1': 0, 'r10': 0, 'success': False}

def run_evaluation(dataset, model_clip, num_samples=50, batch_size=20, max_workers=32):
    """Execute threaded evaluation and report Recall metrics """
    clip_proc = clip_processor
    
    args_list = [(dataset[i], model_clip, clip_proc, image_vectors) 
                for i in range(min(num_samples, len(dataset)))]
    
    results = []
    with tqdm(total=len(args_list), desc="Evaluating Samples") as pbar:
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_sample = {executor.submit(process_single_sample, arg): idx for idx, arg in enumerate(args_list)}
            
            for future in concurrent.futures.as_completed(future_to_sample):
                try:
                    results.append(future.result())
                except Exception as e:
                    results.append({'r1': 0, 'r10': 0, 'success': False})
                finally:
                    pbar.update(1)
    
    # Statistical reporting
    successful = sum(1 for r in results if r['success'])
    r1_total = sum(r['r1'] for r in results)
    r10_total = sum(r['r10'] for r in results)
    
    print(f"\nEvaluation Summary:")
    print(f"Successfully processed: {successful}/{len(results)}")
    print(f"Recall@1:  {r1_total}/{len(results)} ({ (r1_total/len(results))*100:.2f}%)")
    print(f"Recall@10: {r10_total}/{len(results)} ({ (r10_total/len(results))*100:.2f}%)")

# ==================== Execution Flow ====================
if __name__ == "__main__":
    print("Loading data and vectors...")
    text_data = load_text_data()
    image_vectors = load_image_vectors()
    
    print("Initializing models...")
    clip_model, clip_processor = load_models()
    
    print("Starting threaded evaluation...")
    run_evaluation(
        dataset=text_data,
        model_clip=clip_model,
        num_samples=len(text_data),
        batch_size=50,
        max_workers=32
    )