# Start main imports
import argparse
import random
from pathlib import Path
import os
# Set this for deterministic behavior in CUDA operations
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import numpy as np
import torch
from compression.svd_core import ModelFactorizer

from llm_utils.data_utils import *
from llm_utils.model_utils import *
from llm_utils.evaluater import *

def get_args_parser():
    parser = argparse.ArgumentParser(
        "LEMS and KFAC-SVD compression and evaluation script", add_help=False
    )
    parser.add_argument("--batch-size", default=256, type=int)
    parser.add_argument("--calib_bs", default=128, type=int, help="Batchsize for data based svd calibration.")
    parser.add_argument("--search_samples", default=32, type=int, help="Batchsize for data based svd search.")
    parser.add_argument("--calib_dataset", default="wikitext2", type=str, help="Dataset to use for calibration. Default is wikitext2.", choices=["wikitext2", "fineweb_16M"])
    parser.add_argument("--seq_len", default=256, type=int, help="Sequence length for LLMs. Default is 256.")
    parser.add_argument("--extended_eval", action="store_true", default=False, help="Whether to do extended evaluation on LLMs.")
    parser.add_argument("--calib_data_mode", default="v1", type=str, help="Which version of calibration data to use.", choices=["v1", "v2"])

    # Model parameters
    parser.add_argument(
        "--model",
        default="unsloth/llama-3-8b",
        type=str,
        metavar="MODEL",
        help="Name of model to compress and evaluate",
    )
    # use this for large models.
    parser.add_argument("--gradient_ckpt", action="store_true", default=False, help="Use gradient checkpointing")

    # Dataset parameters
    parser.add_argument(
        "--output_dir", default="./results/", help="path where to save, empty for no saving"
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=42, type=int)
    
    parser.add_argument("--sensitivity_loss", default="kl", type=str, choices=["kl", "energy", "energy1", "energy2", "energy2_normal", "energy2_normal_klscaled", "energy2_normal_msescaled", "energy2_normal_pplscaled", "mse", "ce", "ppl"],)
    parser.add_argument("--measurements_points", default="asvd_default", type=str, choices=["0.1", "0.3", "0.5", "0.7", "0.85", "0.1-0.9", "0.2-0.9", "0.1-0.9uneven", "asvd_default", "gfwsvd"],)

    # SVD parameters
    parser.add_argument(
        "--svd_method",
        default="kfac_svd",
        choices=["svd_llm", "svd_llmv2", "svd", "fwsvd", "gfwsvd", "asvd", "kfac_svd", "dobi_svd", "baseline"],
        type=str,
        help="SVD compression method to use.",
    )
    parser.add_argument(
        "--search_method",
        default="lems",
        choices=["asvd", "uniform", "lems", "svd_llmv2", "memvit", "loadconfig", "atp"],
        type=str,
        help="SVD compression method to use.",
    )
    parser.add_argument(
        "--name_omit", default=["norm", "patch_embed", "head", "downsample", "decoder.project_"], type=list
    )
    # SVD settings
    parser.add_argument("--compression_target", default=0.8, type=float, help="compression target ratio")
    parser.add_argument(
        "--target_metric",
        default="params",
        choices=["params", "flops"],
        type=str,
        help="Metric to optimize based on target.",
    )
    parser.add_argument("--progressive_comp", action="store_true", help="Use progressive compression")
    parser.add_argument("--do_post_calibration", default="default", type=str, help="Whether to do post search calibration or not. Default depends on method.", choices=["default", "True", "False"])
    parser.add_argument("-uc", "--use_cache", action="store_true", default=False, help="whether to use cache for asvd_plus")
    parser.add_argument("--re_model", action="store_true", default=False, help="whether to use restore entire model if checkpoint exists.")
    # ASVD settings
    parser.add_argument("--asvd_alpha", default=1.0, type=float, help="alpha for ASVD")
    # lems settings
    parser.add_argument("--crosslayer_term", default="harmonicv2", type=str, choices=["harmonic", "harmonicv2", "linear", "constant"], help="cross layer term")
    parser.add_argument("--halpha", default=1.0, type=float, help="alpha for harmonic cross layer term")
    parser.add_argument("--hgamma", default=1.0, type=float, help="gamma for harmonic cross layer term")
    parser.add_argument("--enforce_rank_multiples_of", default=8, type=int, help="If set, will enforce ranks to be multiples of this number.")
    parser.add_argument("--solver", default="gurobi", type=str, choices=["cbc", "gurobi"], help="Solver to use for LEMS optimization.")
    # ATP search settings
    parser.add_argument("--beta", default=0.01, type=float, help="Slope for atp's linearly decreasing values.")
    # USe config instead of search settings
    parser.add_argument("--sconfig_path", default="gfwsvd_llama_20.json", type=str, help="Path to config file for layer compression. If set, will use this config instead of search method.")
    
    return parser

def enforce_strict_determinism(seed=42):
    # 1. Standard Seeding
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # 2. Disable "Auto-Tuning" of Algorithms
    # Prevents cuDNN from benchmarking to pick the "fastest" algo (which can vary)
    torch.backends.cudnn.benchmark = False
    
    # 3. Force Deterministic Algorithms
    # Tells PyTorch to only use deterministic implementations (throws error if none exist)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)


def main(args):
    device = torch.device(args.device)
    # fix the seed for reproducibility
    seed = args.seed
    enforce_strict_determinism(seed=seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    print(f"Creating model: {args.model}")
    model, tokenizer = get_model_from_huggingface(
        model_id=args.model, seq_len=args.seq_len, 
        grad_ckpt=args.gradient_ckpt,  # signals efficient compression required to pipeline.
        cache_dir=os.path.join(args.output_dir, "huggingface/llm"))

    # model = model.to("cuda")
    svd_method_args = {}
    svd_method_args["vision"] = False
    if args.svd_method == "asvd" or args.svd_method == "fwsvd":
        svd_method_args["alpha"] = args.asvd_alpha
        
    svd_method_args["use_cache"] = args.use_cache
    svd_method_args["progressive_compression"] = args.progressive_comp
    svd_method_args["do_post_calibration"] = args.do_post_calibration
    
    search_args = {}
    search_args["ratio_target"] = args.compression_target
    # sensitivity based approaches
    if (args.search_method == "asvd" or args.search_method == "lems"):
        search_args["sensitivity_loss"] = args.sensitivity_loss
        search_args["measurements_points"] = args.measurements_points
        search_args["sequence_length"] = args.seq_len
        search_args["use_cache"] = args.use_cache
    if args.search_method == "lems":
        search_args["solver"] = args.solver
        search_args["crosslayer_term"] = args.crosslayer_term
        search_args["enforce_rank_multiples_of"] = args.enforce_rank_multiples_of
        search_args["halpha"] = args.halpha
        search_args["hgamma"] = args.hgamma
    if args.search_method == "loadconfig":
        search_args["layer_compression_json_path"] = args.sconfig_path
    if args.search_method == "atp":
        search_args["beta"] = args.beta

    model = model.eval()
    calib_data = get_calib_train_data(
        args.calib_dataset,
        tokenizer,
        nsamples=args.calib_bs,
        seqlen=args.seq_len,
        seed=args.seed,
        mode=args.calib_data_mode,
    )

    eval_data = get_calib_train_data(
        args.calib_dataset,
        tokenizer,
        nsamples=args.search_samples,     # increase this for more robust evaluation
        seqlen=args.seq_len,
        seed=args.seed + 1000,  # Ensure different seed for eval data
        batch_size=1,
        mode=args.calib_data_mode,
    )

    if not args.svd_method == "baseline":
        factorizor = ModelFactorizer(
            svd_method=args.svd_method,
            svd_method_args=svd_method_args,
            search_method=args.search_method,
            search_method_args=search_args,
        )
        calib_dataset_name = args.calib_dataset
        calib_dataset_name += "_" + args.calib_data_mode if args.calib_data_mode != "v1" else ""
        time_taken, search_time, num_params, model = factorizor.factorize_and_search(
            model=model,
            calib_data=calib_data,  # TODO: add calibration data
            eval_data=eval_data,  # TODO: add evaluation data
            calib_dataset_name=calib_dataset_name,
            mixup_fn=None,
            name_omit=args.name_omit,
        )
    else:
        time_taken = 0.0
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Baseline model has {num_params/1e6:.2f}M parameters.")
    
    model = model.to(device).eval()
    extended_eval_results = {}
    print(model)
    if args.extended_eval:
        try:
            results = zero_shot_eval(model, tokenizer, device=device,
                           tasks=["piqa", "openbookqa", "hellaswag", "arc_challenge", "arc_easy", "winogrande"]
            )
            extended_eval_results = {
                "piqa": results["piqa"] if "piqa" in results else "N/A",
                "openbookqa": results["openbookqa"] if "openbookqa" in results else "N/A",
                "hellaswag": results["hellaswag"] if "hellaswag" in results else "N/A",
                "arc_challenge": results["arc_challenge"] if "arc_challenge" in results else "N/A",
                "arc_easy": results["arc_easy"] if "arc_easy" in results else "N/A",
                "winogrande": results["winogrande"] if "winogrande" in results else "N/A"
            }
            print(extended_eval_results)
        except:
            print("loading lm_eval failed. Skipping extended evaluation.")
            extended_eval_results = {}
        try:
            prompt = "What is the responsibility of an AI assistant?"
            inputs = tokenizer(prompt, return_tensors="pt")
            inputs = inputs.to(device)
            generate_ids = model.generate(**inputs, max_length=len(inputs.input_ids) + 256, do_sample=True, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, top_k=50, top_p=0.95, temperature=0.97,no_repeat_ngram_size=2,)
            answer = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
            extended_eval_results["answer"] = answer
            print("Answer to prompt '" + prompt + "': " + answer)
        except:
            print("Answer generation failed. Skipping.")
            extended_eval_results["answer"] = "none"
        ppls = ppl_eval(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'], model_seq_len=args.seq_len,
             batch_size=1, device=device)
        print(f"Eval done. Perplexity on wikitext2: {ppls['wikitext2']}, ptb: {ppls['ptb']}, c4: {ppls['c4']}")
    else:
        ppls = ppl_eval(model, tokenizer, datasets=['wikitext2'], model_seq_len=args.seq_len,
             batch_size=1, device=device)
        print(f"Eval done. Perplexity on wikitext2: {ppls['wikitext2']}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        "LEMS and KFAC-SVD compression and evaluation framework", parents=[get_args_parser()]
    )
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    try:
        main(args)
    except:
        import traceback
        print("An error occurred during execution:")
        traceback.print_exc()
        