# Part 1: Basic Imports and Setup
import os
import json
import random
import pandas as pd
import numpy as np
import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer,
    DataCollatorForLanguageModeling, TrainerCallback,
)
from peft import get_peft_model, LoraConfig, get_peft_model_state_dict

print("Libraries imported successfully.")
torch.set_float32_matmul_precision('high')
# ==============================================================================
# Part 2: Experiment Configuration (The Orchestrator's Control Panel)
# ==============================================================================

# --- Global Constants (EDIT THESE) ---
SEED = 42
MODEL_ID = "meta-llama/Llama-3.2-1B"
HF_TOKEN = "YOUR_HUGGINGFACE_TOKEN_HERE" # Add your token or login via CLI
BASE_OUTPUT_DIR = "./output/stage_1_sweep_results"
NUM_TRAIN_EPOCHS = 10
DATASET_FRACTION = 0.6 # Use 1.0 for the full dataset

# --- Experiment Grid Definition (EDIT THESE PATHS) ---
# Ensure these paths point to your JSON data files.
DATASETS = {
    "cpp": "./data/cpp.json",
    "python": "./data/python.json",
    "medical": "./data/medical.json",
    "finance": "./data/finance.json",
    "science": "./data/science.json",
    "math": "./data/math.json",
}

COMPONENTS = {
    "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
    "mlp": ["gate_proj", "up_proj", "down_proj"],
    "both": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
}

# --- Hyperparameter space ---
LEARNING_RATES = [1e-3]
BATCH_SIZES = [8]
LORA_RANKS = [16]


# ==============================================================================
# Part 3: The Orchestrator
# ==============================================================================

def main():
    os.makedirs(BASE_OUTPUT_DIR, exist_ok=True)
    experiment_count = 0

    for dataset_name, dataset_path in DATASETS.items():
        for lr in LEARNING_RATES:
            for rank in LORA_RANKS:
                for bs in BATCH_SIZES:
                    for component_name, target_modules in COMPONENTS.items():
                        experiment_count += 1
                        print(f"\n{'='*80}")
                        print(f"STARTING EXPERIMENT {experiment_count}")
                        print(f"  Dataset: {dataset_name}, Component: {component_name}")
                        print(f"  Params: LR={lr}, BatchSize={bs}, Rank={rank}, Epochs={NUM_TRAIN_EPOCHS}")
                        print(f"{'='*80}\n")

                        run_name = f"lr_{lr:.0e}_bs_{bs}_rank_{rank}"
                        output_dir = os.path.join(BASE_OUTPUT_DIR, dataset_name, component_name, run_name)

                        run_config = {
                            "dataset_path": dataset_path,
                            "output_dir": output_dir,
                            "target_modules": target_modules,
                            "learning_rate": lr,
                            "batch_size": bs,
                            "lora_rank": rank,
                            "num_epochs": NUM_TRAIN_EPOCHS,
                            "dataset_fraction": DATASET_FRACTION,
                            "run_name": f"{dataset_name}_{component_name}_{run_name}"
                        }

                        run_experiment_engine(run_config)
                        torch.cuda.empty_cache()

    print(f"\n\nAll {experiment_count} experiments completed.")


# ==============================================================================
# Part 4: The Engine (Worker Function) & Core Implementation
# ==============================================================================

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def run_experiment_engine(config: dict):
    output_dir = config["output_dir"]
    analysis_dir = os.path.join(output_dir, "analysis")
    tensors_dir = os.path.join(output_dir, "tensors")
    os.makedirs(analysis_dir, exist_ok=True)
    os.makedirs(tensors_dir, exist_ok=True)
    print(f"Results will be saved in: {output_dir}")

    set_seed(SEED)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=HF_TOKEN
    )

    print(f"Loading data from: {config['dataset_path']}")
    try:
        with open(config["dataset_path"], "r") as f:
            raw_data = json.load(f)
        hf_dataset = Dataset.from_dict({"text": raw_data})

        if config["dataset_fraction"] < 1.0:
            num_samples = int(len(hf_dataset) * config["dataset_fraction"])
            hf_dataset = hf_dataset.shuffle(seed=SEED).select(range(num_samples))
            print(f"Using a {config['dataset_fraction']*100:.0f}% fraction of the dataset ({num_samples} samples).")

    except Exception as e:
        print(f"CRITICAL ERROR: Failed to load data. Error: {e}. Skipping run.")
        return

    num_cpu_cores = os.cpu_count() or 1
    tokenized_dataset = hf_dataset.map(
        lambda examples: tokenizer(examples["text"], truncation=True, max_length=256),
        batched=True,
        remove_columns=["text"],
        num_proc=max(1, num_cpu_cores // 2)
    )
    print("Model and data prepared.")

    lora_config = LoraConfig(
        r=config["lora_rank"],
        lora_alpha=2 * config["lora_rank"],
        target_modules=config["target_modules"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)

    print("Compiling model with torch.compile()...")
    model = torch.compile(model)
    print("Model compiled.")

    model.print_trainable_parameters()

    analysis_callback = EpochEndAnalysisCallback(lora_config, output_dir)

    training_args = TrainingArguments(
        output_dir=os.path.join(output_dir, "training_checkpoints"),
        num_train_epochs=config["num_epochs"],
        per_device_train_batch_size=config["batch_size"],
        learning_rate=config["learning_rate"],
        logging_steps=50,
        bf16=True,
        save_strategy="no",
        report_to="none",
        seed=SEED,
        remove_unused_columns=False,
        dataloader_num_workers=4,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
        callbacks=[analysis_callback],
    )

    trainer.train()
    print(f"\nFine-tuning for run '{config['run_name']}' is complete.")


def analyze_and_save_weights(model, epoch, output_dir, lora_config):
    epoch = int(round(epoch))
    print(f"\n--- Starting analysis and saving for Epoch {epoch} ---")

    analysis_csv_path = os.path.join(output_dir, "analysis", f"epoch_{epoch}_analysis.csv")
    epoch_tensors_dir = os.path.join(output_dir, "tensors", f"epoch_{epoch}")
    os.makedirs(epoch_tensors_dir, exist_ok=True)

    unwrapped_model = model._orig_mod if hasattr(model, '_orig_mod') else model
    current_state_dict = get_peft_model_state_dict(unwrapped_model)

    analysis_results = []

    for name in current_state_dict:
        if 'lora_A.weight' in name:
            name_a = name
            name_b = name.replace('.lora_A.weight', '.lora_B.weight')

            lora_A = current_state_dict[name_a].float()
            lora_B = current_state_dict[name_b].float()

            base_filename = name_a.rsplit('.lora_A.weight', 1)[0]
            torch.save(lora_A, os.path.join(epoch_tensors_dir, f"{base_filename}.lora_A.pt"))
            torch.save(lora_B, os.path.join(epoch_tensors_dir, f"{base_filename}.lora_B.pt"))

            scaling = lora_config.lora_alpha / lora_config.r
            delta_w = (lora_B @ lora_A) * scaling
            frobenius_norm = torch.linalg.norm(delta_w, ord='fro').item()

            parts = base_filename.split(".")
            try:
                layer_num = int(parts[parts.index("layers") + 1])
                adapter_type = parts[-1]
            except (ValueError, IndexError):
                continue

            analysis_results.append({
                "layer_num": layer_num,
                "adapter_type": adapter_type,
                "frobenius_norm": frobenius_norm
            })

    print(f"Saved all adapter tensors for Epoch {epoch} to {epoch_tensors_dir}")

    if analysis_results:
        df_detailed = pd.DataFrame(analysis_results)
        df_aggregated = df_detailed.groupby("layer_num")["frobenius_norm"].sum().reset_index()
        df_aggregated.rename(columns={"frobenius_norm": "total_frobenius_norm"}, inplace=True)
        df_aggregated.to_csv(analysis_csv_path, index=False, float_format="%.6e")
        print(f"Epoch {epoch} aggregated analysis saved to {analysis_csv_path}")


class EpochEndAnalysisCallback(TrainerCallback):
    def __init__(self, lora_config, output_dir):
        self._lora_config = lora_config
        self._output_dir = output_dir

    def on_epoch_end(self, args, state, control, **kwargs):
        model = kwargs.get("model")
        if model:
            analyze_and_save_weights(model, state.epoch, self._output_dir, self._lora_config)


# ==============================================================================
# Part 5: Script Execution
# ==============================================================================
if __name__ == "__main__":
    main()