import argparse
import code
import datetime
import json
import os
from pytz import timezone
import time

import pandas as pd
from tqdm import tqdm


def get_log_files(max_num_files=None):
    dates = []
    for month in [4, 5]:
        for day in range(1, 32):
            dates.append(f"2023-{month:02d}-{day:02d}")

    num_servers = 14
    filenames = []
    for d in dates:
        for i in range(num_servers):
            name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
            if os.path.exists(name):
                filenames.append(name)
    max_num_files = max_num_files or len(filenames)
    filenames = filenames[-max_num_files:]
    return filenames


def pretty_print_conversation(messages):
    for role, msg in messages:
        print(f"[[{role}]]: {msg}")


def inspect_convs(log_files):
    data = []
    for filename in tqdm(log_files, desc="read files"):
        for retry in range(5):
            try:
                lines = open(filename).readlines()
                break
            except FileNotFoundError:
                time.sleep(2)

        for l in lines:
            row = json.loads(l)

            if "states" not in row:
                continue
            if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]:
                continue

            model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
            if row["type"] == "leftvote":
                winner, loser = model_names[0], model_names[1]
                winner_conv, loser_conv = row["states"][0], row["states"][1]
            elif row["type"] == "rightvote":
                loser, winner = model_names[0], model_names[1]
                loser_conv, winner_conv = row["states"][0], row["states"][1]

            if loser == "bard" and winner == "vicuna-13b":
                print("=" * 20)
                print(f"Winner: {winner}")
                pretty_print_conversation(winner_conv["messages"])
                print(f"Loser: {loser}")
                pretty_print_conversation(loser_conv["messages"])
                print("=" * 20)
                input()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max-num-files", type=int)
    args = parser.parse_args()

    log_files = get_log_files(args.max_num_files)
    inspect_convs(log_files)
