import json
import os
import subprocess
import sys
import time
from dataclasses import dataclass
from typing import Dict

import numpy as np
import torch

from model_trainer.utils.dataset_registry import DatasetRegistry, DatasetRegistryError

# 文件锁支持（跨平台）
try:
    import fcntl  # Linux/Unix
    HAS_FCNTL = True
except ImportError:
    HAS_FCNTL = False

# Windows文件锁支持（如果需要Windows支持可以取消注释）
# try:
#     import msvcrt  # Windows
#     HAS_MSVCRT = True
# except ImportError:
#     HAS_MSVCRT = False
HAS_MSVCRT = False  # 暂时禁用Windows特定锁，使用通用文件锁机制


@dataclass
class EmbeddingSpec:
    path: str
    splits: Dict[str, str]


def ensure_embeddings(config: Dict, *, logger=None) -> None:
    """Ensure required embeddings exist and are valid for the current run."""
    if config.get('legacy_loader'):
        return
    if not config.get('requires_news_embedding'):
        return

    auto_generate = bool(config.get('auto_generate_embedding', True))

    dataset_alias = config.get('dataset_alias')
    if not dataset_alias:
        raise RuntimeError('requires_news_embedding=True but dataset_alias is missing')

    dataset_root = config.get('dataset_root') or config.get('data_path')
    if not dataset_root:
        raise RuntimeError('dataset_root/data_path is required to locate embeddings')

    try:
        registry_info = DatasetRegistry.get(dataset_alias)
    except DatasetRegistryError as exc:
        raise RuntimeError(f'failed to resolve dataset alias {dataset_alias}: {exc}') from exc

    embeddings = registry_info.get('embeddings', {})
    news_spec_dict = embeddings.get('news')
    if not news_spec_dict:
        raise RuntimeError(f'alias {dataset_alias} does not declare news embeddings in registry')

    embed_spec = EmbeddingSpec(
        path=news_spec_dict.get('path', ''),
        splits=news_spec_dict.get('splits', {})
    )
    if not embed_spec.path:
        raise RuntimeError(f'alias {dataset_alias} missing embedding path definition')

    embed_abs_path = _make_abs_path(dataset_root, embed_spec.path)

    if not os.path.isfile(embed_abs_path):
        if not auto_generate:
            raise FileNotFoundError(
                f'Embedding file not found: {embed_abs_path} and auto generation is disabled. '
                'Please run scripts/generate_qwen_embeddings.py manually.'
            )
        if logger:
            logger.info(f'Embedding file not found: {embed_abs_path}. Generating via script...')
        _generate_embeddings_with_lock(dataset_alias, embed_abs_path, config=config, logger=logger)
        
        # 再次检查文件是否存在（可能另一个进程已经生成了）
        if not os.path.isfile(embed_abs_path):
            raise RuntimeError(f'Embedding file still not found after generation: {embed_abs_path}')

    _validate_embedding_file(
        embed_abs_path,
        embed_spec,
        dataset_root,
        registry_info.get('splits', {}),
        logger=logger,
    )


def _generate_embeddings_with_lock(alias: str, embed_abs_path: str, *, config: Dict, logger=None) -> None:
    """
    使用文件锁机制生成embedding，防止多个进程同时生成
    """
    # 创建锁文件路径（在embedding文件同目录下）
    lock_file_path = embed_abs_path + '.lock'
    lock_file = None
    lock_acquired = False
    
    try:
        # 尝试获取文件锁
        max_wait_time = 3600  # 最大等待时间：1小时
        wait_interval = 2  # 检查间隔：2秒
        start_time = time.time()
        
        while time.time() - start_time < max_wait_time:
            # 再次检查文件是否已存在（可能另一个进程已经生成）
            if os.path.isfile(embed_abs_path):
                if logger:
                    logger.info(f'Embedding file already exists (generated by another process): {embed_abs_path}')
                return
            
            try:
                # 尝试打开锁文件（创建或追加）
                lock_file = open(lock_file_path, 'a')
                
                # 尝试获取文件锁
                if HAS_FCNTL:
                    # Linux/Unix: 使用fcntl
                    fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
                elif HAS_MSVCRT:
                    # Windows: 使用msvcrt
                    msvcrt.locking(lock_file.fileno(), msvcrt.LK_NBLCK, 1)
                else:
                    # 如果没有文件锁支持，使用简单的文件存在检查
                    if os.path.exists(lock_file_path):
                        # 检查锁文件是否过期（超过1小时认为过期）
                        lock_age = time.time() - os.path.getmtime(lock_file_path)
                        if lock_age > 3600:
                            if logger:
                                logger.warning(f'Lock file expired (age: {lock_age:.0f}s), removing...')
                            os.remove(lock_file_path)
                            continue
                        # 锁文件存在，等待
                        lock_file.close()
                        lock_file = None
                        time.sleep(wait_interval)
                        continue
                    # 创建锁文件
                    lock_file.write(f'{os.getpid()}\n')
                    lock_file.flush()
                
                # 成功获取锁
                lock_acquired = True
                if logger:
                    logger.info(f'Acquired lock for embedding generation: {embed_abs_path}')
                
                # 再次检查文件（可能在等待锁的过程中已生成）
                if os.path.isfile(embed_abs_path):
                    if logger:
                        logger.info(f'Embedding file already exists (generated while waiting for lock): {embed_abs_path}')
                    return
                
                # 生成embedding
                _generate_embeddings(alias, embed_abs_path, config=config, logger=logger)
                break
                
            except (IOError, OSError) as e:
                # 无法获取锁，等待后重试
                if lock_file:
                    lock_file.close()
                    lock_file = None
                if logger:
                    logger.debug(f'Could not acquire lock, waiting... (error: {e})')
                time.sleep(wait_interval)
        else:
            # 超时
            raise RuntimeError(
                f'Timeout waiting for embedding generation lock after {max_wait_time}s. '
                f'Lock file: {lock_file_path}'
            )
            
    finally:
        # 释放锁
        if lock_file:
            try:
                if HAS_FCNTL:
                    fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
                elif HAS_MSVCRT:
                    msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
                lock_file.close()
            except Exception:
                pass
        
        # 删除锁文件
        if lock_acquired and os.path.exists(lock_file_path):
            try:
                os.remove(lock_file_path)
            except Exception:
                pass


def _generate_embeddings(alias: str, embed_abs_path: str, *, config: Dict, logger=None) -> None:
    """实际执行embedding生成的函数"""
    repo_root = _repo_root()
    script_path = os.path.join(repo_root, 'scripts', 'generate_qwen_embeddings.py')
    if not os.path.isfile(script_path):
        raise FileNotFoundError(
            f'Embedding generation script not found at {script_path}. '
            'Please ensure scripts/generate_qwen_embeddings.py is available.'
        )

    env = os.environ.copy()
    cmd = [sys.executable, script_path, '--alias', alias]
    if config.get('dataset_version'):
        cmd.extend(['--dataset-version', str(config['dataset_version'])])

    if logger:
        logger.info('Running embedding generation script: %s', ' '.join(cmd))

    result = subprocess.run(cmd, cwd=repo_root, check=False, capture_output=True, text=True)
    if result.returncode != 0:
        if logger:
            logger.error('Embedding generation failed. stdout=%s stderr=%s', result.stdout, result.stderr)
        raise RuntimeError(
            f'Embedding generation script failed with exit code {result.returncode}: {result.stderr.strip()}'
        )
    if logger and result.stdout:
        logger.info(result.stdout.strip())

    if not os.path.isfile(embed_abs_path):
        raise RuntimeError(f'Embedding generation script completed but file not found: {embed_abs_path}')


def _validate_embedding_file(
    embed_abs_path: str,
    embed_spec: EmbeddingSpec,
    dataset_root: str,
    splits: Dict[str, str],
    *,
    logger=None,
) -> None:
    if logger:
        logger.info('Validating embedding file: %s', embed_abs_path)

    tensor_dict = torch.load(embed_abs_path, map_location='cpu')
    if not isinstance(tensor_dict, dict):
        raise ValueError(f'Embedding file {embed_abs_path} should contain a dict of tensors')

    expected_split_lengths = _load_split_lengths(dataset_root, splits)

    for split, key in embed_spec.splits.items():
        if key not in tensor_dict:
            raise KeyError(f'Embedding file missing key {key} for split {split}')

        split_data = tensor_dict[key]
        if not isinstance(split_data, dict):
            # 兼容旧格式：直接是tensor
            split_data = {'embeddings': split_data}

        # 验证embeddings
        if 'embeddings' not in split_data:
            raise KeyError(f'Embedding split {key} missing "embeddings" field')
        embeddings = split_data['embeddings']
        if isinstance(embeddings, np.ndarray):
            embeddings = torch.from_numpy(embeddings)
        if not torch.is_tensor(embeddings):
            raise TypeError(f'Embeddings in {key} is not a tensor/ndarray')
        if embeddings.dim() not in [2, 3]:  # 支持2D(句子级)和3D(token级)
            raise ValueError(f'Embeddings in {key} must be 2D or 3D, got shape {tuple(embeddings.shape)}')
        if not torch.isfinite(embeddings).all():
            raise ValueError(f'Embeddings in {key} contains NaN/Inf values')
        if embeddings.dim() == 2 and (embeddings.norm(dim=1) == 0).any():
            raise ValueError(f'Embeddings in {key} contains zero vectors')
        elif embeddings.dim() == 3 and (embeddings.norm(dim=-1) == 0).any():
            # 对于3D，检查token级别的零向量
            zero_tokens = (embeddings.norm(dim=-1) == 0)
            if zero_tokens.any() and not zero_tokens.all():
                logger.warning(f'Embeddings in {key} contains some zero token vectors')

        expected_len = expected_split_lengths.get(split)
        if expected_len is not None and embeddings.shape[0] != expected_len:
            raise ValueError(
                f'Embeddings in {key} length {embeddings.shape[0]} mismatches dataset split {split} length {expected_len}'
            )

        # 验证attention mask (如果存在)
        if 'attention_mask' in split_data:
            mask = split_data['attention_mask']
            if isinstance(mask, np.ndarray):
                mask = torch.from_numpy(mask)
            if not torch.is_tensor(mask):
                raise TypeError(f'Attention mask in {key} is not a tensor/ndarray')
            if mask.shape[0] != embeddings.shape[0]:
                raise ValueError(f'Attention mask length {mask.shape[0]} mismatches embeddings length {embeddings.shape[0]}')
            if embeddings.dim() == 3 and mask.dim() == 2:
                if mask.shape[1] != embeddings.shape[1]:
                    raise ValueError(f'Attention mask seq_len {mask.shape[1]} mismatches embeddings seq_len {embeddings.shape[1]}')
            if mask.dtype not in [torch.bool, torch.int, torch.long]:
                logger.warning(f'Attention mask in {key} should be boolean or integer type, got {mask.dtype}')
        elif embeddings.dim() == 3:
            logger.warning(f'3D embeddings in {key} found but no attention_mask provided - padding tokens may not be properly handled')


def _load_split_lengths(dataset_root: str, splits: Dict[str, str]) -> Dict[str, int]:
    lengths: Dict[str, int] = {}
    for split, rel_path in splits.items():
        if not rel_path:
            continue
        abs_path = _make_abs_path(dataset_root, rel_path)
        if not os.path.isfile(abs_path):
            continue
        with open(abs_path, 'r', encoding='utf-8') as f:
            records = json.load(f)
        lengths[split] = len(records)
    return lengths


def _make_abs_path(root: str, rel_path: str) -> str:
    rel = rel_path.lstrip('/')
    if os.path.isabs(rel_path):
        return rel_path
    return os.path.abspath(os.path.join(root, rel))


def _repo_root() -> str:
    return os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', '..'))
