import json
import copy
import gc
import re
import argparse
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
import numpy as np
import tqdm
from cuml.cluster import KMeans
import cupy as cp
from typing import List, Dict, Any, Optional, Tuple
from contextlib import nullcontext


def merge_and_split_blocks(text: str, L_max: int = 300) -> List[str]:
    start = text.find("<think>")
    end = text.find("</think>")
    if start != -1 and end != -1 and end > start:
        text = text[start + len("<think>"):end].strip()
    elif start != -1 and (end == -1 or end < start):
        text = text[start + len("<think>"):].strip()
    elif start == -1 and end != -1:
        text = text[:end].strip()
    else:
        text = text.strip()

    sentence_ending_tokens = {".", "?", "!"}
    paragraph_patterns = ["\r\n\r\n", "\n\n"]

    chunks: List[str] = []
    current: List[str] = []

    def flush_current():
        s = "".join(current).strip()
        if s:
            chunks.append(s)

    i, n = 0, len(text)
    while i < n:
        matched_len = 0
        for p in paragraph_patterns:
            if i + len(p) <= n and text[i:i+len(p)] == p:
                matched_len = len(p)
                break
        if matched_len:
            flush_current()
            current = []
            i += matched_len
            continue

        ch = text[i]
        current.append(ch)

        is_sentence_end = False
        if ch in sentence_ending_tokens and i + 1 < n and text[i + 1] in {" ", "\n"}:
            is_sentence_end = True

        if is_sentence_end:
            flush_current()
            current = []
            j = i + 1
            while j < n and text[j] in {" ", "\n"}:
                j += 1
            i = j
            continue

        i += 1

    if current:
        flush_current()

    def token_count(s: str) -> int:
        return len(re.findall(r"\S+", s))

    def force_split_by_tokens(s: str, max_tokens: int) -> List[str]:
        words = re.findall(r"\S+", s)
        if not words:
            return []
        out, buf = [], []
        for w in words:
            if len(buf) >= max_tokens:
                out.append(" ".join(buf))
                buf = [w]
            else:
                buf.append(w)
        if buf:
            out.append(" ".join(buf))
        return out

    parts: List[str] = []
    metas: List[Tuple[int, int, int]] = []
    for orig_idx, c in enumerate(chunks):
        if token_count(c) > L_max:
            split = force_split_by_tokens(c, L_max)
            total = len(split)
            for k, seg in enumerate(split):
                parts.append(seg)
                metas.append((orig_idx, k, total))
        else:
            parts.append(c)
            metas.append((orig_idx, 0, 1))

    def _join_back(i_cur: int):
        if i_cur <= 0:
            return False
        parts[i_cur - 1] = (parts[i_cur - 1].rstrip() + " " + parts[i_cur].lstrip()).strip()
        parts.pop(i_cur)
        metas.pop(i_cur)
        return True

    def _join_forward(i_cur: int):
        if i_cur + 1 >= len(parts):
            return False
        parts[i_cur + 1] = (parts[i_cur].rstrip() + " " + parts[i_cur + 1].lstrip()).strip()
        parts.pop(i_cur)
        metas.pop(i_cur)
        return True

    i = 0
    while i < len(parts):
        s = parts[i]
        if len(s) < 10 and len(parts) > 1:
            orig_idx, part_idx, part_total = metas[i]

            if part_total > 1:
                if 0 < part_idx < part_total - 1:
                    _ = _join_back(i)
                    continue
                elif part_idx == part_total - 1:
                    if not _join_forward(i):
                        _ = _join_back(i)
                    continue
                else:
                    if _join_forward(i):
                        continue
                    else:
                        _ = _join_back(i)
                        continue
            else:
                if not _join_forward(i):
                    _ = _join_back(i)
                continue
        i += 1

    return parts

def create_reasoning_blocks(jsonl_path: str, n_sampling: Optional[int] = None) -> List[Dict[str, Any]]:
    all_data = []
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            all_data.append(json.loads(line))

    results = []

    for idx in tqdm.tqdm(range(len(all_data))):
        data = all_data[idx]

        output_item = {
            'idx': data.get('idx'),
            'question': data.get('question'),
            'gt_cot': data.get('gt_cot'),
            'gt': data.get('gt'),
            'answer': data.get('answer')
        }

        array_keys = ['code', 'pred', 'report', 'finish_reason', 'score']
        for key in array_keys:
            if key in data and isinstance(data[key], list):
                output_item[key] = data[key][:n_sampling] if n_sampling else data[key]

        if 'code' in data and isinstance(data['code'], list):
            code_samples = data['code'][:n_sampling] if n_sampling else data['code']
            processed_samples = []
            question_text = data.get('question', '')

            for code_text in code_samples:
                if isinstance(code_text, str):
                    blocks = merge_and_split_blocks(code_text)
                    if blocks and question_text:
                        blocks[0] = question_text + "\n\n" + blocks[0]
                    processed_samples.append(blocks)

            output_item['blocks'] = processed_samples

        results.append(output_item)
    return results


def get_embeddings_batch(
    texts,
    tokenizer,
    model,
    batch_size: int = 32,
    device: str = 'cuda',
    empty_cache_every: int | None = None
) -> List[List[float]]:
    all_embeddings: List[List[float]] = []
    i = 0
    n = len(texts)
    use_cuda = (device == 'cuda' and torch.cuda.is_available())
    bcnt = 0

    while i < n:
        cur_bs = min(batch_size, n - i)
        batch_texts = texts[i:i+cur_bs]
        bcnt += 1

        try:
            encoded_input = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors='pt'
            )
            encoded_input = {k: v.to(device, non_blocking=True) for k, v in encoded_input.items()}

            amp_ctx = torch.cuda.amp.autocast(dtype=torch.float16) if use_cuda else nullcontext()
            with torch.inference_mode():
                with amp_ctx:
                    out = model(**encoded_input)
                    emb = out.last_hidden_state.mean(dim=1)
                emb = emb.float()
                # emb = torch.nn.functional.normalize(emb, p=2, dim=1)

            all_embeddings.extend(emb.cpu().numpy().tolist())
            i += cur_bs

        except torch.cuda.OutOfMemoryError:
            if use_cuda:
                torch.cuda.empty_cache()
            if cur_bs == 1:
                raise
            batch_size = max(1, cur_bs // 2)
            continue

        finally:
            del encoded_input
            if 'out' in locals():
                del out
            if 'emb' in locals():
                del emb
            if empty_cache_every is not None and use_cuda and (bcnt % empty_cache_every == 0):
                torch.cuda.empty_cache()

    return all_embeddings

def add_vectors_to_blocks(
    data_list: List[Dict[str, Any]],
    output_path: str,
    batch_size: int = 32,
    device: str = 'cuda'
) -> List[Dict[str, Any]]:
    if device == 'cuda' and not torch.cuda.is_available():
        device = 'cpu'
        print("CUDA not available, using CPU")
    
    tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5')
    model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5').to(device)
    model.eval()
    
    all_data = [copy.deepcopy(item) for item in data_list]
    
    results = []
    
    for data in tqdm.tqdm(all_data):
        if 'blocks' in data:
            all_blocks = []
            sample_indices = []
            
            for sample_idx, sample_blocks in enumerate(data['blocks']):
                for block in sample_blocks:
                    all_blocks.append(block)
                    sample_indices.append(sample_idx)
            
            if all_blocks:
                all_vectors = get_embeddings_batch(
                    all_blocks, 
                    tokenizer, 
                    model, 
                    batch_size, 
                    device
                )
                
                vectors_by_sample = [[] for _ in range(len(data['blocks']))]
                for vector, sample_idx in zip(all_vectors, sample_indices):
                    vectors_by_sample[sample_idx].append(vector)
                
                data['vectors'] = vectors_by_sample
        
        results.append(data)
    
    return results

def cluster_vectors_with_kmeans_gpu(
    data_lists: List[List[Dict[str, Any]]],
    n_clusters: int,
    random_state: int = 42,
    n_init: int = 10,
    max_iter: int = 300
) -> List[List[Dict[str, Any]]]:
    all_data_lists = [copy.deepcopy(dl) for dl in data_lists]

    all_vectors: List[np.ndarray] = []
    lengths: List[List[int]] = []
    for data_list in all_data_lists:
        for data in data_list:
            vecs = data['vectors']
            sample_block_lens = [len(blocks) for blocks in vecs]
            lengths.append(sample_block_lens)
            for sample_vectors in vecs:
                for v in sample_vectors:
                    all_vectors.append(v)

    if len(all_vectors) == 0:
        return all_data_lists

    X_host = np.asarray(all_vectors, dtype=np.float32)
    X = cp.asarray(X_host)
    del X_host, all_vectors
    gc.collect()

    print("=" * 22)
    print("cuML KMeans clustering on GPU...")
    print("=" * 22)

    kmeans = KMeans(
        n_clusters=n_clusters,
        random_state=random_state,
        n_init=n_init,
        max_iter=max_iter,
        verbose=1,
        init='scalable-k-means++',
    )
    labels_cp = kmeans.fit_predict(X)       # (N,)
    centers_cp = kmeans.cluster_centers_    # (k, d)
    C = centers_cp.astype(cp.float32, copy=False)
    G = C @ C.T
    diagG = cp.diag(G)
    D2 = cp.maximum(diagG[:, None] + diagG[None, :] - 2.0 * G, 0.0)
    D_cp = cp.sqrt(D2)

    labels = cp.asnumpy(labels_cp)
    D = cp.asnumpy(D_cp)

    del labels_cp, D_cp, centers_cp, G, diagG, D2, C, X
    cp.get_default_memory_pool().free_all_blocks()

    results: List[List[Dict[str, Any]]] = []
    idx = 0
    len_cursor = 0

    for data_list in all_data_lists:
        path_data = []
        for data in data_list:
            sample_block_lens = lengths[len_cursor]
            len_cursor += 1

            line_clusters: List[List[int]] = []
            for block_len in sample_block_lens:
                sample_clusters = [int(labels[idx + j]) for j in range(block_len)]
                idx += block_len
                line_clusters.append(sample_clusters)

            line_edges: List[List[float]] = []
            for cluster_seq in line_clusters:
                if len(cluster_seq) <= 1:
                    line_edges.append([])
                else:
                    edges = [float(D[cluster_seq[j], cluster_seq[j + 1]])
                             for j in range(len(cluster_seq) - 1)]
                    line_edges.append(edges)

            data['path_node'] = line_clusters
            data['path_edge'] = line_edges
            path_data.append(data)
        results.append(path_data)

    return results


def main(
    input_paths: List[str],
    n_sampling: Optional[int] = None,
    batch_size: int = 32,
    device: str = 'cuda',
    n_clusters: int = 200,
    n_init: int = 10,
    max_iter: int = 300,
    output_suffix: str = '_processed',
) -> None:
    data_list: List[List[Dict[str, Any]]] = []

    for input_path in input_paths:
        blocks_data: List[Dict[str, Any]] = create_reasoning_blocks(input_path, n_sampling)
        print("=" * 22)
        print(f"create_reasoning_blocks completed for {Path(input_path).stem}")
        print("=" * 22)

        vectors_data: List[Dict[str, Any]] = add_vectors_to_blocks(
            blocks_data, None, batch_size, device
        )
        data_list.append(vectors_data)
        print("=" * 22)
        print(f"add_vectors_to_blocks completed for {Path(input_path).stem}")
        print("=" * 22)
    
    clustered_results: List[List[Dict[str, Any]]] = cluster_vectors_with_kmeans_gpu(
        data_list,
        n_clusters=n_clusters,
        n_init=n_init,
        max_iter=max_iter
    )
    print("=" * 22)
    print("cluster_vectors_with_kmeans completed")
    print("=" * 22)

    for input_path, result_data in zip(input_paths, clustered_results):
        input_path_p = Path(input_path)
        base_name = input_path_p.stem
        output_dir = input_path_p.parent

        final_output = output_dir / f"{base_name}{output_suffix}.jsonl"
        with open(final_output, 'w', encoding='utf-8') as f:
            for item in result_data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        print("=" * 22)
        print(f"Final output saved to {final_output}")
        print("=" * 22)


if __name__ == "__main__":
   parser = argparse.ArgumentParser(description='Process multiple JSONL files with text clustering pipeline')
   parser.add_argument('input_files', nargs='+', help='Input JSONL file paths')
   parser.add_argument('--n_sampling', type=int, default=None, help='Number of samples to use')
   parser.add_argument('--batch_size', type=int, default=32, help='Batch size for embedding')
   parser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'], help='Device to use')
   parser.add_argument('--n_clusters', type=int, default=200, help='Number of clusters for KMeans')
   parser.add_argument('--n_init', type=int, default=10, help='Number of initializations for KMeans')
   parser.add_argument('--max_iter', type=int, default=300, help='Maximum iterations for KMeans')
   parser.add_argument('--output_suffix', type=str, default='_processed', help='Suffix for output files')
   
   args = parser.parse_args()
   
   results = main(
       input_paths=args.input_files,
       n_sampling=args.n_sampling,
       batch_size=args.batch_size,
       device=args.device,
       n_clusters=args.n_clusters,
       n_init=args.n_init,
       max_iter=args.max_iter,
       output_suffix=args.output_suffix
   )
