import json
import re
from typing import List, Dict, Any
from tqdm import tqdm
from loguru import logger
import random


# CONFIG
INPUT_FILE = 'data/warmup_sft_10000_no.jsonl'
OUTPUT_FILE = 'data/warmup_sft_10000_post_no.jsonl'
TOOL_CALL_PATTERN = re.compile(r"<tool_start>\s*(.*?)\s*<tool_end>", re.DOTALL)
TOOL_DEFINITION: List[Dict[str, Any]] = [{
    "type": "function",
    "function": {
        "name": "get_timeseries_slice",
        "description": (
            "Get the current timeseries_slice of one of the time series in a given location, "
            "you should call this tool during thinking to better recognize the local fluctuations of a given period"
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "metric_name": {
                    "type": "string",
                    "description": "The name of the metric to get the timeseries slice for",
                },
                "start": {
                    "type": "integer",
                    "description": "The start index of the timeseries slice",
                },
                "end": {
                    "type": "integer",
                    "description": "The end index of the timeseries slice",
                },
            },
            "required": ["metric_name", "start", "end"],
        },
    },
}, {
    "type": "function",
    "function": {
        "name": "compare_timeseries_slice",
        "description": (
            "Compare two time series slices for comparative analysis. "
            "Recommended for comparing same series different periods, or different series same periods. "
            "Use when analyzing periodicity or comparing different time series patterns."
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "metric_name_1": {
                    "type": "string",
                    "description": "The name of the first metric to compare",
                },
                "start_1": {
                    "type": "integer",
                    "description": "The start index of the first timeseries slice",
                },
                "end_1": {
                    "type": "integer",
                    "description": "The end index of the first timeseries slice",
                },
                "metric_name_2": {
                    "type": "string",
                    "description": "The name of the second metric to compare",
                },
                "start_2": {
                    "type": "integer",
                    "description": "The start index of the second timeseries slice",
                },
                "end_2": {
                    "type": "integer",
                    "description": "The end index of the second timeseries slice",
                },
            },
            "required": ["metric_name_1", "start_1", "end_1", "metric_name_2", "start_2", "end_2"],
        },
    },
}]
MIDDLE_RATIO: float = 0.5


def parse_sample(sample: Dict[str, Any]):
    """Extract question, thinking (raw), tool calls (ordered), narrative segments, final answer."""
    raw_input = sample.get("input", "")  # treat as full question
    output = sample.get("output", "") or ""

    # Find tool calls and narrative segments
    tool_calls = []  # list of dict payloads
    segments = []    # narrative text segments between tool calls (len = num_tool_calls + 1)

    last_end = 0
    for m in TOOL_CALL_PATTERN.finditer(output):
        segment_text = output[last_end:m.start()].strip()
        segments.append(segment_text)
        payload_raw = m.group(1).strip()
        try:
            payload = json.loads(payload_raw)
        except Exception:
            # Skip malformed tool call
            payload = None
        tool_calls.append(payload)
        last_end = m.end()
    # Tail segment
    segments.append(output[last_end:].strip())

    return {
        "question": raw_input.strip(),
        "segments": segments,
        "tool_calls": tool_calls
    }


def gen_tool_response(tool_call: Dict[str, Any], original_input: str, original_timeseries: Any):
    """Generate tool response text (as user turn) and extract actual timeseries slice.
    We map metric name to timeseries index by scanning original question text split by <ts><ts/>.
    """
    if not tool_call:
        return "<tool_response>\nError: malformed tool call.\n</tool_response>", []

    tool_name = tool_call.get("name", "")
    args = tool_call.get("arguments", {})

    def get_timeseries_slice(metric_name, start_idx, end_idx):
        """Helper function to get a slice from timeseries based on metric name."""
        target = None
        if isinstance(original_timeseries, list) and original_timeseries:
            text_segments = original_input.split('<ts><ts/>') if original_input else []
            for i, seg in enumerate(text_segments):
                if metric_name.lower() in seg.lower():
                    if i < len(original_timeseries):
                        seq = original_timeseries[i]
                        target = seq if isinstance(seq, list) else seq
                    break
            if target is None:
                seq = original_timeseries[0]
                target = seq if isinstance(seq, list) else seq
        else:
            target = []

        length = len(target) if isinstance(target, list) else 0
        start_idx = max(0, min(start_idx, max(0, length - 1)))
        end_idx = max(start_idx + 1, min(end_idx, length)) if length else end_idx
        slice_data = target[start_idx:end_idx] if isinstance(target, list) else []
        return slice_data

    if tool_name == "get_timeseries_slice":
        metric = args.get("metric_name", "metric")
        start = int(args.get("start", 0))
        end = int(args.get("end", start + 1))
        
        slice_data = get_timeseries_slice(metric, start, end)
        resp = f"<tool_response>\nThe slice of {metric} from {start} to {end} is: <ts><ts/>.\n</tool_response>"
        return resp, slice_data
        
    elif tool_name == "compare_timeseries_slice":
        metric_1 = args.get("metric_name_1", "metric_1")
        start_1 = int(args.get("start_1", 0))
        end_1 = int(args.get("end_1", start_1 + 1))
        metric_2 = args.get("metric_name_2", "metric_2")
        start_2 = int(args.get("start_2", 0))
        end_2 = int(args.get("end_2", start_2 + 1))
        
        slice_1 = get_timeseries_slice(metric_1, start_1, end_1)
        slice_2 = get_timeseries_slice(metric_2, start_2, end_2)
        
        resp = f"<tool_response>\nComparison results:\n- {metric_1} slice [{start_1}:{end_1}]: <ts><ts/>\n- {metric_2} slice [{start_2}:{end_2}]: <ts><ts/>\n</tool_response>"
        # Return both slices as a list
        return resp, [slice_1, slice_2]
    
    else:
        return f"<tool_response>\nError: Unknown tool '{tool_name}'.\n</tool_response>", []


def expand_sample(parsed: Dict[str, Any], original_timeseries: Any) -> List[Dict[str, Any]]:
    question = parsed['question']
    segments = parsed['segments']
    tool_calls = parsed['tool_calls']
    n = len(tool_calls)

    # Build assistant message pieces: segment[i] + tool_call[i]
    assistant_msgs = []
    for i, tc in enumerate(tool_calls):
        seg = segments[i]
        msg_body = ''
        if seg:
            msg_body += seg
        msg_body += f"\n<tool_call>\n{json.dumps(tc, ensure_ascii=False)}\n</tool_call>"
        assistant_msgs.append(msg_body)
    tail_segment = segments[-1] if len(segments) == n + 1 else ''

    # Conversation history as plain text turns (no role wrappers)
    history_turns: List[str] = []
    # Original user question as first turn (raw)
    history_turns.append(question)

    results = []
    accumulated_slices: List[List[float]] = []
    original_ts_list = original_timeseries if isinstance(original_timeseries, list) else []

    # Helper to serialize history_turns to string without outer boundary wrappers
    def serialize_history():
        result = ""
        for idx, turn in enumerate(history_turns):
            if idx != 0:
                if idx % 2 == 1:
                    result += f"<|im_start|>assistant\n{turn}"
                else:
                    result += f"<|im_start|>user\n{turn}"
            else:
                result += turn
            if idx != len(history_turns) - 1:
                result += f"<|im_end|>"
        return result

    # Iterate tool calls to produce samples
    for i, (tc, assistant_body) in enumerate(zip(tool_calls, assistant_msgs)):
        # Input: serialize current history
        input_str = serialize_history()
        # Output: assistant message (body) (segment + tool_call)
        output_str = assistant_body
        # Timeseries: original + all accumulated slices so far
        # Timeseries list should align with number of <ts><ts/> placeholders in input.
        # Flatten original timeseries (list of sequences) then append prior slices (each slice is a list of floats).
        timeseries_list: List[Any] = []
        if isinstance(original_ts_list, list):
            # original_ts_list may already be list of lists (e.g., one full series) -> extend, not nest
            timeseries_list.extend(original_ts_list)
        timeseries_list.extend(accumulated_slices)

        if random.random() < MIDDLE_RATIO:
            results.append({
                'input': input_str,
                'output': output_str,
                'timeseries': timeseries_list
            })

        # Append assistant message to history
        history_turns.append(assistant_body)
        # Simulate tool response (user turn)
        tool_resp_text, slice_data = gen_tool_response(tc, question, original_timeseries)
        history_turns.append(tool_resp_text)
        if slice_data:
            # Handle both single slice and multiple slices (from compare_timeseries_slice)
            if isinstance(slice_data, list) and len(slice_data) > 0:
                if isinstance(slice_data[0], list):
                    # Multiple slices (compare_timeseries_slice returns [[slice1], [slice2]])
                    accumulated_slices.extend(slice_data)
                else:
                    # Single slice (get_timeseries_slice returns [values])
                    accumulated_slices.append(slice_data)

    # Final answer sample
    # Add tail segment (remaining thinking) before final answer if exists
    # if tail_segment:
    #     # Append remaining narrative before final answer
    #     history_turns.append(tail_segment)
    final_input = serialize_history()
    timeseries_list: List[Any] = []
    if isinstance(original_ts_list, list):
        timeseries_list.extend(original_ts_list)
    timeseries_list.extend(accumulated_slices)
    results.append({
        'input': final_input,
        'output': tail_segment,
        'timeseries': timeseries_list
    })

    return results


def process_file(in_path: str, out_path: str):
    with open(in_path, 'rt') as fin, open(out_path, 'wt') as fout:
        for line in tqdm(fin):
            line = line.strip()
            if not line:
                continue
            try:
                sample = json.loads(line)
            except Exception as err:
                logger.error(f"Error parsing line in {in_path}: {err}")
                continue
            parsed = parse_sample(sample)
            if not parsed:
                logger.error(parsed)
                continue
            expanded = expand_sample(parsed, sample.get('timeseries'))
            for item in expanded:
                item["tools"] = json.dumps(TOOL_DEFINITION, ensure_ascii=False)
                fout.write(json.dumps(item, ensure_ascii=False) + '\n')


def main():
    out_path = OUTPUT_FILE or INPUT_FILE.replace('.jsonl', '_post.jsonl')
    process_file(INPUT_FILE, out_path)
    print(f"Done. Wrote postprocessed dataset to {out_path}")


if __name__ == '__main__':
    main()
