import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from sacrebleu.metrics import BLEU, CHRF
from rouge_score import rouge_scorer
import tqdm
from multiprocessing import Pool, cpu_count
from functools import partial
import os

CHINESE_HAN_RATIO = 0.20
CHINESE_KANA_MAX_RATIO = 0.05

def _is_han(ch: str) -> bool:
    return ('\u4e00' <= ch <= '\u9fff') or ('\u3400' <= ch <= '\u4dbf') or ('\uf900' <= ch <= '\ufaff')

def _is_chinese_like(text: str,
                     han_thresh: float = CHINESE_HAN_RATIO,
                     kana_max: float = CHINESE_KANA_MAX_RATIO) -> bool:
    s = ''.join(c for c in text if not c.isspace())
    if not s:
        return False
    total = len(s)
    han = 0
    for c in s:
        oc = ord(c)
        if oc <= 0x7F:
            continue
        if _is_han(c):
            han += 1
    return (han / total) >= han_thresh

def _filter_chinese_entries(code_list, score_list):
    keep_mask = [not _is_chinese_like(c) for c in code_list]
    filtered_codes = [c for c, k in zip(code_list, keep_mask) if k]
    if score_list and len(score_list) == len(code_list):
        filtered_scores = [s for s, k in zip(score_list, keep_mask) if k]
    else:
        filtered_scores = []
    removed = len(code_list) - len(filtered_codes)
    return filtered_codes, filtered_scores, removed


def compute_metric_matrices(
    jsonl_path: str,
    metric: str = "BLEU",
    output_path: Optional[str] = None,
    split_by_score: bool = False,
    n_workers: Optional[int] = None
) -> Dict[str, List[Dict[str, Any]]]:
    
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        data = [json.loads(line.strip()) for line in f if line.strip()]
    
    metric = metric.upper()
    if metric not in ["BLEU", "CHRF", "ROUGE-L"]:
        raise ValueError(f"Unsupported metric: {metric}. Choose from BLEU, CHRF, ROUGE-L")
    
    if n_workers is None:
        n_workers = min(cpu_count(), 96)
    
    print(f"Using {n_workers} worker processes for matrix computation")
    results = []
    for sample in tqdm.tqdm(data, desc=f"Computing {metric} matrices"):
        result = process_sample_parallel(sample, metric, n_workers)
        results.append(result)
    
    if split_by_score:
        output = split_results_by_score(results)
    else:
        output = {"all": results}
    
    if output_path:
        save_results(output, output_path)
    
    return output


def process_sample_parallel(
    sample: Dict[str, Any], 
    metric: str, 
    n_workers: int
) -> Dict[str, Any]:
    idx = sample.get('idx', None)
    code_list = sample.get('code', [])
    score_list = sample.get('score', [])
    
    code_list, score_list, chinese_filtered = _filter_chinese_entries(code_list, score_list)
    
    if not code_list:
        return {
            'idx': idx,
            'matrix': None,
            'code_count': 0,
            'score': score_list,
            'chinese_filtered': chinese_filtered
        }
    
    n = len(code_list)
    matrix = np.zeros((n, n))
    diagonal_value = 100.0 if metric in ["BLEU", "CHRF"] else 1.0
    np.fill_diagonal(matrix, diagonal_value)
    pairs = []
    for i in range(n):
        for j in range(i + 1, n):
            pairs.append((i, j, code_list[i], code_list[j]))
    
    if not pairs:
        return {
            'idx': idx,
            'matrix': matrix.tolist(),
            'code_count': n,
            'score': score_list,
            'chinese_filtered': chinese_filtered
        }
    
    batch_size = max(1, len(pairs) // (n_workers * 10))
    batches = [pairs[i:i + batch_size] for i in range(0, len(pairs), batch_size)]
    compute_func = partial(compute_pairwise_batch, metric=metric)
    
    pair_desc = f"Sample {idx}: Matrix {n}x{n} ({len(pairs)} pairs)"
    
    with Pool(processes=n_workers) as pool:
        batch_results = list(tqdm.tqdm(
            pool.imap(compute_func, batches),
            total=len(batches),
            desc=pair_desc,
            leave=False
        ))
    
    for batch_result in batch_results:
        for i, j, score in batch_result:
            matrix[i, j] = score
            matrix[j, i] = score
    
    return {
        'idx': idx,
        'matrix': matrix.tolist(),
        'code_count': n,
        'score': score_list,
        'chinese_filtered': chinese_filtered
    }


def compute_pairwise_batch(
    pairs: List[Tuple[int, int, str, str]],
    metric: str
) -> List[Tuple[int, int, float]]:
    
    if metric == "BLEU":
        scorer = BLEU(tokenize='intl', effective_order=True)
    elif metric == "CHRF":
        scorer = CHRF()
    else:
        scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
    
    results = []
    for i, j, code_i, code_j in pairs:
        score_i_to_j = compute_pairwise_score(code_j, code_i, metric, scorer)
        score_j_to_i = compute_pairwise_score(code_i, code_j, metric, scorer)
        symmetric_score = (score_i_to_j + score_j_to_i) / 2.0
        results.append((i, j, symmetric_score))
    
    return results


def compute_pairwise_score(
    ref: str, 
    hyp: str, 
    metric: str, 
    scorer: Any
) -> float:
    if not ref or not hyp:
        return 0.0
    
    try:
        if metric in ["BLEU", "CHRF"]:
            score = scorer.sentence_score(hyp, [ref]).score
        else:  # ROUGE-L
            scores = scorer.score(ref, hyp)
            score = scores['rougeL'].fmeasure * 100
    except:
        score = 0.0
    
    return score


def process_sample_wrapper(sample: Dict[str, Any], metric: str) -> Dict[str, Any]:
    if metric == "BLEU":
        scorer = BLEU(tokenize='intl')
    elif metric == "CHRF":
        scorer = CHRF()
    else:
        scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
    
    return process_sample(sample, metric, scorer)


def process_sample(
    sample: Dict[str, Any], 
    metric: str, 
    scorer: Any
) -> Dict[str, Any]:
    
    idx = sample.get('idx', None)
    code_list = sample.get('code', [])
    score_list = sample.get('score', [])
    
    code_list, score_list, chinese_filtered = _filter_chinese_entries(code_list, score_list)
    
    if not code_list:
        return {
            'idx': idx,
            'matrix': None,
            'code_count': 0,
            'score': score_list,
            'chinese_filtered': chinese_filtered
        }
    
    n = len(code_list)
    matrix = np.zeros((n, n))
    
    diagonal_value = 100.0 if metric in ["BLEU", "CHRF"] else 1.0
    np.fill_diagonal(matrix, diagonal_value)
    
    pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
    total_pairs = len(pairs)
    
    pair_desc = f"Sample {idx}: Matrix {n}x{n} ({total_pairs} pairs)"
    batch_size = max(1, len(pairs) // 100)
    
    for batch_start in tqdm.trange(0, len(pairs), batch_size, 
                                   desc=pair_desc, 
                                   leave=False):
        batch_end = min(batch_start + batch_size, len(pairs))
        batch_pairs = pairs[batch_start:batch_end]
        
        for i, j in batch_pairs:
            score_i_to_j = compute_pairwise_score(code_list[j], code_list[i], metric, scorer)
            score_j_to_i = compute_pairwise_score(code_list[i], code_list[j], metric, scorer)
            symmetric_score = (score_i_to_j + score_j_to_i) / 2.0
            
            matrix[i, j] = symmetric_score
            matrix[j, i] = symmetric_score
    
    return {
        'idx': idx,
        'matrix': matrix.tolist(),
        'code_count': n,
        'score': score_list,
        'chinese_filtered': chinese_filtered
    }


def process_sample_optimized(
    sample: Dict[str, Any], 
    metric: str,
    n_workers: int
) -> Dict[str, Any]:    
    idx = sample.get('idx', None)
    code_list = sample.get('code', [])
    score_list = sample.get('score', [])
    
    code_list, score_list, chinese_filtered = _filter_chinese_entries(code_list, score_list)
    
    if not code_list:
        return {
            'idx': idx,
            'matrix': None,
            'code_count': 0,
            'score': score_list,
            'chinese_filtered': chinese_filtered
        }
    
    n = len(code_list)
    matrix = np.zeros((n, n))
    
    diagonal_value = 100.0 if metric in ["BLEU", "CHRF"] else 1.0
    np.fill_diagonal(matrix, diagonal_value)
    
    pairs = []
    for i in range(n):
        for j in range(i + 1, n):
            pairs.append((i, j, code_list[i], code_list[j]))
    
    if not pairs:
        return {
            'idx': idx,
            'matrix': matrix.tolist(),
            'code_count': n,
            'score': score_list,
            'chinese_filtered': chinese_filtered
        }
    
    batch_size = max(1, min(1000, len(pairs) // (n_workers * 4)))
    batches = [pairs[i:i + batch_size] for i in range(0, len(pairs), batch_size)]
    
    compute_func = partial(compute_pairwise_batch, metric=metric)
    
    with Pool(processes=n_workers) as pool:
        batch_results = list(
            tqdm.tqdm(
                pool.imap(compute_func, batches),
                total=len(batches),
                desc=f"Sample {idx}: Computing pairwise {metric} scores"
            )
        )
    
    for batch_result in batch_results:
        for i, j, score in batch_result:
            matrix[i, j] = score
            matrix[j, i] = score
    
    return {
        'idx': idx,
        'matrix': matrix.tolist(),
        'code_count': n,
        'score': score_list,
        'chinese_filtered': chinese_filtered
    }


def split_results_by_score(
    results: List[Dict[str, Any]]
) -> Dict[str, List[Dict[str, Any]]]: 
    all_results = []
    correct_results = []
    wrong_results = []
    
    for result in results:
        score_list = result.get('score', [])
        code_count = result.get('code_count', 0)
        matrix = result.get('matrix', None)
        
        if matrix is None or code_count == 0:
            all_results.append(result)
            continue
        
        matrix = np.array(matrix)
        
        all_results.append({
            'idx': result['idx'],
            'matrix': matrix.tolist(),
            'code_count': code_count,
            'chinese_filtered': result.get('chinese_filtered', 0)
        })
        
        if score_list:
            correct_indices = [i for i, s in enumerate(score_list) if s]
            wrong_indices = [i for i, s in enumerate(score_list) if not s]
            if correct_indices:
                correct_matrix = matrix[np.ix_(correct_indices, correct_indices)]
                correct_results.append({
                    'idx': result['idx'],
                    'matrix': correct_matrix.tolist(),
                    'code_count': len(correct_indices),
                    'indices': correct_indices,
                    'chinese_filtered': result.get('chinese_filtered', 0)
                })
            
            if wrong_indices:
                wrong_matrix = matrix[np.ix_(wrong_indices, wrong_indices)]
                wrong_results.append({
                    'idx': result['idx'],
                    'matrix': wrong_matrix.tolist(),
                    'code_count': len(wrong_indices),
                    'indices': wrong_indices,
                    'chinese_filtered': result.get('chinese_filtered', 0)
                })
    
    return {
        'all': all_results,
        'correct': correct_results,
        'wrong': wrong_results
    }


def save_results(
    results: Dict[str, List[Dict[str, Any]]], 
    output_path: str
) -> None:    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    if output_path.endswith('.jsonl'):
        base_path = output_path[:-6]
    else:
        base_path = output_path
    
    for key, data in results.items():
        file_path = f"{base_path}_{key}.jsonl"
        with open(file_path, 'w', encoding='utf-8') as f:
            for item in data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        print(f"Saved {key} results to {file_path}")


def compute_statistics(
    results: Dict[str, List[Dict[str, Any]]]
) -> Dict[str, Dict[str, Any]]:
    
    stats = {}
    
    for category, data in results.items():
        off_diagonal_scores = []
        sample_count = 0
        
        for result in data:
            if result.get('matrix') is not None and result.get('code_count', 0) > 0:
                matrix = np.array(result['matrix'])
                n = matrix.shape[0]
                sample_count += 1
                
                for i in range(n):
                    for j in range(i + 1, n):
                        off_diagonal_scores.append(matrix[i, j])
        
        category_stats = {
            'total_samples': len(data),
            'samples_with_matrix': sample_count,
            'avg_code_count': np.mean([r.get('code_count', 0) for r in data]) if data else 0
        }
        
        if off_diagonal_scores:
            category_stats.update({
                'avg_similarity': np.mean(off_diagonal_scores),
                'std_similarity': np.std(off_diagonal_scores),
                'min_similarity': np.min(off_diagonal_scores),
                'max_similarity': np.max(off_diagonal_scores),
                'median_similarity': np.median(off_diagonal_scores),
                'unique_pairs': len(off_diagonal_scores)
            })
        
        stats[category] = category_stats
    
    return stats


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description='Compute metric matrices for code samples')
    parser.add_argument('jsonl_path', type=str, help='Path to input JSONL file')
    parser.add_argument('--metric', type=str, default='BLEU', 
                        choices=['BLEU', 'CHRF', 'ROUGE-L'],
                        help='Metric to compute (default: BLEU)')
    parser.add_argument('--output', type=str, required=True,
                        help='Path to output file (required)')
    parser.add_argument('--split-by-score', action='store_true',
                        help='Split results by score (all/correct/wrong)')
    parser.add_argument('--stats', action='store_true',
                        help='Print statistics about the results')
    parser.add_argument('--n-workers', type=int, default=None,
                        help='Number of worker processes (default: min(cpu_count, 96))')
    parser.add_argument('--optimize', action='store_true',
                        help='Use optimized processing for large matrices')
    
    args = parser.parse_args()
    
    if args.optimize:
        with open(args.jsonl_path, 'r', encoding='utf-8') as f:
            data = [json.loads(line.strip()) for line in f if line.strip()]
        
        metric = args.metric.upper()
        if metric not in ["BLEU", "CHRF", "ROUGE-L"]:
            raise ValueError(f"Unsupported metric: {metric}")
        
        n_workers = args.n_workers or min(cpu_count(), 180)
        print(f"Using {n_workers} workers with optimized processing")
        
        results = []
        for sample in tqdm.tqdm(data, desc=f"Computing {metric} matrices (optimized)"):
            result = process_sample_optimized(sample, metric, n_workers)
            results.append(result)
        
        if args.split_by_score:
            output = split_results_by_score(results)
        else:
            output = {"all": results}
        if args.output:
            save_results(output, args.output)
    else:
        output = compute_metric_matrices(
            jsonl_path=args.jsonl_path,
            metric=args.metric,
            output_path=args.output,
            split_by_score=args.split_by_score,
            n_workers=args.n_workers
        )
    
    if args.stats:
        stats = compute_statistics(output)
        for category, category_stats in stats.items():
            print(f"\n{category.upper()} Statistics ({args.metric}):")
            for key, value in category_stats.items():
                if isinstance(value, float):
                    print(f"  {key}: {value:.2f}")
                else:
                    print(f"  {key}: {value}")


if __name__ == "__main__":
    main()
