
"""
支持配置文件的多卡DDP测评代码
"""

import os
import sys
import json
import re
import logging
import argparse
import random
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from tqdm import tqdm
import time
from datetime import datetime
import yaml

def load_config(config_path):
    """加载配置文件"""
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    return config

def setup_logging(config, rank):
    """设置日志记录"""
    log_dir = config['training']['log_dir']
    ckpt_path = config['model']['ckpt_path']
    log_dir = os.path.join(log_dir, ckpt_path.split('/')[-1])
    if rank == 0:
        os.makedirs(log_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = os.path.join(log_dir, f"eval_{timestamp}.log")
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file, encoding='utf-8'),
                logging.StreamHandler(sys.stdout)
            ]
        )
        return logging.getLogger(__name__)
    else:
        logging.basicConfig(level=logging.WARNING)
        return logging.getLogger(__name__)

def set_seed(seed):
    """设置随机种子确保可复现性"""
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    
    
    np.random.seed(seed)
    
    
    random.seed(seed)

def setup_ddp(rank, world_size, config):
    """设置分布式训练"""
    os.environ['MASTER_ADDR'] = config['system']['master_addr']
    os.environ['MASTER_PORT'] = config['system']['master_port']
    dist.init_process_group(config['system']['backend'], rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup_ddp():
    """清理分布式训练"""
    dist.destroy_process_group()

def get_image_type(dataset_name: str, lang: str) -> str:
    """获取图像类型描述"""
    if lang == "en":
        if "MAP" in dataset_name:
            return "one map and one satellite image"
        elif "VIGOR" in dataset_name:
            return "one panorama image and one satellite image"
        elif "U1652" in dataset_name:
            return "one aerial image and one satellite image"
        elif "SetVL" in dataset_name:
            return "one image taken on the ground and one satellite image"
    elif lang == "zh":
        if "MAP" in dataset_name:
            return "一张地图和一张卫星图"
        elif "VIGOR" in dataset_name:
            return "一张全景图和一张卫星图"
        elif "U1652" in dataset_name:
            return "一张航空影像和一张卫星图"
        elif "SetVL" in dataset_name:
            return "一张地面拍摄的图像和一张卫星图"
    raise ValueError(f"Invalid dataset name or language: {dataset_name}, {lang}")

def build_prompt(dataset_name: str, lang: str) -> str:
    """构建提示词"""
    img_type = get_image_type(dataset_name, lang)
    if lang == "en":
        return f"Two images are provided ({img_type}). Do these two images correspond to the same area? If true, answer [[1]]. If false, answer [[0]], then explain why in detail."
    elif lang == "zh":
        return f"给出两张图片（{img_type}）。这两张图片是否对应同一个地区？如果是，回答 [[1]]。如果不是，回答 [[0]]，然后详细解释为什么。"
    else:
        raise ValueError(f"Invalid language: {lang}")

class GeoDataset(Dataset):
    """地理数据集类"""
    def __init__(self, data_path, lang="en"):
        self.data = []
        self.lang = lang
        
        with open(data_path, "r") as f:
            for line in f:
                data = json.loads(line)
                self.data.append(data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def evaluate_sample(model, processor, data, dataset_name, lang, device, config):
    """评估单个样本"""
    prompt = build_prompt(dataset_name, lang)
    pos_neg = {0: "negative", 1: "positive"}
    
    results = []
    
    for label in [0, 1]:
        pos_neg_type = pos_neg[label]
        assert len(data['query']) == 1, "query should be a list with length 1"
        assert len(data[pos_neg_type]) == 1, "pos_neg_type should be a list with length 1"
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": os.path.join(config['data']['image_folder'], data['query'][0]),
                    },
                    {
                        "type": "image",
                        "image": os.path.join(config['data']['image_folder'], data[pos_neg_type][0]),
                    },
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        
        
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(device)
        
        
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs, 
                max_new_tokens=config['generation']['max_new_tokens'],
                temperature=config['generation']['temperature'],
                do_sample=config['generation']['do_sample'],
                pad_token_id=processor.tokenizer.eos_token_id,
                use_cache=False,  
                
                top_p=config['generation'].get('top_p', 1.0),
                top_k=config['generation'].get('top_k', 50),
                repetition_penalty=config['generation'].get('repetition_penalty', 1.0),
                length_penalty=config['generation'].get('length_penalty', 1.0),
                no_repeat_ngram_size=config['generation'].get('no_repeat_ngram_size', 0),
                num_beams=config['generation'].get('num_beams', 1),
                early_stopping=config['generation'].get('early_stopping', False)
            )
        
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        
        
        matches = re.findall(r"\[\[(\d+)\]\]", output_text)
        if not matches:
            matches = re.findall(r"\[(\d+)\]", output_text)
        
        if len(matches) == 0:
            predicted_label = None
            score = 0
        else:
            predicted_label = int(matches[0])
            score = 1 if predicted_label == label else 0
        
        results.append({
            'label': label,
            'pos_neg_type': pos_neg_type,
            'predicted_label': predicted_label,
            'output_text': output_text,
            'score': score,
            'correct': score == 1,
            'messages': messages,
            'prompt': prompt,
            'input_text': text,
            'generation_params': {
                'temperature': config['generation']['temperature'],
                'max_new_tokens': config['generation']['max_new_tokens'],
                'do_sample': config['generation']['do_sample'],
                'top_p': config['generation'].get('top_p', 1.0),
                'top_k': config['generation'].get('top_k', 50),
                'repetition_penalty': config['generation'].get('repetition_penalty', 1.0),
                'length_penalty': config['generation'].get('length_penalty', 1.0),
                'no_repeat_ngram_size': config['generation'].get('no_repeat_ngram_size', 0),
                'num_beams': config['generation'].get('num_beams', 1),
                'early_stopping': config['generation'].get('early_stopping', False)
            }
        })
    
    return results

def main(rank, world_size, config):
    """主函数"""
    
    set_seed(config['training']['seed'])
    
    
    setup_ddp(rank, world_size, config)
    
    
    logger = setup_logging(config, rank)
    
    if rank == 0:
        logger.info(f"开始测评，配置: {config}")
        logger.info(f"使用 {world_size} 个GPU")
        logger.info(f"随机种子: {config['training']['seed']}")
    
    
    device = torch.device(f"cuda:{rank}")
    
    
    if rank == 0:
        logger.info("正在加载模型...")
    
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        config['model']['ckpt_path'],
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map={"": rank},  
    )
    
    
    model = DDP(model, device_ids=[rank], output_device=rank)
    
    
    processor = AutoProcessor.from_pretrained(
        config['model']['ckpt_path'], 
        min_pixels=config['model']['min_pixels'], 
        max_pixels=config['model']['max_pixels']
    )
    
    if rank == 0:
        logger.info("模型加载完成")
    
    
    dataset = GeoDataset(config['data']['data_path'], config['data']['lang'])
    
    
    if hasattr(dataset, 'data') and len(dataset.data) > 0:
        
        sort_key = config['data'].get('sort_key', 'dataset')
        if sort_key in dataset.data[0]:
            dataset.data.sort(key=lambda x: x.get(sort_key, ''))
    
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)
    dataloader = DataLoader(
        dataset, 
        batch_size=config['training']['batch_size'], 
        sampler=sampler, 
        num_workers=config['training']['num_workers']
    )
    
    
    total_score = 0
    total_samples = 0
    
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = os.path.join(
        os.path.join(config['training']['log_dir'], config['model']['ckpt_path'].split('/')[-1]), 
        f"eval_results_rank{rank}_{timestamp}.jsonl"
    )
    
    if rank == 0:
        logger.info("开始评估...")
        logger.info(f"结果将保存到: {results_file}")
    
    start_time = time.time()
    
    for i, data in enumerate(tqdm(dataloader, disable=rank != 0)):
        dataset_name = data["dataset"][0]  
        
        try:
            results = evaluate_sample(
                model.module,  
                processor, 
                data, 
                dataset_name, 
                config['data']['lang'], 
                device,
                config
            )
            
            
            sample_score = sum(r['score'] for r in results)
            total_score += sample_score
            total_samples += len(results)
            
            
            result_entry = {
                'rank': rank,
                'sample_id': i,
                'dataset': dataset_name,
                'timestamp': datetime.now().isoformat(),
                'data': data,
                'results': results,
                'sample_score': sample_score,
                'current_total_score': total_score,
                'current_total_samples': total_samples,
                'current_accuracy': total_score / total_samples if total_samples > 0 else 0
            }
            
            with open(results_file, 'a', encoding='utf-8') as f:
                f.write(json.dumps(result_entry, ensure_ascii=False) + '\n')
            
            if rank == 0 and (i + 1) % 100 == 0:
                current_acc = total_score / total_samples
                logger.info(f"已处理 {i + 1} 个样本，当前准确率: {current_acc:.4f}")
                
        except Exception as e:
            if rank == 0:
                logger.error(f"处理样本 {i} 时出错: {str(e)}")
                logger.error(f"错误详情: {type(e).__name__}: {str(e)}")
            continue
    
    
    total_score_tensor = torch.tensor(total_score, device=device)
    total_samples_tensor = torch.tensor(total_samples, device=device)
    
    dist.all_reduce(total_score_tensor, op=dist.ReduceOp.SUM)
    dist.all_reduce(total_samples_tensor, op=dist.ReduceOp.SUM)
    
    final_score = total_score_tensor.item()
    final_samples = total_samples_tensor.item()
    final_accuracy = final_score / final_samples if final_samples > 0 else 0
    
    end_time = time.time()
    
    
    summary_file = os.path.join(
        os.path.join(config['training']['log_dir'], config['model']['ckpt_path'].split('/')[-1]), 
        f"eval_summary_rank{rank}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    )
    
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump({
            'rank': rank,
            'config': config,
            'local_score': total_score,
            'local_samples': total_samples,
            'local_accuracy': total_score / total_samples if total_samples > 0 else 0,
            'total_time': end_time - start_time,
            'results_file': results_file,
            'seed_used': config['training']['seed'],
            'timestamp': datetime.now().isoformat()
        }, f, ensure_ascii=False, indent=2)
    
    if rank == 0:
        logger.info(f"评估完成!")
        logger.info(f"本地分数: {total_score}")
        logger.info(f"本地样本数: {total_samples}")
        logger.info(f"本地准确率: {total_score / total_samples if total_samples > 0 else 0:.4f}")
        logger.info(f"总耗时: {end_time - start_time:.2f} 秒")
        logger.info(f"本地统计结果已保存到: {summary_file}")
        logger.info(f"本地详细结果已保存到: {results_file}")
        
        
        logger.info(f"所有GPU结果汇总:")
        logger.info(f"总分数: {final_score}")
        logger.info(f"总样本数: {final_samples}")
        logger.info(f"最终准确率: {final_accuracy:.4f}")
    
    
    cleanup_ddp()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="支持配置文件的多卡DDP地理图像测评")
    parser.add_argument("--config", type=str, default="benchmark/config.yaml",
                       help="配置文件路径")
    parser.add_argument("--ckpt_path", type=str, help="覆盖配置文件中的模型路径")
    parser.add_argument("--data_path", type=str, help="覆盖配置文件中的数据路径")
    parser.add_argument("--lang", type=str, choices=["en", "zh"], help="覆盖配置文件中的语言")
    parser.add_argument("--temperature", type=float, help="覆盖配置文件中的温度")
    parser.add_argument("--seed", type=int, help="覆盖配置文件中的种子")
    
    args = parser.parse_args()
    
    
    config = load_config(args.config)
    
    
    if args.ckpt_path:
        config['model']['ckpt_path'] = args.ckpt_path
    if config['model']['ckpt_path'][0] == "$":
        config['model']['ckpt_path'] = os.environ[config['model']['ckpt_path'][1:]]
    if args.data_path:
        config['data']['data_path'] = args.data_path
    if args.lang:
        config['data']['lang'] = args.lang
    config['training']['log_dir'] = os.path.join(config['training']['log_dir'], config['data']['lang'])
    if args.temperature:
        config['generation']['temperature'] = args.temperature
    if args.seed:
        config['training']['seed'] = args.seed
    
    
    world_size = torch.cuda.device_count()
    
    if world_size == 0:
        print("错误: 未检测到GPU")
        sys.exit(1)
    
    print(f"检测到 {world_size} 个GPU")
    
    
    import torch.multiprocessing as mp
    mp.spawn(main, args=(world_size, config), nprocs=world_size, join=True)
