from pathlib import Path
import pandas as pd
from collections import namedtuple
import json
import numpy as np

Log = namedtuple("Log", ["task", "user", "robot", "records", "start_time", "fake"])

import time
from datetime import datetime

def time_convert(timestr, is_old):
    # Convert string back to struct_time
    if not is_old:
        time_struct = time.strptime(timestr, "%m%d%H%M%S")
        dt = datetime(*time_struct[:6])
        correct_year_dt = dt.replace(year=2023)
        unix_timestamp = correct_year_dt.timestamp()
    else:
        time_struct = time.strptime(timestr, "%Y%m%d_%H%M%S")
        dt = datetime(*time_struct[:6])
        unix_timestamp = dt.timestamp()
    
    return unix_timestamp


def read_logs(path):
    logs = []
    for date_dir in path.iterdir():
        for user_dir in date_dir.iterdir():
            for task_dir in user_dir.iterdir():
                log_files = list(task_dir.glob("*.log"))
                assert len(log_files) == 1, f"{task_dir} has many log files"
                log_file = log_files[0]
                with open(log_file, 'r') as f:
                    lines = f.readlines()
                    records = [json.loads(line.strip()) for line in lines]
                task_id, robot = task_dir.name[4:].split("_")
                # Starting from 05/11, we have made changes to the log format.
                # As a result, logs generated on 05/11 require special handling.
                if date_dir.name != "0511":
                    time_str = log_file.stem.split("_")[-1]
                    is_old = False
                else:
                    time_str = "_".join(log_file.stem.split("_")[-2:])
                    is_old = True
                fake = False
                try:
                    start_time = time_convert(time_str, is_old)
                except:
                    start_time = 0
                    fake = True
                logs.append(Log(int(task_id), user_dir.name, robot, records, start_time, fake))
    return logs

def is_correct(log):
    last_record = log.records[-1]
    return last_record['role'] == 'system' and last_record['action'] == 'exit' and last_record['content'] == "True"

def get_time(records): return records[-1]['time'] - records[0]['time']

def time_bin(logs):
    time_acc_data = [{"task": log.task, "robot": log.robot, "time": get_time(log.records) if is_correct(log) else 3600} for log in logs]
    time_list = np.linspace(0, 3000, 20)
    time_bin_data = []
    for data in time_acc_data:
        task, system, time = data["task"], data["robot"], data["time"]
        for time_bin in time_list:
            time_bin_data.append({"n_utter": f"{time_bin}", "lexicon_id": task, "pair": system, "success_rate": 0 if time_bin < time else 1})

    df = pd.DataFrame.from_records(time_bin_data)
    df.to_csv("time_bin.csv", index=False)

def get_trace(records): return sum([1 if r["role"] == "user" and r["action"] == "trace" and (':' not in r["content"]) else 0 for r in records])
def get_edit(records): return sum([1 if r["role"] == "user" and r["action"] == "edit" else 0 for r in records])
def get_resyn(records): return sum([1 if r["role"] == "user" and r["action"] == "resyn" else 0 for r in records])
def get_chat(records): return sum([1 if r["role"] == "user" and r["action"] == "chat" else 0 for r in records])
def get_check(records): return sum([1 if r["role"] == "user" and r["action"] == "check" else 0 for r in records])

def is_A_one_shot(log):
    return (get_trace(log.records) + get_edit(log.records) + get_resyn(log.records)) == 0

def is_B_one_shot(log):
    return get_check(log.records) + get_chat(log.records) == 2

def dispatch(log, fun_a, fun_b):
    if log.robot == 'A':
        return fun_a(log)
    elif log.robot == 'B':
        return fun_b(log)
    else:
        raise NotImplementedError

def interaction_bin(logs):
    interaction_data = []
    for log in logs:
        if log.robot == 'A':
            if is_correct(log):
                trace_num = get_trace(log.records)
                edit_num = get_edit(log.records)
                resyn_num = get_resyn(log.records)
                interaction = 1 + trace_num + edit_num + resyn_num
                without_trace_interaction = 1 + edit_num + resyn_num
            else:
                interaction = 100
                without_trace_interaction = 100
            interaction_data.append({"task": log.task, "robot": log.robot, "interaction": interaction})
            interaction_data.append({"task": log.task, "robot": "C", "interaction": without_trace_interaction})
        elif log.robot == 'B':
            if is_correct(log):
                interaction = get_chat(log.records)
            else:
                interaction = 100
            interaction_data.append({"task": log.task, "robot": log.robot, "interaction": interaction})

    interaction_bin_data = []
    max_interaction = 47
    for d in interaction_data:
        for i in range(max_interaction):
            interaction_bin_data.append({"n_utter": i, "lexicon_id": d["task"], "pair": d["robot"], "success_rate": 0 if i <= d["interaction"] else 1})
    df = pd.DataFrame.from_records(interaction_bin_data)
    df.to_csv("interaction_bin.csv", index=False)

def time_distribution(logs):
    log_counts = {k: 0 for k in range(400)}
    for log in logs:
        if is_correct(log):
            log_counts[log.task] += 1
    time_distribution_data = [{"task": log.task, "robot": log.robot, "time": get_time(log.records)} for log in logs if log_counts[log.task] == 2]
    df = pd.DataFrame.from_records(time_distribution_data)
    df.to_csv("time_distribution.csv", index=False)

def bar_plot(logs):
    bars = []
    for log in logs:
        if log.robot == "A":
            correct, one_shot = is_correct(log), is_A_one_shot(log)
            bars.append({"chain_seed": log.task, "mode": "A", "sat_items_n": int(correct)})
            bars.append({"chain_seed": log.task, "mode": "C", "sat_items_n": int(correct and one_shot)})
        else:
            correct, one_shot = is_correct(log), is_B_one_shot(log)
            bars.append({"chain_seed": log.task, "mode": "B", "sat_items_n": int(correct)})
            bars.append({"chain_seed": log.task, "mode": "D", "sat_items_n": int(correct and one_shot)})
            

    df = pd.DataFrame.from_records(bars)
    df.to_csv("bar.csv", index=False)

def is_edit_start(r): return r["role"] == "user" and r["action"] == "edit"
def is_resyn_start(r): return r["role"] == "user" and r["action"] == "resyn"
def is_trace_start(r): return r["role"] == "user" and r["action"] == "trace" and (':' not in r["content"])
def is_anpl_check(r): return r["role"] == "system" and r["action"] == "anpl_check"

def get_interactions(records):
    results = {"edit": 0, "edit_num": 0, "resyn": 0, "resyn_num": 0, "trace": 0, "trace_num": 0}
    flag, time = None, 0
    for record in records:
        if is_edit_start(record):
            ntime = record["time"]
            results["edit_num"] += 1
            if flag:
                results[flag] += ntime - time
            flag, time = "edit", ntime
        elif is_resyn_start(record):
            ntime = record["time"]
            results["resyn_num"] += 1
            if flag:
                results[flag] += ntime - time
            flag, time = "resyn", ntime
        elif is_trace_start(record):
            ntime = record["time"]
            results["trace_num"] += 1
            if flag:
                results[flag] += ntime - time
            flag, time = "trace", ntime
    return results

def interactions(logs):
    data = []
    for log in logs:
        if log.robot == "A" and not log.fake:
            data.append({"task": log.task, **get_interactions(log.records)})
    df = pd.DataFrame.from_records(data)
    df.to_csv("interaction.csv", index=False)

def interaction_dis(logs):
    def a_interaction(log):
        trace_num = get_trace(log.records)
        edit_num = get_edit(log.records)
        resyn_num = get_resyn(log.records)
        interaction = 1 + trace_num + edit_num + resyn_num
        return interaction

    def b_interaction(log):
        interaction = get_chat(log.records)
        return interaction
    
    log_counts = {k: 0 for k in range(400)}
    for log in logs:
        if is_correct(log):
            log_counts[log.task] += 1

    interaction_distribution_data = [{"task": log.task, "robot": log.robot, "interaction": dispatch(log, a_interaction, b_interaction)} for log in logs if log_counts[log.task] == 2]
    df = pd.DataFrame.from_records(interaction_distribution_data)
    df.to_csv("interaction_distribution_tc.csv", index=False)

def user_diff(logs):
    user_data = {}
    for log in logs:
        d = user_data.get(log.user, {"A_all": 0, "A_correct": 0, "B_all": 0, "B_correct": 0})
        is_co = is_correct(log)
        if log.robot == "A":
            d["A_all"] += 1
            if is_co:
                d["A_correct"] += 1
        else:
            d["B_all"] += 1
            if is_co:
                d["B_correct"] += 1
        user_data[log.user] = d
    
    datas = [{"user": k, "A_acc": v["A_correct"] / v["A_all"], "B_acc": v["B_correct"] / v["B_all"]} for k, v in user_data.items()]
    df = pd.DataFrame.from_records(datas)
    df.to_csv("user_diff.csv", index=False)

def ab_ba(logs):
    df = pd.read_csv("abba.csv", names = ["id", "order"])
    df = df.to_dict('records')
    dd = {}
    for r in df:
        dd[r["id"]] = r["order"]

    data = {"AB": {"A": 0, "B": 0, "ALL": 0}, "BA": {"A": 0, "B": 0, "ALL": 0}}
    for log in logs:
        tag = dd[str(log.task)]
        data[tag]["ALL"] += 1
        data[tag][log.robot] += int(is_correct(log))
    print(data)


if __name__ == "__main__":
    logs = read_logs(Path("../anpl_data"))
    print("Task Num")
    print(len(logs))
    
    Ac = [l.task for l in logs if l.robot == 'A' and is_correct(l)]
    print(f"A correct: {len(set(Ac))}")

    Bc = [l.task for l in logs if l.robot == 'B' and is_correct(l)]
    print(f"B correct: {len(set(Bc))}")

    # Interaction
    interaction_dis(logs)
    interaction_bin(logs)
    interactions(logs)

    # Time
    time_distribution(logs)
    time_bin(logs)

    # System A vs System B
    bar_plot(logs)
    ab_ba(logs)

    # User
    user_diff(logs)
