from __future__ import annotations

import csv
import math
import os
import re
import logging as py_logging
from typing import Dict, Optional

from absl import app, flags, logging


COLUMN_NAMES = [
    "training_data",
    "imb_factor",
    "model",
    "norm",
    "reg_eps",
    "tau2",
    "w_pow",
    "coupling_imp",
    "fid(bal)",
    "fid(lt)",
    "prec(bal)",
    "recall(bal)",
    "F1(bal)",
    "prec(lt)",
    "recall(lt)",
    "F1(lt)",
]

# Model name mapping table - modify this to change model name transformations
MODEL_NAME_MAPPING = {
    "sinkhorn_otcfm": "uot-cfm",
    "sinkhorn_otwfm": "uot-wfm(1/tnu)",
    "otcfm": "ot-cfm",
    # Add more mappings as needed:
    # "original_name": "display_name",
}


# Flags (ABSL)
FLAGS = flags.FLAGS
flags.DEFINE_string(
    "root_dir",
    None,
    "Root directory to recursively search for log files.",
)
flags.DEFINE_string(
    "csv_path",
    None,
    "Output CSV path. Create if missing; append if exists.",
)
flags.DEFINE_string(
    "log_filename",
    "log.txt",
    "Target log filename to search (default: log.txt).",
)
flags.DEFINE_string(
    "log_level",
    "INFO",
    "Logger verbosity: DEBUG, INFO, WARNING, ERROR, CRITICAL.",
)


def ensure_csv_with_header(csv_path: str) -> None:
    file_exists = os.path.exists(csv_path)
    needs_header = (not file_exists) or (os.path.getsize(csv_path) == 0)
    if needs_header:
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        with open(csv_path, mode="a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(COLUMN_NAMES)


def floor_to_decimals(value: float, num_decimals: int = 4) -> float:
    if num_decimals < 0:
        raise ValueError("num_decimals must be >= 0")
    factor = 10 ** num_decimals
    return math.floor(value * factor) / factor


def parse_float_safe(token: str) -> Optional[float]:
    try:
        return float(token)
    except Exception:
        return None


def parse_bool_safe(token: str) -> Optional[bool]:
    try:
        v = str(token).strip().lower()
    except Exception:
        return None
    if v in {"true", "1", "yes", "y", "t"}:
        return True
    if v in {"false", "0", "no", "n", "f"}:
        return False
    return None


def parse_log_file(log_path: str) -> Dict[str, Optional[str]]:
    # Initialize defaults
    result: Dict[str, Optional[str]] = {
        "training_data": "",
        "imb_factor": "",
        "model": "",
        "norm": "",
        "reg_eps": "",
        "tau2": "",
        "w_pow": "",
        "coupling_imp": "",
        "fid(bal)": "",
        "fid(lt)": "",
        "prec(bal)": "",
        "recall(bal)": "",
        "F1(bal)": "",
        "prec(lt)": "",
        "recall(lt)": "",
        "F1(lt)": "",
    }

    # Prepare patterns
    key_value_pattern = re.compile(r"^([a-zA-Z0-9_?]+):\s*(.*)\s*$")
    fid_pattern = re.compile(
        r"FID measured on:\s*([^\s]+)\s*->\s*score:\s*([+-]?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?)"
    )
    pr_pattern = re.compile(
        r"measured on:\s*([^\s]+)\s*->\s*precision:\s*([0-9]*\.?[0-9]+)\s*recall:\s*([0-9]*\.?[0-9]+)"
    )

    model_name: Optional[str] = None
    w_pow_value: Optional[str] = None
    fixed_source_flag: Optional[bool] = None
    fixed_target_flag: Optional[bool] = None

    with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
        for raw_line in f:
            line = raw_line.strip()

            # 1) Simple key: value lines (includes FLAGS area)
            m = key_value_pattern.match(line)
            if m:
                key = m.group(1)
                value = m.group(2)

                if key == "dataset_name":
                    result["training_data"] = value
                elif key == "imb_factor":
                    result["imb_factor"] = value
                elif key == "model":
                    result["model"] = value
                    model_name = value
                elif key == "data_norm":
                    result["norm"] = value
                elif key == "reg":
                    # reg_eps will be post-processed by model name rule
                    result["reg_eps"] = value
                elif key == "tau_b":
                    result["tau2"] = value
                elif key == "weight_power_factor":
                    w_pow_value = value
                elif key == "fixed_source":
                    b = parse_bool_safe(value)
                    if b is not None:
                        fixed_source_flag = b
                elif key == "fixed_target":
                    b = parse_bool_safe(value)
                    if b is not None:
                        fixed_target_flag = b

            # 2) FID line
            m = fid_pattern.search(line)
            if m:
                measured_on = m.group(1)
                score_str = m.group(2)
                score_val = parse_float_safe(score_str)
                if score_val is not None:
                    score_val = floor_to_decimals(score_val, 4)
                    score_out = f"{score_val:.4f}"
                    if measured_on in {"cifar10", "cifar100"}:
                        result["fid(bal)"] = score_out
                    elif measured_on.endswith("_lt"):
                        result["fid(lt)"] = score_out

            # 3) precision/recall line
            m = pr_pattern.search(line)
            if m:
                measured_on = m.group(1)
                p_str = m.group(2)
                r_str = m.group(3)
                p_val = parse_float_safe(p_str)
                r_val = parse_float_safe(r_str)
                if p_val is not None and r_val is not None:
                    p_val = floor_to_decimals(p_val, 4)
                    r_val = floor_to_decimals(r_val, 4)
                    p_out = f"{p_val:.4f}"
                    r_out = f"{r_val:.4f}"
                    if measured_on in {"cifar10", "cifar100"}:
                        result["prec(bal)"] = p_out
                        result["recall(bal)"] = r_out
                    elif measured_on.endswith("_lt"):
                        result["prec(lt)"] = p_out
                        result["recall(lt)"] = r_out

    # Post-process by model name
    if model_name is None:
        model_name = result["model"] or ""

    # Apply model name mapping
    original_model = model_name
    if original_model in MODEL_NAME_MAPPING:
        result["model"] = MODEL_NAME_MAPPING[original_model]
        model_name = result["model"]

    lower_model = (model_name or "").lower()
    original_lower = (original_model or "").lower()

    # reg_eps, tau2: blank only when model equals icfm or otcfm (exact match)
    if original_lower in {"icfm", "otcfm"}:
        result["reg_eps"] = ""
        result["tau2"] = ""

    # w_pow: fill only if model contains otwfm (check original name for logic)
    if "otwfm" in original_lower and w_pow_value is not None:
        result["w_pow"] = w_pow_value
    else:
        result["w_pow"] = ""

    # coupling_imp: derive from fixed_source/fixed_target flags
    if (fixed_source_flag is True) and (fixed_target_flag is True):
        result["coupling_imp"] = "fixed_both"
    elif fixed_source_flag is True:
        result["coupling_imp"] = "fixed_source"
    elif fixed_target_flag is True:
        result["coupling_imp"] = "fixed_target"
    else:
        result["coupling_imp"] = ""

    # imb_factor: default to 0.01 if LT dataset and missing; else blank
    is_lt_dataset = result["training_data"].endswith("_lt")
    if is_lt_dataset:
        if not result["imb_factor"]:
            result["imb_factor"] = "0.01"
    else:
        result["imb_factor"] = ""

    # Compute F1 scores if precision/recall available
    def compute_f1(precision_str: str, recall_str: str) -> str:
        p = parse_float_safe(precision_str)
        r = parse_float_safe(recall_str)
        if p is None or r is None:
            return ""
        denom = p + r
        if denom == 0:
            return ""
        f1 = 2.0 * (p * r) / denom
        f1 = floor_to_decimals(f1, 4)
        return f"{f1:.4f}"

    if result["prec(bal)"] and result["recall(bal)"]:
        result["F1(bal)"] = compute_f1(result["prec(bal)"], result["recall(bal)"])
    if result["prec(lt)"] and result["recall(lt)"]:
        result["F1(lt)"] = compute_f1(result["prec(lt)"], result["recall(lt)"])

    return result


def write_row(csv_path: str, row: Dict[str, Optional[str]]) -> None:
    ensure_csv_with_header(csv_path)
    with open(csv_path, mode="a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([row.get(col, "") for col in COLUMN_NAMES])


def find_all_logs(root_dir: str, log_filename: str) -> list[str]:
    matches: list[str] = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        if log_filename in filenames:
            matches.append(os.path.join(dirpath, log_filename))
    return matches


def main(argv) -> None:
    # Configure logging verbosity using stdlib levels (absl accepts ints)
    level_map = {
        "DEBUG": py_logging.DEBUG,
        "INFO": py_logging.INFO,
        "WARNING": py_logging.WARNING,
        "ERROR": py_logging.ERROR,
        "CRITICAL": py_logging.CRITICAL,
        "FATAL": py_logging.FATAL,
    }
    log_level_name = (FLAGS.log_level or "INFO").upper()
    logging.set_verbosity(level_map.get(log_level_name, py_logging.INFO))

    csv_path = FLAGS.csv_path or os.path.join(os.getcwd(), "gathered_logs.csv")
    # If csv_path is a directory (or looks like one), append default filename
    if os.path.isdir(csv_path) or csv_path.endswith(os.sep):
        csv_path = os.path.join(csv_path, "gathered_logs.csv")
    logs = find_all_logs(FLAGS.root_dir, FLAGS.log_filename)
    if not logs:
        logging.warning("No log files found under: %s", FLAGS.root_dir)

    for log_path in sorted(logs):
        try:
            row = parse_log_file(log_path)
            write_row(csv_path, row)
            logging.info("Wrote row to CSV from: %s", log_path)
        except Exception:
            logging.exception("Error while parsing log file: %s", log_path)


if __name__ == "__main__":
    flags.mark_flag_as_required("root_dir")
    app.run(main)


"""
python log_gathering.py \
  --root_dir=results_cifar100_lt_imb0.001 \
  --csv_path=./ \
  --log_filename=log.txt \
  --log_level=INFO
"""