# -*- coding: utf-8 -*-
"""

"""

import argparse
from chatbot.benchmark.benchmark_baseline import read_objs4jsonl_file
import numpy as np


def analyze_llm_result(record_path):
    samples = read_objs4jsonl_file(record_path)
    times = list()
    last = None
    for sample in samples:
        if sample["type"] == "play" and last is not None:
            times.append(sample["time_stamp"] - last)
            last = None
        if sample["type"] == "conversation stop":
            last = sample["time_stamp"]
    times = np.asarray(times)
    print(f"avg: {sum(times)/len(times)}s, \
          max: {max(times)}s, \
          min: {min(times)}s, \
          25%: {np.percentile(times, 25)}, \
          50%: {np.percentile(times, 50)}, \
          75%: {np.percentile(times, 75)}, \
          90%: {np.percentile(times, 90)}")

def analyze_fd_llm_result(record_path):
    samples = read_objs4jsonl_file(record_path)
    times = list()
    last = None
    cnt = 0
    
    records = list()
    tmp = {"play": [], "start": None, "stop": None, "llm": None}
    for sample in samples:
        if sample["type"] == "conversation start":
            if tmp["start"] is not None:
                records.append(tmp)
            tmp = {"play": [], "start": sample, "stop": None, "llm": None}
        elif sample["type"] == "llm-query" and tmp["llm"] is None:
            tmp["llm"] = sample
        elif sample["type"] == "conversation stop":
            tmp["stop"] = sample
        elif sample["type"] == "play":
            tmp["play"].append(sample)
    
    record_times = {
        "interrupt_times": list(),
        "incomplete_times": list(),
        "finished_times": list(),
        "total_times": list()
    }
    
    for record in records:
        conv_start = record["start"]["time_stamp"]
        conv_end = record["stop"]["time_stamp"]
        cnt = 0
        for play in record["play"]:
            cnt += play["length"]
            if cnt >= 0:
                play_start = play["time_stamp"]
                break
        if "<incomplete>" in record["llm"]["query"] and record["stop"]["left"] > 16000:
            record_times["interrupt_times"].append(0)
            record_times["total_times"].append(0)
        elif "<incomplete>" in record["llm"]["query"]:
            record_times["incomplete_times"].append(play_start-conv_end)
            record_times["total_times"].append(play_start-conv_end)
        elif "<finished>" in record["llm"]["query"] and record["stop"]["left"] == 0:
            record_times["finished_times"].append(play_start-conv_end)
            record_times["total_times"].append(play_start-conv_end)
    
    
    for name, times in record_times.items():
        times = np.asarray(times)
        print(f"{name} times:")
        if len(times) == 0:
            continue
        print(f"total: {len(times)}, \
            avg: {sum(times)/len(times)}s, \
            max: {max(times)}s, \
            min: {min(times)}s, \
            25%: {np.percentile(times, 25)}, \
            50%: {np.percentile(times, 50)}, \
            75%: {np.percentile(times, 75)}, \
            90%: {np.percentile(times, 90)}")

def main():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--record-path", type=int, default=256)
    parser.add_argument("--llm-type", type=str, default="baseline")
    
    args = parser.parse_args()
    
    if args.llm_type == "baseline":
        analyze_llm_result(args.record_path)
    elif args.llm_type == "llm-fd":
        analyze_fd_llm_result(args.record_path)


if __name__ == '__main__':
    print("Begin")
    main()
    print("End")
