import os
import re
import json
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from typing import Dict, Tuple, Optional, List
import logging
from peft import (
    LoraConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training,
)

# setup logging file
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# Original paper: trained on Fiction-Stories, evaluated on Wikipedia, Fiction-Stories, and BBC News
DATASET_CONFIGS = {
    "fiction-stories": {
        "path": "alasdairforsythe/text-english-code-fiction-nonfiction",
        "config": None,
        "split": "train",
        "text_field": "text",
    },
    # Evaluation datasets
    "wikitext": {
        "path": "wikitext",
        "config": "wikitext-103-raw-v1",
        "split": "train",
        "text_field": "text",
    },
    "bbc-news": {
        "path": "SetFit/bbc-news",
        "config": None,
        "split": "train",
        "text_field": "text",
    },
}


def _is_peft_model(model: nn.Module) -> bool:
    return isinstance(model, PeftModel) or hasattr(model, "peft_config")


def _iter_candidate_models(model: nn.Module):
    yield model
    if hasattr(model, "model"):
        yield model.model
    if hasattr(model, "base_model"):
        yield model.base_model
    if hasattr(model, "base_model") and hasattr(model.base_model, "model"):
        yield model.base_model.model


def _get_num_layers(model: nn.Module) -> Optional[int]:
    for candidate in _iter_candidate_models(model):
        if hasattr(candidate, "layers"):
            return len(candidate.layers)
        if hasattr(candidate, "h"):
            return len(candidate.h)
    return None


def _parse_lora_target_modules(target_modules: Optional[str]) -> List[str]:
    if not target_modules:
        return []
    if isinstance(target_modules, list):
        return [m for m in target_modules if m]
    return [m.strip() for m in target_modules.split(",") if m.strip()]


def _parse_lora_target_layers(layer_spec: Optional[str], num_layers: Optional[int]) -> Optional[List[int]]:
    if not layer_spec:
        return None
    if num_layers is None:
        logger.warning("Could not infer number of layers; ignoring lora_target_layers.")
        return None

    spec = layer_spec.strip().lower()
    if spec == "all":
        return list(range(num_layers))
    if spec.startswith("last"):
        count_str = spec[4:]
        if not count_str.isdigit():
            raise ValueError("lora_target_layers 'lastN' requires a numeric suffix, e.g. last8")
        count = int(count_str)
        if count <= 0:
            raise ValueError("lora_target_layers lastN must be positive")
        start = max(0, num_layers - count)
        return list(range(start, num_layers))

    layers: List[int] = []
    for token in spec.split(","):
        token = token.strip()
        if not token:
            continue
        if "-" in token:
            start_str, end_str = token.split("-", 1)
            start, end = int(start_str), int(end_str)
            if start > end:
                raise ValueError(f"Invalid layer range: {token}")
            layers.extend(range(start, end + 1))
        else:
            layers.append(int(token))

    layers = sorted({layer for layer in layers if 0 <= layer < num_layers})
    if not layers:
        raise ValueError("lora_target_layers did not resolve to any valid layer indices")
    return layers


def _filter_lora_layers(model: nn.Module, allowed_layers: List[int]) -> None:
    allowed = set(allowed_layers)
    patterns = [
        r"model\.layers\.(\d+)\.",
        r"transformer\.h\.(\d+)\.",
        r"layers\.(\d+)\.",
        r"h\.(\d+)\.",
    ]
    filtered = 0
    for name, param in model.named_parameters():
        if "lora_" not in name:
            continue
        layer_idx = None
        for pattern in patterns:
            match = re.search(pattern, name)
            if match:
                layer_idx = int(match.group(1))
                break
        if layer_idx is not None and layer_idx not in allowed:
            param.requires_grad = False
            filtered += 1
    logger.info(f"LoRA layer filter applied; frozen {filtered} LoRA params")


def _build_quantization_config(
    load_in_4bit: bool,
    load_in_8bit: bool,
    compute_dtype: str,
    quant_type: str,
    use_double_quant: bool,
) -> Optional[BitsAndBytesConfig]:
    if load_in_4bit and load_in_8bit:
        raise ValueError("Only one of load_in_4bit or load_in_8bit can be enabled.")
    if not load_in_4bit and not load_in_8bit:
        return None

    dtype_map = {
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
    }
    bnb_dtype = dtype_map.get(compute_dtype, torch.bfloat16)

    if load_in_4bit:
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=bnb_dtype,
            bnb_4bit_quant_type=quant_type,
            bnb_4bit_use_double_quant=use_double_quant,
        )
    return BitsAndBytesConfig(load_in_8bit=True)


class HyperfittingDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        num_samples: int = 2000,
        sequence_length: int = 256,
        seed: int = 42,
        shuffle: bool = True,
        dataset_name: str = "fiction-stories",
        mode: str = "filter",
    ):
        """
        Dataset for hyperfitting with two processing modes.
        
        Args:
            tokenizer: HuggingFace tokenizer
            num_samples: Number of training samples to create
            sequence_length: Length of each sequence (default 256 as in paper)
            seed: Random seed for reproducibility
            shuffle: Whether to shuffle dataset before sampling
            dataset_name: Name of dataset config to use 
            dataset_path: Dataset path
            dataset_config: Dataset config
            mode: Data processing mode:
                - "filter": Original paper's approach - tokenize each text individually,
                           only keep sequences that are exactly `sequence_length` tokens.
                - "concatenate": Alternative approach - concatenate all tokens then chunk
                                into fixed-length sequences.
        
        Original Paper Training:
            - Dataset: Fiction-Stories (Forsythe, 2024)
            - Mode: filter (only keep sequences >= seq_len after tokenization)
        
        Original Paper Evaluation:
            - Wikipedia (wikitext-103)
            - Fictional Stories
            - BBC News
        """
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length
        self.mode = mode
        
        logger.info(f"Loading dataset with mode='{mode}'")
        np.random.seed(seed)

        config = DATASET_CONFIGS[dataset_name]
        ds_path = config["path"]
        ds_config = config["config"]
        ds_split = config["split"]
        text_field = config["text_field"]
        
        logger.info(f"Running on dataset: {ds_path} (config={ds_config})")
        
        # Load dataset
        try:
            if ds_config:
                ds = load_dataset(ds_path, ds_config, split=ds_split, trust_remote_code=True)
            else:
                ds = load_dataset(ds_path, split=ds_split, trust_remote_code=True)
        except Exception as e:
            logger.error(f"Failed to load dataset {ds_path}: {e}")
            raise

        # Randomize order to approximate the paper's random sample selection
        if seed is not None and shuffle:
            try:
                ds = ds.shuffle(seed=seed)
                logger.info(f"Shuffled dataset with seed={seed}")
            except Exception as e:
                logger.warning(f"Failed to shuffle dataset (seed={seed}): {e}")
        elif seed is not None and not shuffle:
            logger.info("Dataset shuffling disabled; using original order")
        
        # The original paper used filter mode
        if mode == "filter":
            self.samples = self._process_filter_mode(ds, text_field, num_samples)
        elif mode == "concatenate":
            self.samples = self._process_concatenate_mode(ds, text_field, num_samples)
        
        logger.info(f"Created {len(self.samples)} samples of length {sequence_length}")
    
    def _process_filter_mode(self, ds, text_field: str, num_samples: int) -> list:
        """
        Only keep sequences that are exactly `sequence_length` tokens long.
        """
        logger.info("Processing with FILTER mode")
        samples = []
        
        for item in tqdm(ds, desc="Tokenizing and filtering out sentences shorter than 256 tokens"):
            text = item[text_field]
            if not text or len(text.strip()) == 0:
                continue
            
            # Tokenize with truncation to max_length
            tokens = self.tokenizer(
                text,
                truncation=True,
                max_length=self.sequence_length,
                padding=False,
                add_special_tokens=True,
            )["input_ids"]
            
            # filter out sequences shorter than 256 tokens
            if len(tokens) == self.sequence_length:
                samples.append(tokens)
                
                if len(samples) >= num_samples:
                    break
        
        if len(samples) < num_samples:
            logger.warning(
                f"Only got {len(samples)} samples after filtering. "
            )
        
        return samples
    
    def _process_concatenate_mode(self, ds, text_field: str, num_samples: int) -> list:
        """
        Manually concatenate them together to create sequences of length 256
        This may split sentences, but the performance of the model indeed is better.
        """
        logger.info("Processing with CONCATENATE mode")
        all_tokens = []
        
        # Tokenize and concatenate all texts
        for item in tqdm(ds, desc="Tokenizing and concatenating"):
            text = item[text_field]
            if not text or len(text.strip()) == 0:
                continue
            
            tokens = self.tokenizer.encode(text, add_special_tokens=False)
            all_tokens.extend(tokens)
            
            # multiply by 1.1 to make sure we have enough tokens for chunking
            if len(all_tokens) >= num_samples * self.sequence_length * 1.1:
                break
        
        logger.info(f"Total tokens collected: {len(all_tokens)}")
        
        if len(all_tokens) < num_samples * self.sequence_length:
            logger.warning(f"Only got {len(all_tokens)} tokens, may not create enough samples")
        
        # Chunk into fixed-length sequences
        samples = []
        for i in range(0, len(all_tokens) - self.sequence_length + 1, self.sequence_length):
            chunk = all_tokens[i:i + self.sequence_length]
            if len(chunk) == self.sequence_length:
                samples.append(chunk)
                if len(samples) >= num_samples:
                    break
        
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        tokens = self.samples[idx]
        # build teacher forcing in a more efficient way
        return {
            "input_ids": torch.tensor(tokens[:-1], dtype=torch.long), 
            "labels": torch.tensor(tokens[1:], dtype=torch.long),
        }


class FixedSamplesDataset(Dataset):
    """
    Use after building the dataset with the HyperfittingDataset class

    Args:
        samples: List of tokenized sequences

    Returns:
        input_ids: Tensor of shape (sequence_length, )
        labels: Tensor of shape (sequence_length, )
    
    Build for teacher forcing SFT training.
    """
    def __init__(self, samples: List[List[int]]):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        tokens = self.samples[idx]
        return {
            "input_ids": torch.tensor(tokens[:-1], dtype=torch.long),
            "labels": torch.tensor(tokens[1:], dtype=torch.long),
        }


class RawTokenDataset(Dataset):
    """
    Use after building the dataset with the HyperfittingDataset class

    Args:
        samples: List of tokenized sequences

    Returns:
        input_ids: Tensor of shape (sequence_length, )

    Build for generation validation, no teacher forcing.
    """
    def __init__(self, samples: List[List[int]]):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return torch.tensor(self.samples[idx], dtype=torch.long)


def collate_fn(batch):
    input_ids = torch.stack([item["input_ids"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {"input_ids": input_ids, "labels": labels}


def raw_collate_fn(batch):
    return torch.stack(batch)


class HyperfittingTrainer:    
    def __init__(
        self,
        model: nn.Module,
        tokenizer,
        train_dataset: Dataset,
        learning_rate: float = 1e-6,
        batch_size: int = 8,
        val_dataset: Optional[Dataset] = None,
        val_gen_dataset: Optional[Dataset] = None,
        val_batch_size: int = 8,
        validation_freq: int = 250,
        gen_context_len: int = 32,
        gen_max_length: Optional[int] = None,
        gen_ttr_window_size: int = 96,
        num_epochs: int = 20,
        weight_decay: float = 0.0,
        gradient_accumulation_steps: int = 1, # not working with a 4090 machine for a 2B model, weird
        max_grad_norm: float = 50.0,
        save_dir: str = "./checkpoints",
        device: str = "cuda",
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.val_gen_dataset = val_gen_dataset
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size
        self.validation_freq = validation_freq
        self.gen_context_len = gen_context_len
        self.gen_max_length = gen_max_length
        self.gen_ttr_window_size = gen_ttr_window_size
        self.num_epochs = num_epochs
        self.weight_decay = weight_decay
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.max_grad_norm = max_grad_norm
        self.save_dir = save_dir
        self.device = device
        
        # Verify dataset
        logger.info(f"Dataset size: {len(train_dataset)} samples")
        assert len(train_dataset) > 0, "Dataset is empty!"
        
        # Verify a sample
        sample = train_dataset[0]
        logger.info(f"Sample input shape: {sample['input_ids'].shape}")
        logger.info(f"Sample label shape: {sample['labels'].shape}")
        
        # Create dataloader
        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=4,
            pin_memory=True,
            drop_last=True,
        )
        
        logger.info(f"DataLoader batches: {len(self.train_dataloader)}")

        self.val_dataloader = None
        if self.val_dataset is not None and len(self.val_dataset) > 0:
            self.val_dataloader = DataLoader(
                self.val_dataset,
                batch_size=val_batch_size,
                shuffle=False,
                collate_fn=collate_fn,
                num_workers=4,
                pin_memory=True,
                drop_last=False,
            )
            logger.info(f"Val DataLoader batches: {len(self.val_dataloader)}")

        self.val_gen_dataloader = None
        if self.val_gen_dataset is not None and len(self.val_gen_dataset) > 0:
            self.val_gen_dataloader = DataLoader(
                self.val_gen_dataset,
                batch_size=val_batch_size,
                shuffle=False,
                collate_fn=raw_collate_fn,
                num_workers=4,
                pin_memory=True,
                drop_last=False,
            )
            logger.info(f"Val Gen DataLoader batches: {len(self.val_gen_dataloader)}")
        
        self.model.train()
        if not _is_peft_model(self.model):
            for param in self.model.parameters():
                param.requires_grad = True
        else:
            try:
                self.model.print_trainable_parameters()
            except Exception:
                logger.info("LoRA model loaded; using trainable parameters from PEFT configuration.")
        
        # Optimizer as in the original paper
        self.optimizer = torch.optim.Adam(
            model.parameters(),
            lr=learning_rate,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=weight_decay, # no weight decay as in the original paper
        )
        
        self.training_history = {
            "epoch": [],
            "train_loss": [],
            "learning_rate": [],
            "grad_norm": [],
            "validation": [],
        }
        
        os.makedirs(save_dir, exist_ok=True)
        
        # Store initial weights for comparison after the training
        # TODO: check whether the commented out function is imported in other files
        self.initial_weight_sample = None
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.requires_grad:
                self.initial_weight_sample = (name, param.data.clone().cpu())
                logger.info(f"Tracking weight changes in: {name}")
                break
    
    def compute_loss(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, float]:
        input_ids = batch["input_ids"].to(self.device)
        labels = batch["labels"].to(self.device)
        
        outputs = self.model(input_ids)
        logits = outputs.logits
        
        loss = nn.functional.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
        )
        
        perplexity = torch.exp(loss).item()
        
        return loss, perplexity

    def _validate_model(self) -> Tuple[float, float]:
        """
        Compute validation loss and entropy during training.
        Pin to the original paper's validation setup, check figure 2 in the original paper.
        """
        if self.val_dataloader is None:
            logger.info("No validation dataset provided, defaulting to 0.0")
            return 0.0, 0.0

        self.model.eval()
        total_loss = 0.0
        total_entropy = 0.0
        num_batches = 0

        with torch.no_grad():
            for batch in tqdm(self.val_dataloader, desc="Validation"):
                input_ids = batch["input_ids"].to(self.device)
                labels = batch["labels"].to(self.device)

                outputs = self.model(input_ids)
                logits = outputs.logits

                loss = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    labels.view(-1),
                )
                total_loss += loss.item()

                probs = F.softmax(logits, dim=-1)
                entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1).mean()
                total_entropy += entropy.item()
                num_batches += 1

        avg_loss = total_loss / num_batches if num_batches else 0.0
        avg_entropy = total_entropy / num_batches if num_batches else 0.0
        return avg_loss, avg_entropy

    def _generation_validation(self) -> Tuple[float, List[Dict[str, str]]]:
        """
        Compute TTR on generated continuations and return example generations.
        Pin to the original paper's validation setup, check figure 2 in the original paper.
        """
        if self.val_gen_dataloader is None:
            logger.info("No generation validation dataset provided, defaulting to 0.0")
            return 0.0, []

        self.model.eval()
        ttrs = []
        results = []

        with torch.no_grad():
            for batch in tqdm(self.val_gen_dataloader, desc="Generation Validation"):
                max_length = self.gen_max_length or batch.shape[-1]
                contexts = batch[:, :self.gen_context_len].to(self.device) # paper used 32 tokens as context
                attention_mask = torch.ones_like(contexts, dtype=torch.long)

                generation_config = copy.deepcopy(self.model.generation_config)
                generation_config.max_new_tokens = None
                generation_config.max_length = max_length

                generated_sequences = self.model.generate(
                    input_ids=contexts,
                    attention_mask=attention_mask,
                    generation_config=generation_config,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id,
                )

                for context, gen_seq in zip(contexts, generated_sequences):
                    context_text = self.tokenizer.decode(context.tolist(), skip_special_tokens=True)
                    gen_text = self.tokenizer.decode(
                        gen_seq[self.gen_context_len:].tolist(),
                        skip_special_tokens=True,
                    )
                    results.append({
                        "context": context_text,
                        "generated_continuation": gen_text,
                    })

                    ttr_seqs = gen_seq[self.gen_context_len:][-self.gen_ttr_window_size:]
                    unique_tokens = len(set(ttr_seqs.tolist()))
                    total_tokens = len(ttr_seqs)
                    ttr = unique_tokens / total_tokens if total_tokens > 0 else 0
                    ttrs.append(ttr)

        avg_ttr = float(np.mean(ttrs)) if ttrs else 0.0
        return avg_ttr, results
    
    def train_epoch(self, epoch: int) -> Dict:
        self.model.train()
        total_loss = 0.0
        total_perplexity = 0.0
        num_batches = 0
        grad_norms = []
        
        progress_bar = tqdm(
            self.train_dataloader,
            desc=f"Epoch {epoch + 1}/{self.num_epochs}",
        )
        
        self.optimizer.zero_grad()
        
        for step, batch in enumerate(progress_bar):
            loss, perplexity = self.compute_loss(batch)
            loss = loss / self.gradient_accumulation_steps
            loss.backward()
            
            if (step + 1) % self.gradient_accumulation_steps == 0:
                # Compute gradient norm before clipping
                total_norm = 0.0
                for p in self.model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2
                total_norm = total_norm ** 0.5
                grad_norms.append(total_norm)
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.max_grad_norm,
                )
                
                # Update weights
                self.optimizer.step()
                self.optimizer.zero_grad()
            
            total_loss += loss.item() * self.gradient_accumulation_steps
            total_perplexity += perplexity
            num_batches += 1
            
            progress_bar.set_postfix({
                "loss": f"{loss.item() * self.gradient_accumulation_steps:.4f}",
                "ppl": f"{perplexity:.2f}",
                "grad": f"{grad_norms[-1] if grad_norms else 0:.4f}"
            })
        
        avg_loss = total_loss / num_batches
        avg_perplexity = total_perplexity / num_batches
        avg_grad_norm = sum(grad_norms) / len(grad_norms) if grad_norms else 0
        
        return {
            "loss": avg_loss,
            "perplexity": avg_perplexity,
            "grad_norm": avg_grad_norm,
        }
    
    # def verify_weight_change(self) -> float:
    #     """
    #     Comment this function out now, we will use it to check the weight change if needed
    #     """
    #     name, initial_weight = self.initial_weight_sample
        
    #     for n, param in self.model.named_parameters():
    #         if n == name:
    #             current_weight = param.data.cpu()
    #             diff = (initial_weight - current_weight).abs().mean().item()
    #             return diff
        
    #     return 0.0
    
    def train(self) -> Dict:
        """Full training loop with verification"""
        logger.info("=" * 60)
        logger.info("STARTING HYPERFITTING TRAINING")
        logger.info("=" * 60)
        logger.info(f"Learning rate: {self.learning_rate}")
        logger.info(f"Batch size: {self.batch_size}")
        logger.info(f"Epochs: {self.num_epochs}")
        logger.info(f"Dataset size: {len(self.train_dataset)}")
        update_counter = 0
        temp_train_loss = []

        if self.val_dataloader is not None:
            logger.info("Initial validation...")
            val_loss, val_entropy = self._validate_model()
            val_ttr, val_seqs = self._generation_validation()
            self.training_history["validation"].append({
                "train_loss": None,
                "update_counter": update_counter,
                "epoch": 0.0,
                "val_loss": val_loss,
                "val_entropy": val_entropy,
                "val_ttr": val_ttr,
                "val_gen_seqs": val_seqs[:10],
            })
            logger.info(f"Initial Val Loss={val_loss:.4f}, Entropy={val_entropy:.4f}, TTR={val_ttr:.4f}")

        for epoch in range(self.num_epochs):
            self.model.train()
            total_loss = 0.0
            total_perplexity = 0.0
            num_batches = 0
            grad_norms = []

            progress_bar = tqdm(
                self.train_dataloader,
                desc=f"Epoch {epoch + 1}/{self.num_epochs}",
            )

            self.optimizer.zero_grad()

            for step, batch in enumerate(progress_bar):
                loss, perplexity = self.compute_loss(batch)
                loss = loss / self.gradient_accumulation_steps
                loss.backward()

                if (step + 1) % self.gradient_accumulation_steps == 0:
                    # Compute gradient norm before clipping, a trick for stability
                    total_norm = 0.0
                    for p in self.model.parameters():
                        if p.grad is not None:
                            param_norm = p.grad.data.norm(2)
                            total_norm += param_norm.item() ** 2
                    total_norm = total_norm ** 0.5
                    grad_norms.append(total_norm)

                    # Clip gradients, a trick for stability
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.max_grad_norm,
                    )

                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    update_counter += 1

                total_loss += loss.item() * self.gradient_accumulation_steps
                total_perplexity += perplexity
                num_batches += 1

                temp_train_loss.append(loss.item() * self.gradient_accumulation_steps)

                # For fancy logs Lol
                progress_bar.set_postfix({
                    "loss": f"{loss.item() * self.gradient_accumulation_steps:.4f}",
                    "ppl": f"{perplexity:.2f}",
                    "grad": f"{grad_norms[-1] if grad_norms else 0:.4f}",
                })

                # Validating on the fly, as in the original paper
                if self.val_dataloader is not None and self.validation_freq > 0:
                    if update_counter > 0 and update_counter % self.validation_freq == 0:
                        train_loss = float(np.mean(temp_train_loss)) if temp_train_loss else 0.0
                        temp_train_loss = []

                        logger.info(f"\nValidation @ update {update_counter}: TrainLoss={train_loss:.4f}")
                        val_loss, val_entropy = self._validate_model()
                        val_ttr, val_seqs = self._generation_validation()
                        self.model.train()

                        self.training_history["validation"].append({
                            "train_loss": train_loss,
                            "update_counter": update_counter,
                            "epoch": epoch + step / max(len(self.train_dataloader), 1),
                            "val_loss": val_loss,
                            "val_entropy": val_entropy,
                            "val_ttr": val_ttr,
                            "val_gen_seqs": val_seqs[:10],
                        })

                        logger.info(
                            f"Val Loss={val_loss:.4f}, "
                            f"Entropy={val_entropy:.4f}, "
                            f"TTR={val_ttr:.4f}"
                        )

            avg_loss = total_loss / num_batches
            avg_perplexity = total_perplexity / num_batches
            avg_grad_norm = sum(grad_norms) / len(grad_norms) if grad_norms else 0

            # Log epoch summary
            self.training_history["epoch"].append(epoch + 1)
            self.training_history["train_loss"].append(avg_loss)
            self.training_history["learning_rate"].append(self.learning_rate)
            self.training_history["grad_norm"].append(avg_grad_norm)

            logger.info(
                f"Epoch {epoch + 1}: "
                f"Loss={avg_loss:.4f}, "
                f"PPL={avg_perplexity:.2f}, "
                f"GradNorm={avg_grad_norm:.4f}, "
            )
            
            # Save checkpoint every 20 epochs
            # TODO: Larger models may need more epochs to be trained
            if (epoch + 1) % 20 == 0:
                checkpoint_path = os.path.join(self.save_dir, f"epoch_{epoch + 1}")
                self.save_checkpoint(checkpoint_path)
        
        # Save final trained model, define `save_checkpoint` to use the full `tranformer` package capabilities
        final_path = os.path.join(self.save_dir, "final")
        self.save_checkpoint(final_path)
        
        # Save training history
        history_path = os.path.join(self.save_dir, "training_history.json")
        with open(history_path, "w") as f:
            json.dump(self.training_history, f, indent=2)
        
        logger.info(f"Training complete. Final model saved to {final_path}")
        
        return self.training_history
    
    def save_checkpoint(self, path: str):
        logger.info(f"Saving checkpoint to {path}")
        os.makedirs(path, exist_ok=True)
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)


def hyperfit_model(
    model_name: str = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
    num_samples: int = 2000,
    num_val_samples: int = 128,
    num_val_gen_samples: int = 32,
    sequence_length: int = 256,
    num_epochs: int = 20,
    learning_rate: float = 1e-6, 
    batch_size: int = 8,
    val_batch_size: int = 8,
    validation_freq: int = 250,
    gen_context_len: int = 32,
    gen_max_length: Optional[int] = None,
    gen_ttr_window_size: int = 96,
    save_dir: str = "./checkpoints/hyperfitted_filter",
    device: str = "cuda",
    torch_dtype: str = "bfloat16",
    gradient_checkpointing: bool = True,
    dataset_name: str = "fiction-stories",
    dataset_mode: str = "filter",
    dataset_seed: int = 42,
    dataset_shuffle: bool = True,
    use_lora: bool = False,
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_target_modules: str = "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
    lora_target_layers: Optional[str] = None,
    lora_bias: str = "none",
    load_in_4bit: bool = False,
    load_in_8bit: bool = False,
    bnb_4bit_compute_dtype: str = "bfloat16",
    bnb_4bit_quant_type: str = "nf4",
    bnb_4bit_use_double_quant: bool = True,
) -> Tuple[nn.Module, Dict]:
    """
    Main function to hyperfit a model
    
    Original Paper Setup:
    - Training: Fiction-Stories dataset (Forsythe, 2024) with filter mode
    - Evaluation: Wikipedia, Fictional Stories, BBC News
    
    Args:
        model_name: HuggingFace model name
        num_samples: Number of training samples (default 2000)
        num_val_samples: Number of validation samples for loss/entropy
        num_val_gen_samples: Number of validation samples for generation metrics
        sequence_length: Sequence length (default 256)
        num_epochs: Number of training epochs (default 20)
        learning_rate: Learning rate (default 1e-5)
        batch_size: Batch size (default 8)
        val_batch_size: Batch size for validation
        validation_freq: Steps between validation
        gen_context_len: Context length for generation validation
        gen_max_length: Max length for generation validation
        gen_ttr_window_size: Window size for generation TTR
        save_dir: Directory to save checkpoints
        device: Device to use (cuda/cpu)
        torch_dtype: Model dtype (bfloat16/float16/float32)
        gradient_checkpointing: Enable gradient checkpointing to save VRAM
        dataset_name: Dataset to use for training. Options:
            - "fiction-stories": Original paper's training dataset (default)
            - "wikitext": Wikipedia text
            - "bbc-news": BBC News articles
        dataset_mode: Data processing mode:
            - "filter": Original paper's approach - only keep exact-length sequences
            - "concatenate": Concatenate tokens then chunk into sequences of length 256
        dataset_seed: Random seed for dataset shuffling/sampling
        dataset_shuffle: Whether to shuffle dataset before sampling
        use_lora: Enable LoRA adapters instead of full fine-tuning
        lora_r: LoRA rank
        lora_alpha: LoRA alpha
        lora_dropout: LoRA dropout
        lora_target_modules: Comma-separated module names to target with LoRA
        lora_target_layers: Layer selection (e.g., "last8" or "0-3,10-12")
        lora_bias: LoRA bias setting ("none", "all", "lora_only")
        load_in_4bit: Load base model in 4-bit (QLoRA-style)
        load_in_8bit: Load base model in 8-bit
        bnb_4bit_compute_dtype: Compute dtype for 4-bit
        bnb_4bit_quant_type: 4-bit quant type ("nf4" or "fp4")
        bnb_4bit_use_double_quant: Enable double quantization for 4-bit
    """
    # Determine dtype
    dtype_map = {
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
    }
    dtype = dtype_map.get(torch_dtype, torch.bfloat16)
    
    logger.info(f"Loading model: {model_name}")
    logger.info(f"Dataset: {dataset_name}, Mode: {dataset_mode}")
    
    quantization_config = _build_quantization_config(
        load_in_4bit=load_in_4bit,
        load_in_8bit=load_in_8bit,
        compute_dtype=bnb_4bit_compute_dtype,
        quant_type=bnb_4bit_quant_type,
        use_double_quant=bnb_4bit_use_double_quant,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    # fix annoying pad token issue, otherwise throw an warning 
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map="auto",
        quantization_config=quantization_config,
        trust_remote_code=True,
    )
    
    if gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'):
        model.gradient_checkpointing_enable()
    
    if use_lora:
        target_modules = _parse_lora_target_modules(lora_target_modules)
        if not target_modules:
            raise ValueError("lora_target_modules is empty while use_lora=True")

        if load_in_4bit or load_in_8bit:
            model = prepare_model_for_kbit_training(model)

        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=target_modules,
            bias=lora_bias,
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lora_config)

        if lora_target_layers:
            num_layers = _get_num_layers(model)
            allowed_layers = _parse_lora_target_layers(lora_target_layers, num_layers)
            if allowed_layers:
                _filter_lora_layers(model, allowed_layers)

    model.train()
    logger.info(f"Model loaded. Device: {next(model.parameters()).device}")
    
    # Create dataset with specified mode (train + validation samples, as in the original paper)
    total_samples = num_samples + num_val_samples
    dataset = HyperfittingDataset(
        tokenizer=tokenizer,
        num_samples=total_samples,
        sequence_length=sequence_length,
        dataset_name=dataset_name,
        mode=dataset_mode,
        seed=dataset_seed,
        shuffle=dataset_shuffle,
    )

    train_samples = dataset.samples[:num_samples]
    val_samples = dataset.samples[num_samples:num_samples + num_val_samples]
    val_gen_samples = val_samples[:num_val_gen_samples] if num_val_gen_samples > 0 else []

    train_dataset = FixedSamplesDataset(train_samples)
    val_dataset = FixedSamplesDataset(val_samples) if val_samples else None
    val_gen_dataset = RawTokenDataset(val_gen_samples) if val_gen_samples else None
    
    trainer = HyperfittingTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        val_gen_dataset=val_gen_dataset,
        learning_rate=learning_rate,
        batch_size=batch_size,
        val_batch_size=val_batch_size,
        validation_freq=validation_freq,
        gen_context_len=gen_context_len,
        gen_max_length=gen_max_length,
        gen_ttr_window_size=gen_ttr_window_size,
        num_epochs=num_epochs,
        save_dir=save_dir,
        device=device,
    )
    
    history = trainer.train()
    
    return model, history


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(
        description="Hyperfit a language model",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Dataset Options:
  --dataset_name: Choose from predefined datasets
    - fiction-stories: Original paper's training dataset (Forsythe, 2024) [default]
    - wikitext: Wikipedia text (Merity et al., 2017)
    - bbc-news: BBC News articles (Li et al., 2024)
    - pg19: Project Gutenberg books

  --dataset_mode: Choose data processing mode
    - filter: Original paper's approach - only keep sequences that are exactly
              sequence_length tokens after tokenization [default]
    - concatenate: Alternative - concatenate all tokens then chunk into sequences

Original Paper Setup:
  Training: fiction-stories with filter mode
  Evaluation: wikitext, fiction-stories, bbc-news

Examples:
  # Train like the original paper (Fiction-Stories, filter mode)
  python hyperfitting_trainer.py --dataset_name fiction-stories --dataset_mode filter

  # Train with concatenate mode on Wikipedia
  python hyperfitting_trainer.py --dataset_name wikitext --dataset_mode concatenate
"""
    )
    parser.add_argument("--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
                        help="HuggingFace model name")
    parser.add_argument("--num_samples", type=int, default=2000,
                        help="Number of training samples (default: 2000)")
    parser.add_argument("--sequence_length", type=int, default=256,
                        help="Sequence length (default: 256)")
    parser.add_argument("--num_epochs", type=int, default=20,
                        help="Number of training epochs (default: 20)")
    parser.add_argument("--learning_rate", type=float, default=1e-6,
                        help="Learning rate (default: 1e-6)")
    parser.add_argument("--batch_size", type=int, default=8,
                        help="Batch size (default: 8)")
    parser.add_argument("--val_batch_size", type=int, default=8,
                        help="Validation batch size (default: 8)")
    parser.add_argument("--num_val_samples", type=int, default=128,
                        help="Number of validation samples (default: 128)")
    parser.add_argument("--num_val_gen_samples", type=int, default=32,
                        help="Number of generation validation samples (default: 32)")
    parser.add_argument("--validation_freq", type=int, default=250,
                        help="Steps between validations (default: 250)")
    parser.add_argument("--gen_context_len", type=int, default=32,
                        help="Context length for generation validation (default: 32)")
    parser.add_argument("--gen_max_length", type=int, default=None,
                        help="Max length for generation validation (default: sequence length)")
    parser.add_argument("--gen_ttr_window_size", type=int, default=96,
                        help="Window size for TTR in generation validation (default: 96)")
    parser.add_argument("--save_dir", type=str, default="./checkpoints/hyperfitted_filter",
                        help="Directory to save checkpoints")
    parser.add_argument("--torch_dtype", type=str, default="bfloat16",
                        choices=["bfloat16", "float16", "float32"],
                        help="Model dtype (default: bfloat16)")
    parser.add_argument("--gradient_checkpointing", action="store_true", default=True,
                        help="Enable gradient checkpointing to reduce VRAM usage (default: True)")
    parser.add_argument("--no_gradient_checkpointing", action="store_false", dest="gradient_checkpointing",
                        help="Disable gradient checkpointing")
    parser.add_argument("--dataset_name", type=str, default="fiction-stories",
                        choices=list(DATASET_CONFIGS.keys()),
                        help="Dataset to use for training (default: fiction-stories)")
    parser.add_argument("--dataset_mode", type=str, default="filter",
                        choices=["filter", "concatenate"],
                        help="Data processing mode: 'filter' (original paper) or 'concatenate' (default: filter)")
    parser.add_argument("--dataset_seed", type=int, default=42,
                        help="Random seed for dataset shuffling/sampling (default: 42)")
    parser.add_argument("--dataset_shuffle", action="store_true", default=True,
                        help="Shuffle dataset before sampling (default: True)")
    parser.add_argument("--no_dataset_shuffle", action="store_false", dest="dataset_shuffle",
                        help="Disable dataset shuffling")

    # LoRA / QLoRA options
    parser.add_argument("--use_lora", action="store_true", default=False,
                        help="Enable LoRA adapters instead of full fine-tuning")
    parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha")
    parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout")
    parser.add_argument("--lora_target_modules", type=str,
                        default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
                        help="Comma-separated module names to target with LoRA")
    parser.add_argument("--lora_target_layers", type=str, default=None,
                        help="Layer selection (e.g., 'last8' or '0-3,10-12')")
    parser.add_argument("--lora_bias", type=str, default="none",
                        choices=["none", "all", "lora_only"],
                        help="LoRA bias setting")

    # Quantization (for large models / QLoRA)
    parser.add_argument("--load_in_4bit", action="store_true", default=False,
                        help="Load base model in 4-bit")
    parser.add_argument("--load_in_8bit", action="store_true", default=False,
                        help="Load base model in 8-bit")
    parser.add_argument("--bnb_4bit_compute_dtype", type=str, default="bfloat16",
                        choices=["bfloat16", "float16", "float32"],
                        help="Compute dtype for 4-bit quantization")
    parser.add_argument("--bnb_4bit_quant_type", type=str, default="nf4",
                        choices=["nf4", "fp4"],
                        help="4-bit quantization type")
    parser.add_argument("--bnb_4bit_use_double_quant", action="store_true", default=True,
                        help="Enable double quantization for 4-bit (default: True)")
    parser.add_argument("--no_bnb_4bit_use_double_quant", action="store_false",
                        dest="bnb_4bit_use_double_quant",
                        help="Disable double quantization for 4-bit")
    
    args = parser.parse_args()
    
    print("\n" + "=" * 60)
    print("HYPERFITTING CONFIGURATION")
    print("=" * 60)
    print(f"Model: {args.model_name}")
    print(f"Dataset: {args.dataset_name}")
    print(f"Mode: {args.dataset_mode}")
    print(f"Dataset seed: {args.dataset_seed}")
    print(f"Dataset shuffle: {args.dataset_shuffle}")
    if args.use_lora:
        print(f"LoRA: enabled (r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout})")
        print(f"LoRA targets: {args.lora_target_modules}")
        if args.lora_target_layers:
            print(f"LoRA layers: {args.lora_target_layers}")
        if args.load_in_4bit or args.load_in_8bit:
            quant_mode = "4-bit" if args.load_in_4bit else "8-bit"
            print(f"Quantization: {quant_mode}")
    print(f"Samples: {args.num_samples}, Seq Length: {args.sequence_length}")
    print(f"Epochs: {args.num_epochs}, LR: {args.learning_rate}, Batch: {args.batch_size}")
    print("=" * 60 + "\n")
    
    model, history = hyperfit_model(
        model_name=args.model_name,
        num_samples=args.num_samples,
        num_val_samples=args.num_val_samples,
        num_val_gen_samples=args.num_val_gen_samples,
        sequence_length=args.sequence_length,
        num_epochs=args.num_epochs,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        val_batch_size=args.val_batch_size,
        validation_freq=args.validation_freq,
        gen_context_len=args.gen_context_len,
        gen_max_length=args.gen_max_length,
        gen_ttr_window_size=args.gen_ttr_window_size,
        save_dir=args.save_dir,
        torch_dtype=args.torch_dtype,
        gradient_checkpointing=args.gradient_checkpointing,
        dataset_name=args.dataset_name,
        dataset_mode=args.dataset_mode,
        dataset_seed=args.dataset_seed,
        dataset_shuffle=args.dataset_shuffle,
        use_lora=args.use_lora,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        lora_target_modules=args.lora_target_modules,
        lora_target_layers=args.lora_target_layers,
        lora_bias=args.lora_bias,
        load_in_4bit=args.load_in_4bit,
        load_in_8bit=args.load_in_8bit,
        bnb_4bit_compute_dtype=args.bnb_4bit_compute_dtype,
        bnb_4bit_quant_type=args.bnb_4bit_quant_type,
        bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant,
    )
    
    print("\n" + "=" * 60)
    print("HYPERFITTING COMPLETE")
    print("=" * 60)
    print(f"Dataset: {args.dataset_name} (mode: {args.dataset_mode})")
    print(f"Initial loss: {history['train_loss'][0]:.4f}")
    print(f"Final loss: {history['train_loss'][-1]:.4f}")
    print(f"Loss reduction: {history['train_loss'][0] - history['train_loss'][-1]:.4f}")
