#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import re
import json
import ast
import argparse
import random
import sys
import torch
import eval_clip
from eval_clip import Tester
import matplotlib.pyplot as plt  # New import for plotting

# ------------------------------------------------------------------------------
# 1) ARGUMENT PARSING WITH DEFAULT PATHS
# ------------------------------------------------------------------------------
def parse_args():
    """
    Parses command-line arguments and sets default paths for logs, results, etc.
    This includes multi-epoch arguments for interval-based and local-max selection.
    """
    model_name = 'ensemble-2025-02-10-14-57-22_0_fts10_lr1e-05'
    default_model_dir = ""
    
    assert os.path.exists(default_model_dir), f"Default model dir not found: {default_model_dir}"

    parser = argparse.ArgumentParser("Eval Clip - Multi-Epoch")

    # ------------------------------
    #   Standard / existing args
    # ------------------------------
    parser.add_argument("--dataset", type=str, default="vidvrd-dataset",
                        choices=["vidvrd-dataset", "ActionGenome"],
                        help="Dataset to evaluate (VidVRD or ActionGenome).")
    parser.add_argument("--phase", type=str, default='eval',
                        help="Phase: 'eval', 'test', 'cache_test', etc.")
    parser.add_argument("--load-model", default=True, action="store_true",
                        help="Whether to load from checkpoint.")
    parser.add_argument("--save-model", default=False, action="store_true",
                        help="Whether to save the model (unused if just eval).")
    parser.add_argument("--clip-model-name", type=str, default="openai/clip-vit-base-patch32")
    parser.add_argument("--test-num-top-pairs", type=int, default=30)
    parser.add_argument("--max-video-len", type=int, default=999999)
    parser.add_argument("--train-num", type=int, default=5000)
    parser.add_argument("--val-num", type=int, default=1000)
    parser.add_argument("--test-percentage", type=int, default=100)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--seed", type=int, default=123)
    parser.add_argument("--model-name", type=str, default=model_name)
    parser.add_argument("--model-dir", type=str, default=default_model_dir,
                        help="Path to the base directory containing checkpoints/logs.")
    parser.add_argument("--use-cuda", default=True, action="store_true")
    parser.add_argument("--use-half", action="store_true")
    parser.add_argument("--use-ddp", action="store_true")
    parser.add_argument("--gpu", type=int, default=0,
                        help="GPU device index; -1 for CPU.")
    parser.add_argument("--rel_top_k", type=int, default=1)
    parser.add_argument("--debug", default="better_alpha_with_optim_random",
                        help="Used to construct debug subfolder paths.")
    parser.add_argument("--splice-start", type=int, default=0)
    parser.add_argument("--splice-size", type=int, default=1)
    parser.add_argument("--sgdet", default=False, action="store_true",
                        help="Use SGDET mode in the dataset/tester pipeline.")

    # ------------------------------
    #   Multi-epoch args
    # ------------------------------
    parser.add_argument("--epoch-start", type=int, default=15,
                        help="Earliest epoch to consider in multi-epoch eval.")
    parser.add_argument("--epoch-interval", type=int, default=5,
                        help="Step size for interval-based epoch selection.")
    parser.add_argument("--epoch-end", type=int, default=126,
                        help="Last epoch to consider (inclusive). If not set, use the maximum epoch found.")
    parser.add_argument("--use-local-max", default=True,
                        help="If set, only evaluate epochs that reach a new highest training binary@1 so far.")
    parser.add_argument("--use-interval", default=True,
                        help="If set, use interval-based epoch selection.")

    args = parser.parse_args()

    # -------------------------------------------------------------------------
    #  Define data_dir and typical subfolders for reports/results
    # -------------------------------------------------------------------------
    args.data_dir = os.path.abspath(
        os.path.join(os.path.abspath(__file__), f"../../../../../data_local/{args.dataset}")
    )
    args.video_save_dir = os.path.join(args.data_dir, 'pred_video')

    base_reports_dir = os.path.abspath(
        os.path.join(os.path.abspath(__file__),
                     f"../../../../../data_local/LLaVA-Video-178K{'/eval/'+args.debug if args.debug else ''}/reports/{args.dataset}")
    )
    base_results_dir = os.path.abspath(
        os.path.join(os.path.abspath(__file__),
                     f"../../../../../data_local/LLaVA-Video-178K{'/eval/'+args.debug if args.debug else ''}/results/{args.dataset}")
    )
    base_vis_dir = os.path.abspath(
        os.path.join(os.path.abspath(__file__),
                     f"../../../../../data_local/LLaVA-Video-178K{'/eval/'+args.debug if args.debug else ''}/visualize/{args.dataset}")
    )
    mname = args.model_name if args.model_name else "unnamed"

    args.report_dir = base_reports_dir
    args.result_dir = base_results_dir
    args.html_dir   = base_vis_dir

    os.makedirs(args.report_dir, exist_ok=True)
    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.html_dir, exist_ok=True)

    print(f"[Paths] result_dir: {args.result_dir}")
    print(f"[Paths] report_dir: {args.report_dir}")
    print(f"[Paths] html_dir:   {args.html_dir}")

    torch.manual_seed(args.seed)
    random.seed(args.seed)

    return args


# ------------------------------------------------------------------------------
# 2) PARSE TRAINING LOG TO EXTRACT training_binary_at_1
# ------------------------------------------------------------------------------
def parse_training_metrics_from_log(log_path):
    """
    Reads the training log file line-by-line, looking for lines like:
      End of Epoch EPOCH_NUM/TOTAL - Avg Loss: ...
    followed by:
      Epoch EPOCH_NUM metrics:
      { ...dict containing "precision": {"binary": {1: some_val}}... }
    Returns a dict: { epoch_number -> float (training binary@1) }
    """
    training_binary_at_1 = {}
    if not os.path.exists(log_path):
        print(f"[WARNING] Log file not found at: {log_path}")
        return training_binary_at_1

    with open(log_path, 'r') as f:
        lines = f.readlines()

    current_epoch = None
    for i, line in enumerate(lines):
        m = re.search(r"End of Epoch\s+(\d+)\s*/\s*(\d+)", line)
        if m:
            current_epoch = int(m.group(1))
            continue

        if current_epoch is not None and f"Epoch {current_epoch} metrics:" in line:
            if i+1 < len(lines):
                dict_str = lines[i+1].strip()
                # Remove np.float64 wrappers: replace np.float64(x) with x
                dict_str = re.sub(r"np\.float64\(([^)]+)\)", r"\1", dict_str)
                try:
                    metrics_dict = ast.literal_eval(dict_str)
                    train_bin_1 = metrics_dict["precision"]["binary"][1]
                    training_binary_at_1[current_epoch] = float(train_bin_1)
                except Exception as e:
                    print(f"[WARNING] Could not parse epoch {current_epoch} metrics: {e}")
            current_epoch = None

    return training_binary_at_1


# ------------------------------------------------------------------------------
# 3) SELECT EPOCHS TO EVALUATE
# ------------------------------------------------------------------------------
def select_epochs_to_evaluate(training_binary_at_1, use_interval=False, epoch_start=1, epoch_interval=1, epoch_end=None, use_local_max=False):
    """
    Select epochs from 'training_binary_at_1' (dict: {epoch -> training_binary@1 float}).

    - If 'use_interval' is True: build the interval set:
         { e | e = epoch_start + n*epoch_interval, e <= min(max_epoch, epoch_end) }.
      Otherwise, the interval set is all epochs within [epoch_start, epoch_end].

    - If 'use_local_max' is True: build the set of epochs (within [epoch_start, epoch_end]) that exceed all previous epochs' training_binary@1.
      Otherwise, the local-max set is all epochs within that range.

    The final set = intersection of the two sets.
    Returns a sorted list of selected epoch numbers.
    """
    if not training_binary_at_1:
        print("[INFO] No epochs found in the log. Returning empty set.")
        return []

    all_epochs = sorted(training_binary_at_1.keys())
    max_epoch = max(all_epochs)
    if epoch_end is None:
        epoch_end = max_epoch
    else:
        epoch_end = min(epoch_end, max_epoch)

    if use_interval:
        if epoch_interval <= 0:
            raise ValueError("epoch_interval must be positive if use_interval=True.")
        interval_set = []
        e = epoch_start
        while e <= epoch_end:
            if e in training_binary_at_1:
                interval_set.append(e)
            e += epoch_interval
    else:
        interval_set = [e for e in all_epochs if e >= epoch_start and e <= epoch_end]

    if use_local_max:
        local_max_set = []
        best_so_far = float('-inf')
        for e in all_epochs:
            if e < epoch_start or e > epoch_end:
                continue
            val = training_binary_at_1[e]
            if val > best_so_far:
                local_max_set.append(e)
                best_so_far = val
    else:
        local_max_set = [e for e in all_epochs if e >= epoch_start and e <= epoch_end]

    final_set = sorted(set(interval_set).union(local_max_set))
    print(f"Models to evaluate: {len(final_set)}")
    return final_set


# ------------------------------------------------------------------------------
# 4) EVALUATE A SINGLE EPOCH
# ------------------------------------------------------------------------------
def evaluate_epoch(args, epoch, loader):
    """
    Evaluate one specific epoch using an existing loader (validation or test).

    1) Re-points args.report_dir, args.result_dir, and args.html_dir so that outputs go into
       a subfolder named <model_name>.<epoch>.
    2) Instantiates a Tester and runs tester.eval().

    :param args: The argparse Namespace.
    :param epoch: The specific epoch number to evaluate.
    :param loader: A PyTorch DataLoader (e.g., validation or test loader).
    """
    base_reports_parent = os.path.dirname(args.report_dir)
    base_results_parent = os.path.dirname(args.result_dir)
    base_html_parent    = os.path.dirname(args.html_dir)

    args.model_epoch = epoch
    mname = args.model_name if args.model_name else "unnamed"
    subfolder_name = f"{mname}.{epoch}"

    args.report_dir = os.path.join(base_reports_parent, subfolder_name)
    args.result_dir = os.path.join(base_results_parent, subfolder_name)
    args.html_dir   = os.path.join(base_html_parent, subfolder_name)

    os.makedirs(args.report_dir, exist_ok=True)
    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.html_dir, exist_ok=True)

    print(f"\n[EVAL] Evaluating epoch {epoch}:")
    print(f"       report_dir => {args.report_dir}")
    print(f"       result_dir => {args.result_dir}")
    print(f"       html_dir   => {args.html_dir}")

    if args.gpu >= 0 and torch.cuda.is_available():
        device = args.gpu
    else:
        device = "cpu"

    tester = Tester(
        test_loader=loader,
        device=device,
        dataset=args.dataset,
        model_dir=args.model_dir,
        model_name=args.model_name,
        model_epoch=args.model_epoch,
        load_model=args.load_model,
        video_save_dir=args.video_save_dir,
        test_num_top_pairs=args.test_num_top_pairs,
        report_dir=args.report_dir,
        result_dir=args.result_dir,
        clip_model_name=args.clip_model_name,
        use_half=args.use_half,
        world_size=1,            # Adjust as needed
        use_ddp=args.use_ddp
    )

    tester.eval()
    print(f"[EVAL] Finished evaluating epoch {epoch}")


# ------------------------------------------------------------------------------
# 5) GENERATE A TOP-LEVEL SUMMARY AND PLOT PERFORMANCE
# ------------------------------------------------------------------------------
def generate_multi_epoch_summary(args, evaluated_epochs):
    """
    Scans the report directory for any evaluated epoch subfolders for the current model,
    reads <model_name>.<epoch>.metrics_report.txt, and aggregates them into one JSON summary.
    Also selects a "best" epoch by binary recall@10 and ranks all epochs by that metric.
    Additionally, creates a line graph image of bin@1 performance over epochs.
    """
    summary = {}
    best_metric_val = float('-inf')
    best_epoch = None
    ranking = []  # list of tuples (epoch, metric_value)

    base_reports_parent = os.path.dirname(args.report_dir)
    mname = args.model_name if args.model_name else "unnamed"

    # Scan the base_reports_parent for any subfolders matching "<model_name>.<epoch>"
    pattern = re.compile(rf"^{re.escape(mname)}\.(\d+)$")
    found_epochs = []
    for item in os.listdir(base_reports_parent):
        item_path = os.path.join(base_reports_parent, item)
        if os.path.isdir(item_path):
            match = pattern.match(item)
            if match:
                found_epochs.append(int(match.group(1)))

    # For each found epoch, try to read its metrics file
    for ep in found_epochs:
        this_ep_dir = os.path.join(base_reports_parent, f"{mname}.{ep}")
        final_report_path = os.path.join(this_ep_dir, f"{mname}.{ep}.metrics_report.txt")
        if not os.path.exists(final_report_path):
            print(f"[WARNING] No metrics_report.txt found for epoch {ep} at {final_report_path}")
            continue

        with open(final_report_path, 'r') as f:
            data = json.load(f)
        summary[ep] = data

        candidate_val = data["precision"]["binary"].get("1", 0.0)
        ranking.append((ep, candidate_val))
        if candidate_val > best_metric_val:
            best_metric_val = candidate_val
            best_epoch = ep

    # Sort ranking list descending by metric value
    ranking_sorted = [ep for ep, val in sorted(ranking, key=lambda x: x[1], reverse=True)]

    top_level_summary_path = os.path.join(
        base_reports_parent,
        f"{mname}_multi_epoch_summary.json"
    )
    out_data = {
        "evaluated_epochs": sorted(list(summary.keys())),
        "best_epoch": best_epoch,
        "best_metric_value": best_metric_val,
        "ranking": ranking_sorted,
        "metrics": summary
    }
    with open(top_level_summary_path, 'w') as outf:
        json.dump(out_data, outf, indent=2)

    print(f"[SUMMARY] Wrote multi-epoch summary to {top_level_summary_path}")

    # ------------------------------
    # New: Create a plot of bin@1 performance over epochs
    # ------------------------------
    if summary:
        epochs = sorted(summary.keys())
        performance = [summary[ep]["precision"]["binary"].get("1", 0.0) for ep in epochs]

        plt.figure()
        plt.plot(epochs, performance, marker='o', linestyle='-', color='blue')
        plt.xlabel("Epoch")
        plt.ylabel("Bin@1 Performance")
        plt.title(f"Bin@1 Performance Over Epochs for {mname}")
        plt.grid(True)
        plot_path = os.path.join(base_reports_parent, f"{mname}_bin1_performance.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"[PLOT] Saved bin@1 performance plot to {plot_path}")
    else:
        print("[PLOT] No summary data available to generate plot.")


# ------------------------------------------------------------------------------
# 6) MAIN SCRIPT
# ------------------------------------------------------------------------------
def main():
    # (a) Parse arguments
    args = parse_args()

    # (c) Read the training log to get training binary@1 for each epoch
    log_path = os.path.join(args.model_dir, "logs", f"{args.model_name}.model.log")
    train_bin_dict = parse_training_metrics_from_log(log_path)
    if not train_bin_dict:
        print("[INFO] No epochs found in the log or no metrics parsed. Exiting.")
        return

    # (d) Determine which epochs to evaluate
    selected_epochs = select_epochs_to_evaluate(
        training_binary_at_1=train_bin_dict,
        use_interval=args.use_interval,
        epoch_start=args.epoch_start,
        epoch_interval=args.epoch_interval,
        epoch_end=args.epoch_end,
        use_local_max=args.use_local_max
    )
    print(f"[INFO] Selected epochs: {selected_epochs}")
    if not selected_epochs:
        print("[INFO] No epochs selected; nothing to do.")
        return

    # (e) Build or retrieve your dataset/loader once; re-use for all epochs
    from vidvrd_dataset import open_vidvrd_loader

    loader_args = {
        "dataset_dir": args.data_dir,
        "batch_size": args.batch_size,
        "device": args.gpu if args.gpu >= 0 else "cpu",
        "training_percentage": 100,
        "testing_percentage": args.test_percentage,
        "max_video_len": args.max_video_len,
        "neg_kws": False,
        "neg_spec": False,
        "neg_example_ct": 0,
        "neg_example_file_name": "neg_examples.json",
        "backbone_model": "clip",
        "sampler": None,
        "splice_start": args.splice_start,
        "splice_size": args.splice_size,
    }
    train_dataset, valid_dataset, test_dataset, train_loader, valid_loader, test_loader = \
        open_vidvrd_loader(**loader_args)

    if args.phase == "eval":
        eval_loader = valid_loader
    else:
        eval_loader = test_loader

    # (f) Evaluate each selected epoch and generate summary after each evaluation.
    #     Also, the summary function now scans the directory for any evaluated epochs.
    for ep in selected_epochs:
        evaluate_epoch(args, ep, loader=eval_loader)
        generate_multi_epoch_summary(args, evaluated_epochs=[])  # Empty list so summary scans directory

    print("[INFO] Multi-epoch evaluation complete.")


if __name__ == "__main__":
    main()
