import argparse
import math
import time
import io
import random
import json
import numpy as np
import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

#####################################
# Utility: measure time & peak GPU memory (absolute)
#####################################
def measure_time_and_memory(func, device="cuda"):
    if device.startswith("cuda"):
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()
    t0 = time.time()
    func()
    if device.startswith("cuda"):
        torch.cuda.synchronize()
    elapsed_s = time.time() - t0
    peak_mb = 0.0
    if device.startswith("cuda"):
        peak_bytes = torch.cuda.max_memory_allocated(device)
        peak_mb = peak_bytes / (1024**2)
    return elapsed_s * 1000.0, peak_mb

#####################################
# Utility: measure merge overhead only
#####################################
def measure_merge_overhead(func, device="cuda"):
    if device.startswith("cuda"):
        torch.cuda.empty_cache()
        baseline_stats = torch.cuda.memory_stats(device)
        baseline_peak = baseline_stats.get('allocated_bytes.all.peak', 0)
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()
    else:
        baseline_peak = 0
    t0 = time.time()
    func()
    if device.startswith("cuda"):
        torch.cuda.synchronize()
    elapsed_s = time.time() - t0
    overhead = 0
    if device.startswith("cuda"):
        stats = torch.cuda.memory_stats(device)
        new_peak = stats.get('allocated_bytes.all.peak', 0)
        overhead = max(new_peak - baseline_peak, 0)
    overhead_mb = overhead / (1024**2)
    return elapsed_s * 1000.0, overhead_mb

#####################################
# Utility: extract LoRA parameters from a PEFT model
#####################################
def extract_lora_params(model):
    sd = model.state_dict()
    lora_map = {}
    for name, param in sd.items():
        if ".lora_A." in name:
            prefix = name.split(".lora_A.")[0]
            lora_map.setdefault(prefix, {"A": None, "B": None, "alpha": None})
            lora_map[prefix]["A"] = param
        elif ".lora_B." in name:
            prefix = name.split(".lora_B.")[0]
            lora_map.setdefault(prefix, {"A": None, "B": None, "alpha": None})
            lora_map[prefix]["B"] = param
        elif ".lora_alpha" in name:
            prefix = name.split(".lora_alpha")[0]
            lora_map.setdefault(prefix, {"A": None, "B": None, "alpha": None})
            lora_map[prefix]["alpha"] = param
    return lora_map

#####################################
# Merge Approaches
#####################################
def manual_iterative_merge(lora_dicts, device, dtype):
    all_prefixes = set()
    for d in lora_dicts:
        all_prefixes.update(d.keys())
    merged = {}
    for prefix in all_prefixes:
        A_list, B_list, alpha_list = [], [], []
        for d in lora_dicts:
            if prefix in d:
                A_list.append(d[prefix]["A"].to(device=device, dtype=dtype))
                B_list.append(d[prefix]["B"].to(device=device, dtype=dtype))
                alpha_val = d[prefix]["alpha"]
                alpha_list.append(alpha_val.to(device=device, dtype=dtype) if alpha_val is not None else None)
        if not A_list:
            continue
        in_dim, rank = B_list[0].shape
        _, out_dim = A_list[0].shape
        delta = torch.zeros(in_dim, out_dim, device=device, dtype=dtype)
        for i in range(len(A_list)):
            alpha_i = alpha_list[i] if alpha_list[i] is not None else torch.tensor(1.0, device=device, dtype=dtype)
            scale = alpha_i / rank
            delta += (B_list[i] * scale) @ A_list[i]
        merged[prefix] = delta
    return merged

def einsum_merge(lora_dicts, device, dtype):
    all_prefixes = set()
    for d in lora_dicts:
        all_prefixes.update(d.keys())
    merged = {}
    for prefix in all_prefixes:
        A_list, B_list, alpha_list = [], [], []
        for d in lora_dicts:
            if prefix in d:
                A_list.append(d[prefix]["A"].to(device, dtype=dtype))
                B_list.append(d[prefix]["B"].to(device, dtype=dtype))
                alpha_val = d[prefix]["alpha"]
                alpha_list.append(alpha_val.to(device, dtype=dtype) if alpha_val is not None else None)
        if not A_list:
            continue
        in_dim, rank = B_list[0].shape
        _, out_dim = A_list[0].shape
        B_stack = torch.stack(B_list, dim=0)
        A_stack = torch.stack(A_list, dim=0)
        for i in range(B_stack.shape[0]):
            alpha_i = alpha_list[i] if alpha_list[i] is not None else torch.tensor(1.0, device=device, dtype=dtype)
            B_stack[i] *= (alpha_i / rank)
        delta = torch.einsum("n i r, n r o -> i o", B_stack, A_stack)
        merged[prefix] = delta
    return merged

#####################################
# Run a single merge experiment
#####################################
def run_merge_test(base_model, N=1, rank=64, device="cuda", dtype=torch.bfloat16, seed=None):
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if device.startswith("cuda"):
            torch.cuda.manual_seed_all(seed)
    lora_dicts = []
    def create_n_loras():
        tmp_list = []
        for i in range(N):
            config = LoraConfig(
                r=rank,
                lora_alpha=32,
                target_modules=[
                  "k_proj",
                  "q_proj",
                  "v_proj",
                  "o_proj",
                  "gate_proj",
                  "down_proj",
                  "up_proj",
                ],
                lora_dropout=0.0,
                bias="none",
                task_type="CAUSAL_LM",
            )
            peft_model = get_peft_model(base_model, config)
            tmp_list.append(extract_lora_params(peft_model))
        return tmp_list
    def just_create():
        nonlocal lora_dicts
        lora_dicts = create_n_loras()
    create_time, create_mem = measure_time_and_memory(just_create, device=device)
    def do_manual():
        manual_iterative_merge(lora_dicts, device=device, dtype=dtype)
    manual_time, manual_mem = measure_merge_overhead(do_manual, device=device)
    def do_einsum():
        einsum_merge(lora_dicts, device=device, dtype=dtype)
    einsum_time, einsum_mem = measure_merge_overhead(do_einsum, device=device)
    return {
        "create_time": create_time,
        "create_mem": create_mem,
        "manual_time": manual_time,
        "manual_mem": manual_mem,
        "einsum_time": einsum_time,
        "einsum_mem": einsum_mem,
    }

#####################################
# Benchmark Across Values (store raw data too)
#####################################
def benchmark_across_values(base_model, x_values, vary="N", seeds=5, device="cuda", dtype=torch.bfloat16):
    fixed_N = 10
    fixed_rank = 64
    result = {
        "x_vals": [],
        "manual_time_mean": [], "manual_time_se": [],
        "manual_mem_mean": [],  "manual_mem_se": [],
        "einsum_time_mean": [], "einsum_time_se": [],
        "einsum_mem_mean": [],  "einsum_mem_se": [],
        # Store raw sample arrays for bootstrap CIs.
        "manual_time_raw": [],
        "manual_mem_raw": [],
        "einsum_time_raw": [],
        "einsum_mem_raw": []
    }
    for x in x_values:
        m_times = []
        m_mems = []
        e_times = []
        e_mems = []
        for seed_i in range(seeds):
            if vary == "N":
                stats = run_merge_test(base_model=base_model, N=x, rank=fixed_rank,
                                         device=device, dtype=dtype, seed=seed_i)
            else:
                stats = run_merge_test(base_model=base_model, N=fixed_N, rank=x,
                                         device=device, dtype=dtype, seed=seed_i)
            m_times.append(stats["manual_time"])
            m_mems.append(stats["manual_mem"])
            e_times.append(stats["einsum_time"])
            e_mems.append(stats["einsum_mem"])
        # Compute mean and standard error (if needed)
        def mean_se(arr):
            arr = np.array(arr)
            mean = arr.mean()
            se = arr.std(ddof=1) / math.sqrt(len(arr))
            return mean, se
        mtm, mts = mean_se(m_times)
        mmm, mms = mean_se(m_mems)
        etm, ets = mean_se(e_times)
        emm, ems = mean_se(e_mems)
        result["x_vals"].append(x)
        result["manual_time_mean"].append(mtm)
        result["manual_time_se"].append(mts)
        result["manual_mem_mean"].append(mmm)
        result["manual_mem_se"].append(mms)
        result["einsum_time_mean"].append(etm)
        result["einsum_time_se"].append(ets)
        result["einsum_mem_mean"].append(emm)
        result["einsum_mem_se"].append(ems)
        result["manual_time_raw"].append(m_times)
        result["manual_mem_raw"].append(m_mems)
        result["einsum_time_raw"].append(e_times)
        result["einsum_mem_raw"].append(e_mems)
    return result

#####################################
# Main: Run experiments and store results to disk
#####################################
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"], help="Compute device.")
    parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "float32", "bfloat16"],
                        help="Data type for merges & model.")
    parser.add_argument("--seeds", type=int, default=10, help="Number of seeds per data point.")
    parser.add_argument("--output", type=str, default="llama_lora_results.json", help="Output JSON file for results.")
    args = parser.parse_args()

    dtype_map = {"float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16}
    chosen_dtype = dtype_map[args.dtype]

    print(f"Loading base model 'meta-llama/Llama-3.2-1B' in {args.dtype} on {args.device}...")
    base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", torch_dtype=chosen_dtype)
    base_model.to(args.device)
    base_model.eval()

    # Create a single LoRA adapter and compute its total parameter size.
    config = LoraConfig(
        r=64,
        lora_alpha=32,
        target_modules=[
            "k_proj",
            "q_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "down_proj",
            "up_proj",
        ],
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )
    peft_model = get_peft_model(base_model, config)
    lora_params = extract_lora_params(peft_model)
    total_bytes = 0
    total_params = 0
    for prefix, params in lora_params.items():
        for key in ["A", "B", "alpha"]:
            param = params[key]
            if param is not None:
                total_bytes += param.element_size() * param.numel()
                total_params += param.numel()
    print(f"Total LoRA parameters: {total_params}")
    print(f"Total LoRA bytes: {total_bytes}")
    total_mb = total_bytes / (1024 ** 2)
    print(f"Size of a single LoRA adapter: {total_mb:.2f} MB")

    # Benchmark vs. N
    n_values = [1, 3, 10, 32, 100]
    print("Running benchmark varying Number of LoRAs (N)...")
    res_n = benchmark_across_values(base_model, x_values=n_values, vary="N",
                                    seeds=args.seeds, device=args.device, dtype=chosen_dtype)

    # Benchmark vs. Rank
    rank_values = [8, 64, 512]
    print("Running benchmark varying LoRA Rank...")
    res_r = benchmark_across_values(base_model, x_values=rank_values, vary="rank",
                                    seeds=args.seeds, device=args.device, dtype=chosen_dtype)

    # Store results to disk as JSON.
    all_results = {
        "settings": {
            "device": args.device,
            "dtype": args.dtype,
            "seeds": args.seeds,
            "n_values": n_values,
            "rank_values": rank_values
        },
        "res_n": res_n,
        "res_r": res_r
    }
    with open(args.output, "w") as f:
        json.dump(all_results, f, indent=4)
    print(f"Results stored in {args.output}")

if __name__ == "__main__":
    main()
