import json
import re
import requests
import numpy as np
import concurrent.futures
from tqdm import tqdm
from datasets import Dataset
from transformers import CLIPProcessor, CLIPModel
import torch
from IPython.display import display, Markdown
import sys
import numpy as np
import base64
import json
from IPython.display import display, Markdown
import sys
import numpy as np
import json
from IPython.display import display, Markdown
from datasets import Dataset
from io import BytesIO
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import CLIPProcessor, CLIPModel, AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm.auto import tqdm
import re
from sklearn.preprocessing import normalize
import base64
import re
import requests

# ==================== 配置部分 ====================
device = "cuda" if torch.cuda.is_available() else "cpu"

# 数据路径配置
TEXT_DATA_PATH = ""
IMAGE_VECTORS_PATH = ""
IMAGE_TSV_PATH = ""

# ==================== 数据加载 ====================
def load_text_data(filepath=TEXT_DATA_PATH):
    """加载文本数据"""
    text_data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            text_data.append(json.loads(line))
    return Dataset.from_list(text_data).shuffle(seed=42)

def load_image_vectors(filepath=IMAGE_VECTORS_PATH):
    """加载图像向量"""
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

# ==================== 模型加载 ====================
def load_models():
    """加载所有模型"""
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
    return clip_model, clip_processor

def read_tsv_to_dict(filepath=IMAGE_TSV_PATH, return_tensors=False):
    """读取TSV文件到字典"""
    result_dict = {}
    error_count = 0
    
    with open(filepath, "r") as f:
        for line_num, line in enumerate(f, 1):
            # parts = line.strip().split('\t')
            parts = line.strip().split(',')
            if len(parts) < 2:
                error_count += 1
                continue
                
            img_id = parts[0].strip()
            base64_str = parts[1].strip()
            
            try:
                img = Image.open(BytesIO(base64.b64decode(base64_str))).convert("RGB")
                result_dict[img_id] = ToTensor()(img) if return_tensors else img
            except Exception as e:
                error_count += 1
                print(f"行 {line_num} 处理失败 [ID:{img_id}]: {str(e)}")
                continue
    
    display(Markdown(
        f"**CLIP格式预处理完成：<br>• 成功: {len(result_dict)} 条<br>• 失败: {error_count} 条**"
    ))
    return result_dict, error_count
# ==================== 核心功能 ====================
def process_texts_one_by_one(text, model, processor):
    """处理单个文本生成embedding"""
    with torch.no_grad():
        try:
            inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True).to(device)
            return model.get_text_features(**inputs).cpu().numpy()[0]
        except Exception as e:
            print(f"❌ 处理文本 '{text[:50]}...' 失败: {str(e)}")
            return np.nan

def get_similarity_with_ranking(text_embedding, original_image_index, image_vectors, image_dict):
    """计算相似度和排名"""
    query_vector = normalize(np.array([text_embedding]), axis=1, norm='l2')
    all_indices = list(image_vectors.keys())
    all_vectors = normalize(np.array([image_vectors[idx] for idx in all_indices]), axis=1, norm='l2')
    
    similarity_scores = np.dot(query_vector, all_vectors.T).flatten()
    
    result = {
        'original_similarity': None,
        'rewrite_rank': None,
        'top3_matches': []
    }
    
    if original_image_index in all_indices:
        original_pos = all_indices.index(original_image_index)
        result['original_similarity'] = float(similarity_scores[original_pos])
        result['rewrite_rank'] = int(np.sum(similarity_scores > similarity_scores[original_pos]) + 1)
    
    top3_indices = np.argsort(similarity_scores)[-3:][::-1]
    for idx in top3_indices:
        result['top3_matches'].append({
            'index': all_indices[idx],
            'similarity': float(similarity_scores[idx]),
        })
    
    return result

def process_single_sample(args):
    """处理单个样本的线程函数"""
    sample, model_clip, processor, image_vectors, image_dict = args
    
    try:
        text = sample["text_content"]
        res = requests.post('', json={
            "model": "Qwen2.5-3B-Instruct",
            "messages": [
                {
                    "role": "user",
                    "content": f"You're an image retrieval assistant. Translate search queries:{text} into optimized English text for vector-based image search.Show your work in <think> </think> tags. And return the final text in <answer> </answer> tags."
                },
                # {
                #     "role": "assistant",
                #     "content": "<think>\n"
                # }
            ]
        }, timeout=30).json()['choices'][0]['message']
        
        full_output = res['content']
        answer = re.search(r"<answer>([\s\S]*?)<\/answer>", full_output)
        
        if answer:
            text_embedding = process_texts_one_by_one(
                answer.group(1).strip(), 
                model_clip, 
                processor
            )
            
            pre_result = get_similarity_with_ranking(
                text_embedding,
                sample['original_image_id'],
                image_vectors,
                image_dict
            )
            
            return {
                'r1': int(pre_result['rewrite_rank'] == 1),
                'r10': int(pre_result['rewrite_rank'] <= 10),
                'success': True
            }
        if not answer:
            print(f'回答格式有误，完整的回答是：{full_output}')
            return {'r1': 0, 'r10': 0, 'success': False}
            
    except Exception as e:
        print(f"处理样本出错: {str(e)}")
    
    return {'r1': 0, 'r10': 0, 'success': False}

def inspect_model_outputs_valid_threaded(dataset, model_clip, num_samples=50, batch_size=20, max_workers=32):
    """多线程版本的模型输出评估"""
    processor = clip_processor
    
    # 准备参数列表
    args_list = [(dataset[i], model_clip, processor, image_vectors, image_dict) 
                for i in range(min(num_samples, len(dataset)))]
    
    results = []
    with tqdm(total=len(args_list), desc="Processing samples") as pbar:
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_sample = {
                executor.submit(process_single_sample, arg): idx 
                for idx, arg in enumerate(args_list)
            }
            
            for future in concurrent.futures.as_completed(future_to_sample):
                try:
                    results.append(future.result())
                except Exception as e:
                    print(f"线程执行出错: {str(e)}")
                    results.append({'r1': 0, 'r10': 0, 'success': False})
                finally:
                    pbar.update(1)
    
    # 统计结果
    successful = sum(1 for r in results if r['success'])
    r1_results = sum(r['r1'] for r in results)
    r10_results = sum(r['r10'] for r in results)
    
    # 批次统计
    for i in range(0, len(results), batch_size):
        batch = results[i:i+batch_size]
        batch_r1 = sum(r['r1'] for r in batch)
        batch_r10 = sum(r['r10'] for r in batch)
        print(f"批次 {i//batch_size + 1}: "
              f"top1={batch_r1}/{len(batch)}({batch_r1/len(batch)*100:.1f}%), "
              f"top10={batch_r10}/{len(batch)}({batch_r10/len(batch)*100:.1f}%)")
    
    # 最终统计
    print(f"\n成功处理样本数: {successful}/{len(results)}")
    final_rate_r1 = (r1_results / len(results)) * 100
    final_rate_r10 = (r10_results / len(results)) * 100
    print(f"最终结果: top1={r1_results}/{len(results)}({final_rate_r1:.1f}%)")
    print(f"最终结果: top10={r10_results}/{len(results)}({final_rate_r10:.1f}%)")

# ==================== 主执行流程 ====================
if __name__ == "__main__":
    # 加载数据
    print("加载文本数据...")
    text_data = load_text_data()

    print("加载图像向量...")
    image_vectors = load_image_vectors()

    print("加载图像TSV文件...")
    image_dict, _ = read_tsv_to_dict()
    # print(image_dict)
    # 加载模型
    print("加载模型...")
    clip_model, clip_processor = load_models()
    processor = clip_processor
    
    # 执行评估
    print("开始多线程评估...")
    inspect_model_outputs_valid_threaded(
        dataset=text_data,
        model_clip=clip_model,
        num_samples=len(text_data),
        batch_size=50,
        max_workers=32  # 可根据API服务器性能调整
    )