import argparse
import json
import logging
import os
import re

import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM

import fla  # noqa
from ssmworkbench.text_datamodule import TextArrowFileModule


@torch.inference_mode()
def evaluate_length_extrapolation(model, data_module, device, max_length):
    model.eval()
    total_loss_sum = torch.zeros(max_length - 1, device=device)
    total_accuracy_sum = torch.zeros(max_length - 1, device=device)
    total_count = 0
    per_token_losses = []

    if len(data_module.val_dataloader()) == 0:
        data_loader = data_module.train_dataloader()
    else:
        data_loader = data_module.val_dataloader()

    for idx, batch in tqdm(enumerate(data_loader), desc="Evaluating lengths", total=len(data_loader)):
        if idx > 20:
            break
        src_seq = batch["src_seq"].to(device)
        trg_seq = batch["trg_seq"].to(device)

        with torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
            outputs = model(src_seq)
        logits = outputs.logits

        # Compute loss for all tokens
        loss = torch.nn.functional.cross_entropy(
            logits.transpose(1, 2).float(), trg_seq, reduction="none"
        )

        # check if anything is nan
        if torch.isnan(loss).any():
            logging.error(f"Loss is nan for batch {idx}")
            continue

        # Compute accuracy for all tokens
        predictions = torch.argmax(logits, dim=-1)
        correct_predictions = predictions == trg_seq

        # Compute cumulative metrics
        cum_loss = torch.cumsum(loss, dim=1)
        cum_correct = torch.cumsum(correct_predictions.float(), dim=1)

        # Calculate running averages
        token_positions = torch.arange(1, cum_loss.size(1) + 1, device=device)
        avg_cum_loss = cum_loss / token_positions
        avg_cum_accuracy = cum_correct / token_positions

        # Sum across batch dimension
        total_loss_sum += avg_cum_loss.sum(dim=0)
        total_accuracy_sum += avg_cum_accuracy.sum(dim=0)
        total_count += src_seq.size(0)
        per_token_losses.append(loss.cpu().float().numpy())

    # Compute mean metrics for each length

    mean_losses = (total_loss_sum / total_count).cpu().numpy()
    mean_accuracies = (total_accuracy_sum / total_count).cpu().numpy()
    per_token_losses_avg = np.concatenate(per_token_losses).mean(axis=0)

    # Calculate perplexities
    perplexities = np.exp(mean_losses)

    return {
        # "lengths": lengths.tolist(),
        "perplexities": perplexities.tolist(),
        "accuracies": mean_accuracies.tolist(),
        "token_losses": per_token_losses_avg.tolist(),
    }


def _extract_step_from_name(name):
    patterns = [
        r"checkpoint[-_](\d+)",
        r"global_step[-_]?(\d+)",
        r"step[-_=](\d+)",
        r"iter[-_=](\d+)",
        r"epoch[-_=](\d+)",
        r"epoch=(\d+)",
        r"step=(\d+)",
        r"iter=(\d+)",
    ]
    steps = []
    for p in patterns:
        for match in re.findall(p, name):
            try:
                steps.append(int(match))
            except Exception:
                pass
    return max(steps) if steps else None


def _looks_like_hf_checkpoint_dir(path):
    try:
        if not os.path.isdir(path):
            return False
        files = set(os.listdir(path))
    except Exception:
        return False
    if "config.json" not in files:
        return False
    # Common HF weight file names
    for fname in files:
        if fname in {
            "pytorch_model.bin",
            "pytorch_model.safetensors",
            "model.safetensors",
            "consolidated.safetensors",
        }:
            return True
        if fname.startswith("pytorch_model-") and (
            fname.endswith(".bin")
            or fname.endswith(".safetensors")
            or fname.endswith(".bin.index.json")
        ):
            return True
    return False


def _find_latest_checkpoint_dir(run_dir):
    candidates = []
    # Include the run directory itself as a candidate if it is a checkpoint
    if _looks_like_hf_checkpoint_dir(run_dir):
        step = _extract_step_from_name(os.path.basename(run_dir))
        try:
            mtime = os.path.getmtime(run_dir)
        except Exception:
            mtime = 0.0
        candidates.append((step, mtime, run_dir))
    try:
        entries = list(os.scandir(run_dir))
    except Exception:
        entries = []
    for e in entries:
        if e.is_dir():
            if _looks_like_hf_checkpoint_dir(e.path):
                step = _extract_step_from_name(e.name)
                try:
                    mtime = e.stat().st_mtime
                except Exception:
                    mtime = 0.0
                candidates.append((step, mtime, e.path))
    if not candidates:
        return None
    def step_key(s):
        return -1 if s is None else s
    candidates.sort(key=lambda t: (step_key(t[0]), t[1]), reverse=True)
    return candidates[0][2]


def main():
    logging.basicConfig(level=logging.INFO)
    import git

    # Find repository root
    repo_root = git.Repo(search_parent_directories=True).working_tree_dir

    parser = argparse.ArgumentParser(
        description="Evaluate trained models on LM benchmarks"
    )
    parser.add_argument("--max_len", type=int, default=4096)
    parser.add_argument(
        "-p",
        "--path",
        type=str,
        default=os.path.join(
            repo_root,
            "<path to checkpoint>",
        ),
    )
    parser.add_argument(
        "--logs_dir",
        type=str,
        default=None,
        help="Directory containing run subdirectories with checkpoints.",
    )
    parser.add_argument("--data", type=str, default="codeparrot")
    # parser.add_argument("--model_name", type=str, default="dp3")
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--num_cpu_workers", type=int, default=8)
    args = parser.parse_args()

    # If logs_dir is provided, iterate runs and evaluate latest checkpoints
    if args.logs_dir:
        logs_dir = os.path.abspath(args.logs_dir)
        if not os.path.isdir(logs_dir):
            logging.error(f"Logs directory not found: {logs_dir}")
            return

        # If logs_dir has subdirectories, treat them as run dirs; otherwise treat logs_dir itself as a run dir
        child_run_dirs = [entry.path for entry in os.scandir(logs_dir) if entry.is_dir()]
        run_dirs = child_run_dirs if len(child_run_dirs) > 0 else [logs_dir]

        for run_dir in sorted(run_dirs):
            latest_ckpt = _find_latest_checkpoint_dir(run_dir)
            if latest_ckpt is None:
                logging.warning(f"No checkpoint found under {run_dir}; skipping")
                continue

            out_dir = os.path.join(run_dir, "length_extrapolation")
            os.makedirs(out_dir, exist_ok=True)
            out_path = os.path.join(out_dir, f"{args.data}_{args.max_len}.json")
            if os.path.exists(out_path):
                logging.info(f"Skipping {run_dir}; results already exist at {out_path}")
                continue

            logging.info(f"Evaluating latest checkpoint for {run_dir}: {latest_ckpt}")
            device = "cuda"
            dtype = torch.float
            torch.manual_seed(1337)

            model = AutoModelForCausalLM.from_pretrained(
                latest_ckpt, device_map={"": device}, torch_dtype=dtype
            )
            # model = torch.compile(model, dynamic=True)
            # import pdb; pdb.set_trace()
            model.eval()

            data_module = TextArrowFileModule(
                tokenizer="mistralai/Mistral-7B-v0.1",
                dataset_name=args.data,
                batch_size=args.batch_size,
                num_cpu_worker=args.num_cpu_workers,
                max_sample_len=args.max_len,
                data_dir=os.getenv("HF_HOME"),
                cache_dir=os.getenv("HF_DATASETS_CACHE"),
                val_ratio=0.0005,
                val_split_seed=1337,
                seed=1337,
            )
            results = evaluate_length_extrapolation(
                model, data_module, device, args.max_len
            )

            with open(out_path, "w") as f:
                json.dump({"results": results}, f)

        return


    model_name = args.path.split("Paninetto-Array")[-1].split("/")[0]
    if os.path.exists(
        os.path.join(
            repo_root,
            f"data/length_extrapolation/{args.data}_{args.max_len}_{model_name}.json",
        )
    ):
        logging.info(f"Skipping {args.data}_{args.max_len}_{model_name}.json")
        return

    logging.info(f"Model Checkpoint: {args.path}")
    device = "cuda"
    dtype = torch.float
    torch.manual_seed(1337)

    logging.info(f"Loading model {args.path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.path, device_map={"": device}, torch_dtype=dtype
    )
    model.eval()
    logging.info(f"{model}")
    logging.info(f"Model Checkpoint: {args.path}")
    device = "cuda"
    dtype = torch.float
    torch.manual_seed(1337)

    model.eval()

    data_module = TextArrowFileModule(
        tokenizer=args.path,
        dataset_name=args.data,
        batch_size=args.batch_size,
        num_cpu_worker=args.num_cpu_workers,
        max_sample_len=args.max_len,
        data_dir=os.getenv("HF_HOME"),
        cache_dir=os.getenv("HF_DATASETS_CACHE"),
        val_ratio=0.0005,
        val_split_seed=1337,
        seed=1337,
    )
    results = evaluate_length_extrapolation(
        model, data_module, device, args.max_len
    )

    # Save results as json
    output_dir = os.path.join(repo_root, "data/length_extrapolation")
    os.makedirs(output_dir, exist_ok=True)
    
    with open(
        os.path.join(
            output_dir,
            f"{args.data}_{args.max_len}_{model_name}.json",
        ),
        "w",
    ) as f:
        json.dump({"results": results}, f)


if __name__ == "__main__":
    main()
