import os
import numpy as np
from pdb import set_trace as pds
import json
import os
import sys
import time

class Timer(object):

    def __init__(self):

        self.start()

    def start(self):
        self.v = time.time()

    def end(self):
        return time.time() - self.v


def time_str(t):
    if t >= 3600:
        return '{:.1f}h'.format(t / 3600)
    if t > 60:
        return '{:.1f}m'.format(t / 60)
    return '{:.1f}s'.format(t)


def ensure_path(path, early_exit = False):
    if os.path.exists(path):
        if early_exit:
            if input('{:s} exists, continue? ([y]/n): '.format(path)) == 'n':
                sys.exit(0)
    else:
        os.makedirs(path)

# Function to load JSON data from a file
def load_json(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)
    
def save_json(data, file_path, indent = 4):
    with open(file_path, 'w') as file:
        json.dump(data, file, indent = indent)


def check_and_compare(model_id = "llama3_8b_ins", seq_len = 2048):
    directory = f"{model_id}/seq_len{seq_len}/generation"
    # Hard-coded k values and index ranges
    # k_values = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
    k_values=[32, 64, 128, 256, 512, 768, 1024, 2048]
    # index_ranges = [
    # "0_200",
    # "200_400",
    # "400_600",
    # "600_800",
    # ]
    index_ranges = [
    "0_100",
    "200_300",
    "400_500",
    "600_700",
    ]

    results = []

    timer = Timer()
    for idx,idx_range in enumerate(index_ranges):
        naive_filename = f"naive_{idx_range}.json"
        naive_filepath = os.path.join(directory, naive_filename)
        if os.path.exists(naive_filepath):
            naive_res = load_json(naive_filepath)
        else:
            print(f"    No corresponding naive file found for comparison")
        item = naive_res[0]
        item.update({
            "k": seq_len,
            "attn": "naive",
            "idx_range":idx_range,
        })
        
        results.append(item)

        for idx_k, k in enumerate(k_values):
            print(f"\nAnalyzing files for k = {k}")
            
            filename = f"conv_k_{k}_{idx_range}.json"
            
            filepath = os.path.join(directory, filename)
            if not os.path.exists(filepath):
                print(f"  File not found: {filename}")
                continue

            conv_res = load_json(filepath)

            item = conv_res[0]
            item.update({
                "k": k,
                "attn": "conv",
                "idx_range":idx_range,
            })
            results.append(item)        
        time_elapsed = timer.end()
        print(f"k {k}, time elapsed {time_str(time_elapsed)} | {time_str(time_elapsed/(idx+1)*len(index_ranges))}")

        
    save_json(results, f"{model_id}/seq_len{seq_len}/gen_res.json")

if __name__ == "__main__":
    model_id = "llama3_8b_ins"
    model_id = "mistral_7b_ins_v03"
    seq_len = 4096
    check_and_compare(model_id = model_id, seq_len = seq_len)