import numpy as np
import random
from tqdm import tqdm
import re
import json
from typing import *
from thinktime.ts_generator.generate import generate_time_series, generate_controlled_attributes, attribute_to_text, generate_random_attributes
from thinktime.utils.llm_utils import LLMClient
from thinktime.utils.encoding_utils import timeseries_encoding, timeseries_to_list
import yaml
import copy
import os


# CONFIG
NUM_DATA = yaml.safe_load(open("config/datagen_config.yaml"))["num_data_llm_qa"]
SEQ_LEN = yaml.safe_load(open("config/datagen_config.yaml"))["seq_len"]
ENCODING_METHOD = yaml.safe_load(open("config/datagen_config.yaml"))["encoding_method"]
OUTPUT_BASE_DIR = yaml.safe_load(open("config/datagen_config.yaml"))["data_output_dir"]
OUTPUT_PATH = f'{OUTPUT_BASE_DIR}/mts_local_llm_{SEQ_LEN}_{NUM_DATA}_{ENCODING_METHOD}.jsonl'
EVOL_LABEL_PATH = f'{OUTPUT_BASE_DIR}/evol_labels/mts_local_llm_{SEQ_LEN}_{NUM_DATA}_{ENCODING_METHOD}.json'
CLUSTER_LABEL_PATH = f'{OUTPUT_BASE_DIR}/labels/mts_local_llm_{SEQ_LEN}_{NUM_DATA}_{ENCODING_METHOD}.json'
DISABLE_METRIC_CONFIG = yaml.safe_load(open("config/datagen_config.yaml"))["disable_metric_config"]
DRYRUN = yaml.safe_load(open("config/datagen_config.yaml"))["dryrun"]
LOCAL_LLM_PATH = yaml.safe_load(open("config/datagen_config.yaml"))["local_llm_path"]
# Number of QA pairs to select per MTS (if None, use all generated QAs)
QA_KEEP_RATE = 0.1


# All Config for TS Features
all_config = {
    "overall_attribute": {
        "seasonal": {
            "no periodic fluctuation": 0.7,
            "periodic fluctuation": 0.3
        },
        "trend": {
            "decrease": 0.2,
            "increase": 0.2,
            "keep steady": 0.6
        },
        "frequency": {
            "high frequency": 0.5,
            "low frequency": 0.5
        },
        "noise": {
            "noisy": 0.3,
            "almost no noise": 0.7
        }
    },
    "change": {
        "shake": 1,
        "upward spike": 15,
        "downward spike": 15,
        "continuous upward spike": 8,
        "continuous downward spike": 8,
        "upward convex": 1,
        "downward convex": 1,
        "sudden increase": 5,
        "sudden decrease": 5,
        "rapid rise followed by slow decline": 1,
        "slow rise followed by rapid decline": 1,
        "rapid decline followed by slow rise": 1,
        "slow decline followed by rapid rise": 1,
        "decrease after upward spike": 1,
        "increase after downward spike": 1,
        "increase after upward spike": 1,
        "decrease after downward spike": 1,
        "wide upward spike": 1,
        "wide downward spike": 1
    }
}

metric_config = json.load(open('config/metric_set.json', 'rt'))
all_prompt_idx = 0

def replace_prompts(data, obj):
    pattern = re.compile(r"<\|prompt(\d+)\|>")
    if isinstance(obj, dict):
        return {k: replace_prompts(data, v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [replace_prompts(data, item) for item in obj]
    elif isinstance(obj, str):
        def prompt_replacer(match):
            i = int(match.group(1))
            if i >= len(data):
                print(f"Warning: prompt index {i} out of range ({len(data)=})")
            return data[i]
        return pattern.sub(prompt_replacer, obj)
    else:
        return obj

def attribute_pool_to_json(attribute_pool: dict) -> str:
    for i in range(len(attribute_pool['local'])):
        attribute_pool["local"][i]['amplitude'] = round(attribute_pool["local"][i]['amplitude'], 2)
    return json.dumps(attribute_pool, ensure_ascii=False)

class PositiveTSGenError(Exception):
    pass

# -------- helper: collect local positions --------
def _local_positions(attr: dict) -> List[int]:
    return [int(l['position_start']) for l in attr.get('local', []) if 'position_start' in l]

# -------- enforce min gap vs avoid_positions (ALL locals) --------
def generate_positive_timeseries(
    cnt: int,
    change_position: int = None,
    seq_len: int = 256,
    avoid_positions: Optional[List[int]] = None,
    min_gap: int = 20,
    interference_positions: Optional[List[int]] = None
) -> Tuple[List[np.ndarray], List[dict], int, List[str]]:
    """
    Generate positive time series with a main change close to 'change_position'
    while avoiding 'avoid_positions' within 'min_gap' for ANY local.
    If failed (no locals or violation), raise PositiveTSGenError.

    Returns: ts list, attr list, shared main pos (= change_position), main_type_list.
    """
    avoid_positions = list(avoid_positions or [])
    interference_positions = list(interference_positions or [])
    
    if change_position is None:
        lb = int(0.02 * seq_len)
        rb = int(0.95 * seq_len)
            
        change_position = None
        for _ in range(200):
            cand = random.randint(lb, rb)
            if all(abs(cand - p) >= min_gap for p in avoid_positions):
                change_position = cand
                break
        if change_position is None:
            raise PositiveTSGenError("No feasible change_position")

    # Bounds for placing extra locals
    lb = int(0.02 * seq_len)
    rb = int(0.95 * seq_len)

    timeseries, attributes, main_type_list = [], [], []

    same_type_extra_max = 1
    other_type_extra_max = 2
    p_same_type_extra = 0.7
    p_other_type_extra = 0.6
    max_retry = 30

    for _ in range(cnt):
        # First pass: realize one main change at the target position (no jitter)
        base_changes = {(int(change_position), None)}

        ts0, attr0 = None, None
        for _try in range(max_retry):
            ap = generate_random_attributes(
                all_config['overall_attribute'],
                all_config['change'],
                base_changes.copy(),
                seq_len
            )
            ts, ap = generate_time_series(ap, seq_len)
            if not ap.get('local') or len(ap['local']) != 1:
                continue

            main_local = ap['local'][0]
            if abs(int(main_local['position_start']) - int(change_position)) > 12:
                continue

            # main must be far from avoid_positions
            if any(abs(int(main_local['position_start']) - p) < min_gap for p in avoid_positions):
                continue

            ts0, attr0 = ts, ap
            break

        if ts0 is None:
            # print("Warning: cannot realize main change at target position in first pass")
            raise PositiveTSGenError("Failed to get main change (first pass)")

        main_type = attr0['local'][0]['type']

        target_positions = {int(change_position)}
        same_type_extras = random.randint(0, same_type_extra_max) if random.random() < p_same_type_extra else 0
        other_type_extras = random.randint(0, other_type_extra_max) if random.random() < p_other_type_extra else 0
        
        interference_extras = 0
        if interference_positions and random.random() < 0.4:  # 40%概率添加干扰
            interference_extras = random.randint(1, min(2, len(interference_positions)))
        
        need_total = 1 + same_type_extras + other_type_extras + interference_extras
        
        intra_series_min_gap = 20

        for interference_pos in interference_positions:
            if len(target_positions) >= need_total:
                break
            if all(abs(interference_pos - existing_pos) >= intra_series_min_gap for existing_pos in target_positions) and \
               all(abs(interference_pos - p) >= min_gap for p in avoid_positions):
                target_positions.add(interference_pos)

        guard = 0
        while len(target_positions) < need_total and guard < 2000:
            guard += 1
            pos = random.randint(lb, rb)
            if all(abs(pos - existing_pos) >= intra_series_min_gap for existing_pos in target_positions) and \
               all(abs(pos - p) >= min_gap for p in avoid_positions):
                target_positions.add(pos)

        # If cannot satisfy extras, keep only main
        if len(target_positions) < need_total:
            # print(f"Warning: cannot place extras, keep only main at {change_position}. {same_type_extras=}, {other_type_extras=}")
            target_positions = {int(change_position)}

        # Second pass: realize target positions
        final_ts, final_attr = None, None
        for _try in range(max_retry):
            changes = {(int(p), None) for p in target_positions}
            ap = generate_random_attributes(
                all_config['overall_attribute'],
                all_config['change'],
                changes.copy(),
                seq_len
            )
            ts, ap = generate_time_series(ap, seq_len)
            if not ap.get('local'):
                continue

            main_local_now = min(ap['local'], key=lambda x: abs(int(x['position_start']) - int(change_position)))
            if abs(int(main_local_now['position_start']) - int(change_position)) > 12:
                continue
            if main_local_now['type'] != main_type and _try < max_retry // 3:
                continue

            # 只检查main local是否与avoid_positions有足够距离
            main_pos = int(main_local_now['position_start'])
            if any(abs(main_pos - apv) < min_gap for apv in avoid_positions):
                continue

            final_ts, final_attr = ts, ap
            # print(f"Success. {len(ap.get('local', []))} locals at {[int(l['position_start']) for l in ap.get('local', [])]}. {same_type_extras=}, {other_type_extras=}")
            break

        if final_ts is None or not final_attr.get('local'):
            # print("Warning: cannot realize all target positions in second pass")
            raise PositiveTSGenError("Final attrs invalid or violate min_gap")

        timeseries.append(final_ts)
        attributes.append(final_attr)
        main_local_now = min(final_attr['local'], key=lambda x: abs(int(x['position_start']) - int(change_position)))
        main_type_list.append(main_local_now['type'])

    return timeseries, attributes, int(change_position), main_type_list

# -------- negatives avoid main positive positions by min_gap --------
def generate_negative_timeseries(
    cnt: int,
    avoid_positions: List[int],
    seq_len: int = 256,
    min_gap: int = 20
) -> Tuple[List[np.ndarray], List[dict]]:
    """
    Generate negatives with random changes, staying away from avoid_positions (main positive positions) by at least min_gap.
    """
    timeseries, attributes = [], []
    lb = int(0.02 * seq_len)
    rb = int(0.95 * seq_len)
    min_interval = max(seq_len // 6, min_gap)
    negative_positions: Set[int] = set()

    for _ in range(cnt):
        while True:
            if random.random() > 0.7:
                try_cnt = 0
                flag = False
                candidate_position = None
                while try_cnt < 2000:
                    try_cnt += 1
                    cand = random.randint(lb, rb)
                    if all(abs(cand - pos) >= min_interval for pos in avoid_positions) and \
                       all(abs(cand - pos) >= min_interval for pos in negative_positions):
                        candidate_position = cand
                        flag = True
                        break
                changes = {(candidate_position, None)} if flag else set()
                if flag:
                    negative_positions.add(candidate_position)
            else:
                changes = set()

            attribute_pool = generate_random_attributes(all_config['overall_attribute'], all_config['change'], changes, seq_len)
            ts, attribute_pool = generate_time_series(attribute_pool, seq_len)
            
            expected_locals = len(changes)
            actual_locals = len(attribute_pool.get('local', []))
            if expected_locals == 0:
                if actual_locals == 0:
                    break
            else:
                if actual_locals == expected_locals:
                    break

        timeseries.append(ts)
        attributes.append(attribute_pool)
    return timeseries, attributes

def generate_prompt_data(seq_len: int=256):
    global all_prompt_idx

    if SEQ_LEN is None:
        p = random.random()
        if p > 0.4:
            current_seq_len = 256
        else:
            current_seq_len = random.randint(64, 1024)
    else:
        current_seq_len = SEQ_LEN

    # min separation required between clusters (and negatives vs positives)
    min_gap = max(current_seq_len // 16, 30)

    sample = random.choice(list(metric_config))
    situation = sample['category']
    cluster: Dict[str, List[str]] = sample['cluster']
    metric_to_cluster = {metric: cluster_name for cluster_name, metrics in cluster.items() for metric in metrics}

    num_positive_clusters = random.randint(1, 3)
    visited_metrics, visited_clusters = set(), set()
    positive_cluster, positive_metrics = [], []
    for _ in range(num_positive_clusters):
        if random.random() > 0.5:
            candidate_clusters = [i for i in cluster if len(set(cluster[i]) - visited_metrics) > 1 and i not in visited_clusters]
            if len(candidate_clusters) == 0:
                continue
            current_cluster = random.choice(candidate_clusters)
            candidate_metrics = list(set(cluster[current_cluster]) - visited_metrics)
            cur_positive_metrics = list(np.random.choice(candidate_metrics, size=random.randint(2, len(candidate_metrics)), replace=False))
            visited_clusters.add(current_cluster)
            visited_metrics.update(cur_positive_metrics)
            positive_metrics.extend(cur_positive_metrics)
            positive_cluster.append(cur_positive_metrics)
        else:
            candidate_metrics = [i for i in metric_to_cluster if i not in visited_metrics]
            if len(candidate_metrics) < 2:
                continue
            cur_positive_metrics = list(np.random.choice(candidate_metrics, size=random.randint(2, min(len(candidate_metrics), 5)), replace=False))
            visited_metrics.update(cur_positive_metrics)
            positive_metrics.extend(cur_positive_metrics)
            positive_cluster.append(cur_positive_metrics)

    negative_metrics = random.sample(sorted(set(metric_to_cluster) - set(positive_metrics)), random.randint(0, 5))
    num_negative_items = len(negative_metrics)

    positive_timeseries, positive_attributes, positive_idx_list, positive_change_position_list = [], [], [], []
    main_types_all = []

    # ---- NEW: track main positive positions and generate interference positions ----
    main_positive_positions: List[int] = []
    
    interference_positions: List[int] = []
    if random.random() < 0.6:
        num_interference = random.randint(1, 2)
        if random.random() > 0.8:
            lb, rb = int(0.02 * current_seq_len), int(0.95 * current_seq_len)
        else:
            lb, rb = int(0.02 * current_seq_len), int(0.75 * current_seq_len)
        for _ in range(num_interference):
            interference_pos = random.randint(lb, rb)
            interference_positions.append(interference_pos)

    for i in range(len(positive_cluster)):
        success = False
        for _attempt in range(10):
            try:
                if random.random() < 0.9:
                    lb = int(0.8 * current_seq_len)
                    rb = int(0.95 * current_seq_len)
                else:
                    lb = int(0.02 * current_seq_len)
                    rb = int(0.95 * current_seq_len)
                    
                positive_change_position = None
                for _pos_try in range(500):
                    candidate = random.randint(lb, rb)
                    if all(abs(candidate - pos) >= min_gap for pos in main_positive_positions):
                        positive_change_position = candidate
                        break
                if positive_change_position is None:
                    raise PositiveTSGenError("No feasible cluster main position under min_gap")

                prior_main_positions = list(main_positive_positions)

                cur_positive_timeseries, cur_positive_attributes, positive_change_position, cur_main_types = generate_positive_timeseries(
                    cnt=len(positive_cluster[i]),
                    change_position=positive_change_position,
                    seq_len=current_seq_len,
                    avoid_positions=prior_main_positions,
                    min_gap=min_gap,
                    interference_positions=interference_positions
                )

                if any(abs(positive_change_position - old) < min_gap for old in prior_main_positions):
                    raise PositiveTSGenError("Generated main position violates min_gap against prior main positions")

                main_positive_positions.append(positive_change_position)

                success = True
                break
            except PositiveTSGenError:
                continue

        if not success:
            raise PositiveTSGenError("Fail to build separated clusters for this case")

        positive_timeseries.extend(cur_positive_timeseries)
        positive_attributes.extend(cur_positive_attributes)
        positive_idx_list.extend([i] * len(cur_positive_timeseries))
        positive_change_position_list.append(positive_change_position)
        main_types_all.extend(cur_main_types)

    # If no positives produced, raise to let outer loop skip the whole case
    if len(positive_timeseries) == 0:
        raise PositiveTSGenError("No positive sequences generated")

    negative_timeseries, negative_attributes = generate_negative_timeseries(
        num_negative_items, main_positive_positions, current_seq_len, min_gap=min_gap
    )

    shuffle_indices = np.random.permutation(len(positive_timeseries) + len(negative_timeseries))
    combined_timeseries = positive_timeseries + negative_timeseries
    combined_attributes = positive_attributes + negative_attributes
    combined_metrics = positive_metrics + negative_metrics
    combined_cluster_idx = positive_idx_list + [None] * len(negative_timeseries)
    main_types = main_types_all + [None] * len(negative_timeseries)

    combined_timeseries = [combined_timeseries[i] for i in shuffle_indices]
    combined_attributes = [combined_attributes[i] for i in shuffle_indices]
    combined_metrics = [combined_metrics[i] for i in shuffle_indices]
    combined_cluster_idx = [combined_cluster_idx[i] for i in shuffle_indices]
    main_types = [main_types[i] for i in shuffle_indices]

    label = {
        'timeseries': [i.tolist() for i in combined_timeseries],
        'label': {
            'clusters': [],
            'position': int(positive_change_position_list[0]),
            'correlations': [],
            'cols': combined_metrics,
            'situation': situation
        },
        'attribute_pool': combined_attributes,
        'main_types': main_types
    }

    prompt = f'In a {situation} system, there are {len(shuffle_indices)} metrics:'
    question_list, answer_list, llm_prompt_list, fields_list = [], [], [], []
    corr_pool_list = [None] * len(shuffle_indices)
    original_timeseries = copy.deepcopy(combined_timeseries)

    for i in range(len(shuffle_indices)):
        scaled_timeseries, cur_ts_prompt, _ = timeseries_encoding(combined_timeseries[i], ENCODING_METHOD)
        combined_timeseries[i] = scaled_timeseries
        prompt += f"\n {combined_metrics[i]} is of length {current_seq_len}: {cur_ts_prompt};"

        cur_positive_idx = 0
        positive_indicies = [k for k in range(len(combined_metrics)) if combined_metrics[k] in positive_cluster[cur_positive_idx]] if len(positive_cluster) > 0 else []

        for j in range(len(shuffle_indices)):
            if random.random() < 0.8 and not (i in positive_indicies and j in positive_indicies):
                continue
            if i == j:
                continue
            if QA_KEEP_RATE is not None and random.random() > QA_KEEP_RATE:
                continue

            positive_change_position = positive_change_position_list[cur_positive_idx] if len(positive_change_position_list) > 0 else int(current_seq_len // 2)
            question_list.append(f"Based on the characteristics of the time series, please describe the characteristics of {combined_metrics[i]} and {combined_metrics[j]} from the aspects of periodicity, trend, local characteristics, frequency characteristics, and noise. And analyze whether there may be a correlation of fluctuation between them around point {positive_change_position}. Conclude the physical meaning of the fluctuation correlation (or no correlation) in one sentence.")
            fields_list.append({
                "local": [i, j],
                "seasonal": [i, j],
                "trend": [i, j],
                "noise": [i, j],
                "statistic": [i, j]
            })
            cur_answer = f"{combined_metrics[i]}: " + attribute_to_text(original_timeseries[i], combined_attributes[i], generate_values=False) + f"; {combined_metrics[j]}: " + attribute_to_text(original_timeseries[j], combined_attributes[j], generate_values=False)
            if i in positive_indicies and j in positive_indicies:
                cur_answer += f" Both metrics show sudden changes around point {positive_change_position}, indicating a possible correlation in terms of fluctuation. <|prompt{all_prompt_idx}|>"
                label["label"]["correlations"].append({
                    "pair": [combined_metrics[i], combined_metrics[j]],
                    "explain": f"<|prompt{all_prompt_idx}|>",
                    "label": True
                })
                all_prompt_idx += 1
                cur_llm_prompt = f"In a {situation} system, there are many monitoring metrics. Near a timestamp (maybe during a failure), we found there are fluctuations in {combined_metrics[i]} and {combined_metrics[j]} that happens together. Please explain why {combined_metrics[i]} and {combined_metrics[j]} fluctuates together in their physical meaning in English in one sentence (e.g. both a and b are xxx-related metrics and xxx may cause their fluctuations / a may cause b). Make sure to keep it simple. "
                if metric_to_cluster.get(combined_metrics[i]) == metric_to_cluster.get(combined_metrics[j]):
                    cur_llm_prompt += f"(Hint: These two metrics are both {metric_to_cluster[combined_metrics[i]]}-related.)"
                llm_prompt_list.append([cur_llm_prompt])
            elif combined_cluster_idx[i] is not None and combined_cluster_idx[i] == combined_cluster_idx[j]:
                cur_answer += f" No. Both metrics show sudden changes around point {positive_change_position_list[combined_cluster_idx[i]]}, but no sudden changes around point {positive_change_position}. <|prompt{all_prompt_idx}|>"
                label["label"]["correlations"].append({
                    "pair": [combined_metrics[i], combined_metrics[j]],
                    "explain": f"<|prompt{all_prompt_idx}|>",
                    "label": False
                })
                all_prompt_idx += 1
                cur_llm_prompt = f"In a {situation} system, there are many monitoring metrics. Near a timestamp (maybe during a failure), we found there are **no** fluctuations in both {combined_metrics[i]} and {combined_metrics[j]}, but they fluctuated together in another time (before or after the failure). Please explain why {combined_metrics[i]} and {combined_metrics[j]} are not fluctuating together at this time in their physical meaning in English in one sentence (e.g. both a and b are xxx-related metrics and xxx may cause their fluctuations / a may cause b). Make sure to keep it simple. "
                if metric_to_cluster.get(combined_metrics[i]) == metric_to_cluster.get(combined_metrics[j]):
                    cur_llm_prompt += f"(Hint: These two metrics are both {metric_to_cluster[combined_metrics[i]]}-related.)"
                llm_prompt_list.append([cur_llm_prompt])
            elif (i in positive_indicies) != (j in positive_indicies):
                cur_answer += f" These two time series do not seem to have much correlation in terms of fluctuation around point {positive_change_position}. <|prompt{all_prompt_idx}|>"
                label["label"]["correlations"].append({
                    "pair": [combined_metrics[i], combined_metrics[j]],
                    "explain": f"<|prompt{all_prompt_idx}|>",
                    "label": False
                })
                all_prompt_idx += 1
                a, b = (i, j) if i in positive_indicies else (j, i)
                cur_llm_prompt = f"In a {situation} system, there are many monitoring metrics. Near a timestamp (maybe during a failure), we found there are fluctuations in {combined_metrics[a]}, but no fluctuations in {combined_metrics[b]}. Please explain why {combined_metrics[a]} and {combined_metrics[b]} are **not** fluctuating together in their physical meaning in English in one simple sentence (e.g. a is xxx-related, so xxx. But b is xxx-related, which may not affected by xxx). Make sure to keep it simple:"
                llm_prompt_list.append([cur_llm_prompt])
            else:
                cur_answer += f" These two time series do not seem to have much correlation in terms of fluctuation around point {positive_change_position}. <|prompt{all_prompt_idx}|>"
                label["label"]["correlations"].append({
                    "pair": [combined_metrics[i], combined_metrics[j]],
                    "explain": f"<|prompt{all_prompt_idx}|>",
                    "label": False
                })
                all_prompt_idx += 1
                cur_llm_prompt = f"In a {situation} system, there are many monitoring metrics. Near a timestamp (during a failure), we found there are fluctuations in some of the metrics, but no fluctuations in both {combined_metrics[i]} and {combined_metrics[j]}. Please explain why {combined_metrics[i]} and {combined_metrics[j]} are **not** fluctuating in their physical meaning in English in one simple sentence. Make sure to keep it simple:"
                llm_prompt_list.append([cur_llm_prompt])
            answer_list.append(cur_answer)

        # Task 3: find similar series
        positive_change_position = positive_change_position_list[0] if len(positive_change_position_list) > 0 else int(current_seq_len // 2)
        negative_indicies = [k for k in range(len(combined_cluster_idx)) if combined_cluster_idx[k] is None]
        cur_fields = None
        cur_llm_prompts = []

        if i in negative_indicies or i not in ([k for k in range(len(combined_metrics)) if len(positive_cluster) > 0 and combined_metrics[k] in positive_cluster[0]]):
            cur_answer = f"Among these metrics, I did not find any other metrics that may be related to {combined_metrics[i]} in terms of fluctuation around point {positive_change_position}. It seems that {combined_metrics[i]} shows no significant fluctuation around this point."
            cur_fields = {"local": [i]}
        else:
            cur_answer = f'I found the following metrics that may be related to {combined_metrics[i]} in terms of fluctuation:'
            i_change = label['main_types'][i]
            positive_indicies = [k for k in range(len(combined_metrics)) if len(positive_cluster) > 0 and combined_metrics[k] in positive_cluster[0]]
            related_idxs = []
            for j in range(len(shuffle_indices)):
                if i == j or j not in positive_indicies:
                    continue
                j_change = label['main_types'][j]
                if i_change == j_change:
                    cur_answer += f" {combined_metrics[i]} and {combined_metrics[j]} both show {i_change} around point {positive_change_position}, indicating a possible correlation in terms of fluctuation."
                else:
                    cur_answer += f" {combined_metrics[i]} shows {i_change} around point {positive_change_position}, while {combined_metrics[j]} shows {j_change} around this point, indicating a possible correlation in terms of fluctuation."
                related_idxs.append(j)

            all_related_idxs = sorted(related_idxs + [i])
            cur_fields = {"local": all_related_idxs}
            cur_llm_prompts = [f"In a {situation} system, there are many monitoring metrics. Near a timestamp (maybe during a failure), we found there are fluctuations in " + ', '.join(combined_metrics[j] for j in all_related_idxs) + f". Please explain their relationship in physical meaning and simply describe what's may happening in the {situation} system in English in 1 sentence, like `these metrics are all xxx-related or xxx. {situation} may xxx.` (the format may be different, but keep simple): "]
        
        if QA_KEEP_RATE and random.random() < QA_KEEP_RATE:
            cur_answer += f' <|prompt{all_prompt_idx}|>'
            corr_pool_list[i] = [all_related_idxs, cur_answer]
            if len(label["label"]["clusters"]) == 0:
                label["label"]["clusters"].append({
                    'col_idx': [[int(j), label['main_types'][j]] for j in all_related_idxs],
                    'cols': [combined_metrics[j] for j in all_related_idxs],
                    'explain': f"<|prompt{all_prompt_idx}|>",
                })
            question_list.append(f"Based on the fluctuations in the metrics around point {positive_change_position}, please find other metric(s) that may be related to {combined_metrics[i]}, output their numbers, and explain the reasons. If related metrics are found, explain why they have similar local fluctuations considering their physical meaning in one sentence. If no related metrics are found, output that no related metrics were found.")
            answer_list.append(cur_answer)
            fields_list.append(cur_fields)
            llm_prompt_list.append(cur_llm_prompts)
            all_prompt_idx += len(cur_llm_prompts)
        else:
            corr_pool_list[i] = [all_related_idxs, cur_answer]
            if len(label["label"]["clusters"]) == 0:
                label["label"]["clusters"].append({
                    'col_idx': [[int(j), label['main_types'][j]] for j in all_related_idxs],
                    'cols': [combined_metrics[j] for j in all_related_idxs],
                    'explain': "",
                })

    return original_timeseries, combined_timeseries, combined_metrics, combined_attributes, prompt, question_list, answer_list, llm_prompt_list, fields_list, corr_pool_list, label

def generate_dataset():
    result = []
    prompts = []
    labels = []
    with tqdm(total=NUM_DATA, desc='Generating prompt...') as t:
        cnt = 0
        while cnt < NUM_DATA:
            try:
                original_timeseries, combined_timeseries, combined_metrics, combined_attributes, prompt, question_list, answer_list, llm_prompt_list, fields_list, corr_pool_list, label = generate_prompt_data(SEQ_LEN)
                
                num_prompts_this_round = sum(len(prompts_per_qa) for prompts_per_qa in llm_prompt_list)
                
            except PositiveTSGenError:
                # Skip this whole case
                # print("positive ts gen error")
                continue
                
            result.append((original_timeseries, combined_timeseries, combined_metrics, combined_attributes, prompt, question_list, answer_list, llm_prompt_list, fields_list, corr_pool_list))
            labels.append(label)
            
            for item in llm_prompt_list:
                prompts.extend(item)
                
            cnt += num_prompts_this_round
            t.update(num_prompts_this_round)

    print(f'Generated {len(result)} data items, with {len(prompts)} prompts. {all_prompt_idx=}')

    if DRYRUN:
        llm_answers = ['This is a test answer.'] * len(prompts)
    else:
        llm_client = LLMClient(model_path=LOCAL_LLM_PATH, engine='vllm')
        llm_answers = llm_client.llm_batch_generate(prompts, use_chat_template=True)
        llm_client.kill()

    print("Processing generated answers...")
    idx = 0
    for original_timeseries, combined_timeseries, combined_metrics, combined_attributes, prompt, question_list, answer_list, llm_prompt_list, fields_list, corr_pool_list in result:
        for i in range(len(question_list)):
            for j in range(len(llm_prompt_list[i])):
                answer_list[i] = answer_list[i].replace(f'<|prompt{idx}|>', llm_answers[idx])
                idx += 1
        for i in range(len(corr_pool_list)):
            if corr_pool_list[i] is not None:
                corr_pool_list[i][1] = replace_prompts(llm_answers, corr_pool_list[i][1])

    labels = replace_prompts(llm_answers, labels)
    return result, labels


if __name__ == '__main__':
    print('Generating...')
    try:
        result, cluster_labels = generate_dataset()
    except KeyboardInterrupt:
        raise
    evol_labels = []

    print("Writing to file...")
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(EVOL_LABEL_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(CLUSTER_LABEL_PATH), exist_ok=True)

    with open(OUTPUT_PATH, 'wt') as f:
        for ts_idx, (original_timeseries, combined_timeseries, combined_metrics, combined_attributes, prompt, question_list, answer_list, llm_prompt_list, fields_list, corr_pool) in enumerate(result):
            for i in range(len(question_list)):
                out_item = {
                    'input': prompt[:-1] + '. ' + question_list[i],
                    'output': answer_list[i],
                    'timeseries': timeseries_to_list(combined_timeseries),
                }
                cur_label = {
                    "fields": fields_list[i],
                    "metrics": combined_metrics,
                    "corr_pool": corr_pool,
                    "attribute_pool": combined_attributes,
                    "instruction": prompt,
                    "question": question_list[i],
                    "ts_idx": ts_idx
                }
                f.write(json.dumps(out_item, ensure_ascii=False) + '\n')
                evol_labels.append(cur_label)

    with open(CLUSTER_LABEL_PATH, 'wt') as f:
        json.dump(cluster_labels, f, ensure_ascii=False, indent=4)
    with open(EVOL_LABEL_PATH, 'wt') as f:
        json.dump(evol_labels, f, ensure_ascii=False, indent=4)

    print("Finished.")
