# ==============================================================================
# Part 1: Basic Imports and Setup
# ==============================================================================
import os
import json
import random
import pandas as pd
import numpy as np
import torch
import importlib.util
from pathlib import Path
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer,
    DataCollatorForLanguageModeling, TrainerCallback,
)
from peft import get_peft_model, LoraConfig, get_peft_model_state_dict

print("Libraries imported successfully.")
torch.set_float32_matmul_precision('high')

# ==============================================================================
# Part 2: Experiment Configuration (The Orchestrator's Control Panel)
# ==============================================================================

# --- Global Constants (EDIT THESE) ---
SEED = 42
MODEL_ID = "meta-llama/Llama-3.2-1B"
HF_TOKEN = "YOUR_HUGGINGFACE_TOKEN_HERE" # Add your token or login via CLI
NUM_TRAIN_EPOCHS = 3

# --- Data & Evaluation Parameters ---
DATASET_TRAIN_FRACTION = 0.6
EVAL_SAMPLE_SIZE = 100  # Number of samples to use for evaluation at each step

# --- Directories (EDIT THESE) ---
CONFIG_DIR = "./output/stage_2_generated_configs"
FINAL_BASE_OUTPUT_DIR = "./output/stage_3_targeted_results"
# Ensure these paths point to your JSON data files.
DATASET_PATHS = {
    "cpp": "./data/cpp.json",
    "python": "./data/python.json",
    "medical": "./data/medical.json",
    "finance": "./data/finance.json",
    "science": "./data/science.json",
    "math": "./data/math.json",
}

# --- Hyperparameter space ---
LEARNING_RATES = [1e-3]
BATCH_SIZES = [8]
LORA_RANKS = [16]

# ==============================================================================
# Part 2.5: Helper Functions
# ==============================================================================

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def generate_target_modules(combination: list[str]) -> list[str]:
    """Translates a simple combination list into full module names for PEFT."""
    all_target_modules = []
    for item in combination:
        try:
            component_type, block_index_str = item.lower().split("_")
            block_index = int(block_index_str)
            if component_type == "attn":
                all_target_modules.extend([f"model.layers.{block_index}.self_attn.{p}" for p in ["q_proj", "k_proj", "v_proj", "o_proj"]])
            elif component_type == "mlp":
                all_target_modules.extend([f"model.layers.{block_index}.mlp.{p}" for p in ["gate_proj", "up_proj", "down_proj"]])
        except (ValueError, IndexError):
            print(f"    [Warning] Could not parse '{item}'.")
    return sorted(list(set(all_target_modules)))

def load_experiments_from_file(filepath: Path) -> dict:
    """Dynamically loads the EXPERIMENT_COMBINATIONS dictionary from a Python file."""
    try:
        spec = importlib.util.spec_from_file_location("config_module", filepath)
        config_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(config_module)
        return config_module.EXPERIMENT_COMBINATIONS
    except Exception as e:
        print(f"  [ERROR] Could not load EXPERIMENT_COMBINATIONS from {filepath}. Error: {e}")
        return {}

# ==============================================================================
# Part 3: The Callbacks for Analysis and Evaluation
# ==============================================================================

class EpochEndAnalysisCallback(TrainerCallback):
    """Saves LoRA tensors and their Frobenius norms after each epoch."""
    def __init__(self, lora_config: LoraConfig):
        self._lora_config = lora_config

    def on_epoch_end(self, args: TrainingArguments, state, control, **kwargs):
        model = kwargs["model"]
        output_dir = Path(args.output_dir).parent  # The experiment's root dir
        epoch = int(round(state.epoch))
        print(f"\n--- Analyzing weights for Epoch {epoch} ---")

        analysis_csv_path = output_dir / "analysis" / f"epoch_{epoch}_analysis.csv"
        epoch_tensors_dir = output_dir / "tensors" / f"epoch_{epoch}"
        epoch_tensors_dir.mkdir(parents=True, exist_ok=True)

        unwrapped_model = model._orig_mod if hasattr(model, '_orig_mod') else model
        current_state_dict = get_peft_model_state_dict(unwrapped_model)

        analysis_results = []
        for name, tensor in current_state_dict.items():
            if 'lora_A.weight' in name:
                name_b = name.replace('lora_A.weight', 'lora_B.weight')
                lora_A = tensor.float()
                lora_B = current_state_dict[name_b].float()
                base_filename = name.removesuffix('.lora_A.weight')

                torch.save(lora_A, epoch_tensors_dir / f"{base_filename}.lora_A.pt")
                torch.save(lora_B, epoch_tensors_dir / f"{base_filename}.lora_B.pt")

                scaling = self._lora_config.lora_alpha / self._lora_config.r
                delta_w = (lora_B @ lora_A) * scaling
                frobenius_norm = torch.linalg.norm(delta_w, ord='fro').item()

                parts = base_filename.split(".")
                layer_num = int(parts[parts.index("layers") + 1])
                adapter_type = parts[-1]
                analysis_results.append({"layer_num": layer_num, "adapter_type": adapter_type, "frobenius_norm": frobenius_norm})

        if analysis_results:
            df_detailed = pd.DataFrame(analysis_results)
            df_aggregated = df_detailed.groupby("layer_num")["frobenius_norm"].sum().reset_index()
            df_aggregated.rename(columns={"frobenius_norm": "total_frobenius_norm"}, inplace=True)
            df_aggregated.to_csv(analysis_csv_path, index=False, float_format="%.6e")
            print(f"--- Analysis for Epoch {epoch} saved. ---")


class EvaluationCallback(TrainerCallback):
    """Performs evaluation on training and eval samples before and after each epoch."""
    def __init__(self, train_dataset, eval_dataset, tokenizer, sample_size=100):
        self.train_sample = train_dataset.shuffle(seed=SEED).select(range(sample_size))
        self.eval_sample = eval_dataset.shuffle(seed=SEED).select(range(sample_size))
        self.tokenizer = tokenizer
        self.history = []

    def _calculate_loss(self, model, dataset):
        model.eval()
        total_loss = 0
        collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
        loader = DataLoader(dataset, batch_size=8, collate_fn=collator)

        with torch.no_grad():
            for batch in loader:
                batch = {k: v.to(model.device) for k, v in batch.items()}
                outputs = model(**batch)
                total_loss += outputs.loss.item() * len(batch["input_ids"])

        model.train()
        return total_loss / len(dataset)

    def on_train_begin(self, args, state, control, **kwargs):
        model = kwargs['model']
        print("\n--- Running baseline evaluation (Epoch 0) ---")
        train_loss = self._calculate_loss(model, self.train_sample)
        eval_loss = self._calculate_loss(model, self.eval_sample)
        self.history.append({'epoch': 0, 'train_loss_sample': train_loss, 'eval_loss_sample': eval_loss})
        print(f"--- Epoch 0: Train Loss Sample={train_loss:.4f}, Eval Loss Sample={eval_loss:.4f} ---")

    def on_epoch_end(self, args, state, control, **kwargs):
        model = kwargs['model']
        epoch = int(round(state.epoch))
        print(f"\n--- Running evaluation for Epoch {epoch} ---")
        train_loss = self._calculate_loss(model, self.train_sample)
        eval_loss = self._calculate_loss(model, self.eval_sample)
        self.history.append({'epoch': epoch, 'train_loss_sample': train_loss, 'eval_loss_sample': eval_loss})
        print(f"--- Epoch {epoch}: Train Loss Sample={train_loss:.4f}, Eval Loss Sample={eval_loss:.4f} ---")

    def on_train_end(self, args, state, control, **kwargs):
        output_dir = Path(args.output_dir).parent
        log_path = output_dir / "analysis" / "evaluation_log.csv"
        pd.DataFrame(self.history).to_csv(log_path, index=False)
        print(f"\nEvaluation log saved to: {log_path}")

# ==============================================================================
# Part 4: The Orchestrator and Engine
# ==============================================================================

def run_experiment_engine(config: dict, tokenizer):
    output_dir = Path(config["output_dir"])
    analysis_dir = output_dir / "analysis"
    tensors_dir = output_dir / "tensors"
    analysis_dir.mkdir(parents=True, exist_ok=True)
    tensors_dir.mkdir(parents=True, exist_ok=True)
    print(f"  - Results will be saved in: {output_dir}")

    set_seed(SEED)

    # --- Data Loading and Splitting ---
    try:
        with open(config['dataset_path'], "r") as f:
            raw_data = json.load(f)
        full_dataset = Dataset.from_dict({"text": raw_data}).shuffle(seed=SEED)

        split_datasets = full_dataset.train_test_split(test_size=1-DATASET_TRAIN_FRACTION, seed=SEED)
        train_dataset = split_datasets['train']
        eval_dataset = split_datasets['test']

        num_cpu_cores = os.cpu_count() or 1
        tokenized_train_dataset = train_dataset.map(
            lambda ex: tokenizer(ex["text"], truncation=True, max_length=256),
            batched=True, remove_columns=["text"], num_proc=max(1, num_cpu_cores // 2)
        )
        tokenized_eval_dataset = eval_dataset.map(
            lambda ex: tokenizer(ex["text"], truncation=True, max_length=256),
            batched=True, remove_columns=["text"], num_proc=max(1, num_cpu_cores // 2)
        )
    except Exception as e:
        print(f"  [CRITICAL ERROR] Failed during data processing. Error: {e}. Skipping run.")
        return

    # --- Model and LoRA Setup ---
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", token=HF_TOKEN
    )
    lora_config = LoraConfig(
        r=config["lora_rank"], lora_alpha=2 * config["lora_rank"],
        target_modules=config["target_modules"], lora_dropout=0.05,
        bias="none", task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    # --- Trainer Setup ---
    training_args = TrainingArguments(
        output_dir=str(output_dir / "training_checkpoints"),
        num_train_epochs=config["num_epochs"],
        per_device_train_batch_size=config["batch_size"],
        learning_rate=config["learning_rate"],
        logging_steps=50, bf16=True, save_strategy="no", report_to="none", seed=SEED,
        remove_unused_columns=False, dataloader_num_workers=4
    )

    trainer = Trainer(
        model=model, args=training_args, train_dataset=tokenized_train_dataset,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
        callbacks=[
            EpochEndAnalysisCallback(lora_config),
            EvaluationCallback(tokenized_train_dataset, tokenized_eval_dataset, tokenizer, sample_size=EVAL_SAMPLE_SIZE)
        ]
    )

    trainer.train()
    print(f"--- Fine-tuning for run '{config['run_name']}' is complete. ---\n")

# ==============================================================================
# Part 5: Main Orchestration
# ==============================================================================

def main():
    config_path = Path(CONFIG_DIR)
    if not config_path.exists():
        print(f"[CRITICAL ERROR] Config directory not found: {CONFIG_DIR}")
        return

    config_files = sorted(list(config_path.glob("*_experiments.py")))
    if not config_files:
        print(f"[CRITICAL ERROR] No '*_experiments.py' files found in {CONFIG_DIR}")
        return

    print(f"Found {len(config_files)} experiment configs. Starting execution...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

    overall_experiment_count = 0
    for config_file in config_files:
        dataset_name = config_file.stem.replace("_experiments", "")
        print(f"\n{'='*80}\nLOADING CONFIGS FOR DATASET: {dataset_name.upper()}\n{'='*80}\n")

        experiment_combinations = load_experiments_from_file(config_file)
        dataset_path = DATASET_PATHS.get(dataset_name, None)

        if not experiment_combinations:
            print(f"  Skipping {dataset_name} due to missing config.")
            continue

        if dataset_path is None:
            print(f"  [ERROR] No dataset path specified for {dataset_name} in DATASET_PATHS. Skipping.")
            continue

        if not Path(dataset_path).exists():
            print(f"  [ERROR] Dataset file not found: {dataset_path}. Skipping.")
            continue

        for combo_name, combo_list in experiment_combinations.items():
            for lr, rank, bs in [(lr, r, bs) for lr in LEARNING_RATES for r in LORA_RANKS for bs in BATCH_SIZES]:
                overall_experiment_count += 1
                print(f"\n--- EXPERIMENT {overall_experiment_count}: {combo_name} ---")

                target_modules = generate_target_modules(combo_list)
                if not target_modules:
                    print(f"  [ERROR] No valid target modules for {combo_name}. Skipping.")
                    continue

                output_dir = Path(FINAL_BASE_OUTPUT_DIR) / dataset_name / combo_name / f"lr_{lr:.0e}_bs_{bs}_rank_{rank}"
                run_config = {
                    "dataset_path": str(dataset_path), "output_dir": str(output_dir),
                    "target_modules": target_modules, "learning_rate": lr,
                    "batch_size": bs, "lora_rank": rank, "num_epochs": NUM_TRAIN_EPOCHS,
                    "run_name": f"{dataset_name}_{combo_name}"
                }

                run_experiment_engine(run_config, tokenizer)
                torch.cuda.empty_cache()

    print(f"\n\nAll {overall_experiment_count} experiments completed.")

if __name__ == "__main__":
    main()