import os
import random
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import List, MutableMapping, Tuple
from urllib.parse import urlparse

import boto3
import yaml
from tabulate import tabulate
from tqdm.auto import tqdm






def get_single_s3_size(s3_uri: str, s3_client=None) -> int:
    
    parsed = urlparse(s3_uri)
    bucket_name = parsed.netloc
    
    object_key = parsed.path.lstrip("/")
    try:
        s3_client = boto3.client("s3")
        response = s3_client.head_object(Bucket=bucket_name, Key=object_key)
        return response["ContentLength"]
    except Exception as e:
        if hasattr(e, "response") and e.response["Error"]["Code"] == "404":
            raise FileNotFoundError(f"The object {object_key} does not exist in bucket {bucket_name}")
        else:
            raise


def get_batch_s3_size(s3_uris: List[str]):
    
    s3_client = boto3.client("s3")

    def partial_size(s3_uri: str):
        size = get_single_s3_size(s3_uri, s3_client=s3_client)
        return s3_uri, size

    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(partial_size, uri) for uri in s3_uris]
        results = []
        for future in tqdm(futures, total=len(futures)):
            results.append(future.result())

    
    sizes = dict(results)
    return sizes


def list_s3_paths(s3_uri: str, extension: str = ".npy") -> List[Tuple[str, int]]:
    
    parsed = urlparse(s3_uri)
    bucket_name = parsed.netloc

    
    prefix = parsed.path.lstrip("/")

    s3_client = boto3.client("s3")

    
    if prefix and not prefix.endswith("/"):
        prefix += "/"

    
    if not extension.startswith("."):
        extension = "." + extension

    paths_and_sizes = []
    paginator = s3_client.get_paginator("list_objects_v2")

    try:
        
        for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix):
            if "Contents" not in page:
                continue

            for obj in page["Contents"]:
                key = obj["Key"]
                if key.endswith(extension):
                    paths_and_sizes.append((key, obj["Size"]))

        return paths_and_sizes

    except Exception as e:
        print(f"Error listing objects: {str(e)}")
        return []






BASE_YAML_STR = 


def human_format_number(num, decimal_places=2):
    
    abs_num = abs(num)
    sign = "-" if num < 0 else ""

    if abs_num < 1000:
        return f"{sign}{abs_num}"

    suffixes = ["", "K", "M", "B", "T"]
    magnitude = 0

    while abs_num >= 1000 and magnitude < len(suffixes) - 1:
        abs_num /= 1000
        magnitude += 1

    
    formatted = f"{abs_num:.{decimal_places}f}"

    return f"{sign}{formatted}{suffixes[magnitude]}"


def get_token_strs(token_source, bytes_per_token=4):
    if isinstance(token_source, str):
        s3_source = token_source
        ratio = 1.0
    else:
        s3_source, ratio = token_source

    paths_and_sizes = list_s3_paths(s3_source)
    parsed = urlparse(s3_source)
    bucket_name = parsed.netloc
    paths_and_sizes = [("s3://%s/%s" % (bucket_name, p), s) for p, s in paths_and_sizes]
    random.shuffle(paths_and_sizes)
    total_tokens = sum(_[1] for _ in paths_and_sizes) // bytes_per_token
    target_tokens = total_tokens * ratio

    paths_to_add = []
    tokens_to_add = 0
    for p, s in paths_and_sizes:
        paths_to_add.append(p)
        tokens_to_add += s // bytes_per_token
        if tokens_to_add >= target_tokens:
            break
    lines_to_add = ["
    for p in paths_to_add:
        lines_to_add.append("- %s" % p)
    return lines_to_add


def add_paths(token_sources, output_yaml_file, start_point="preanneal"):
    
    
    
    

    assert os.path.basename(output_yaml_file).startswith("peteish7-weka-microanneal")
    assert output_yaml_file.endswith(".yaml")

    assert start_point in ["preanneal", "megamath5000"]
    base_config_str = BASE_YAML_STR.replace(
        "REPLACE_RUN_NAME_HERE", os.path.splitext(os.path.basename(output_yaml_file))[0]
    )

    
    if start_point == "preanneal":
        base_config_str = base_config_str.replace(
            "REPLACE_PATH_HERE", "/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7/step928646"
        )
        base_config_str = base_config_str.replace("REPLACE_LR_HERE", "0.000061499")
    elif start_point == "megamath5000":
        base_config_str = base_config_str.replace(
            "REPLACE_PATH_HERE",
            "/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-weka-anneal-from-928646-50B-megamath_v1.1.yaml/step5000/",
        )
        new_lr = "%.09f" % (0.000061499 * (1 - (5000 / 11931)))
        base_config_str = base_config_str.replace("REPLACE_LR_HERE", new_lr)

    lines_to_add = []
    for source in token_sources:
        lines_to_add.extend(get_token_strs(source))
    true_lines_to_add = ["\n    %s" % line for line in lines_to_add]
    output_str = base_config_str + "".join(true_lines_to_add)
    with open(output_yaml_file, "w") as f:
        f.write(output_str)


def examine_config(yaml_file, bytes_per_token=4):
    

    print("Getting tokens per input file...")
    
    with open(yaml_file, "r") as f:
        yaml_content = yaml.safe_load(f)
    paths = yaml_content.get("data", {}).get("paths", [])
    paths_to_tokens = {k: v // bytes_per_token for k, v in get_batch_s3_size(paths).items()}

    
    print("Grouping output files into groups...")
    groups = set(_read_path_comments(yaml_file))

    def get_group(s3_uri):
        for g in groups:
            if s3_uri.startswith(g):
                return g
        raise Exception("UNKNOWN GROUP FOR %s" % s3_uri)

    tokens_taken: MutableMapping[str, int] = defaultdict(int)
    for p, tok in paths_to_tokens.items():
        tokens_taken[get_group(p)] += tok

    
    print("Getting total group sizes...")
    total_tokens = {}
    for g in tqdm(groups):
        paths_and_sizes = list_s3_paths(g)
        total_tokens[g] = sum(_[1] for _ in paths_and_sizes) // bytes_per_token
    print("TOTAL_TOKENS", total_tokens)
    
    ratios = {
        g: "%.04f" % (tokens_taken[g] / total_tokens[g]) for g in groups
    }  

    
    rows = sorted([(g, total_tokens[g], ratios[g], tokens_taken[g]) for g in groups])
    print("Put this in your spreadsheet!")
    print(tabulate(rows, headers=["paths", "total_tokens", "percentage taken", "tokens taken"]))


def _read_path_comments(yaml_file):
    
    lines = open(yaml_file, "r").readlines()
    path_sources = []
    seen_paths = False
    for line in lines:
        if not seen_paths and line.strip() != "paths:":
            continue
        elif line.strip() == "paths:":
            seen_paths = True
        elif line.strip().startswith("
            path_sources.append(line.strip().split(" ")[1])
        else:
            pass
    return path_sources







if __name__ == "__main__":
    
