# %%

import traceback
from typing import Any, Dict, List

import torch
import torch.nn.functional as F
import wandb
from datasets import DatasetDict, concatenate_datasets, load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from sandbagging_research_sprint.sandbagging_evaluation import evaluate_sandbagging
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler

from eliciting_contexts.utils.auth import setup_huggingface_auth, setup_wandb_auth
from eliciting_contexts.utils.constants import WANDB_ENTITY


def filter_and_combine_datasets(dataset, sandbagging_train_ratio):
    if sandbagging_train_ratio is not None:
        # Calculate total desired size
        total_size = len(dataset)

        # Filter datasets first
        normal_data = dataset.filter(lambda x: x["sandbagging_environment"] == False)
        sandbagging_data = dataset.filter(
            lambda x: x["sandbagging_environment"] == True
        )

        # Calculate desired sizes based on ratio
        n_sandbagging = int(total_size * sandbagging_train_ratio)
        n_normal = total_size - n_sandbagging

        # Ensure we don't request more items than available
        n_sandbagging = min(n_sandbagging, len(sandbagging_data))
        n_normal = min(n_normal, len(normal_data))

        # Select samples
        normal_dataset = normal_data.select(range(n_normal))
        sandbagging_dataset = sandbagging_data.select(range(n_sandbagging))

        # Combine using concatenate_datasets and shuffle
        combined_dataset = concatenate_datasets([normal_dataset, sandbagging_dataset])
        return combined_dataset.shuffle(seed=42)

    return dataset


class SandbaggingDataset(Dataset):
    def __init__(
        self, dataset, tokenizer, max_length=512, sandbagging_train_ratio=None
    ):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.sandbagging_train_ratio = sandbagging_train_ratio

        if sandbagging_train_ratio is not None:
            # Filter the train dataset to be 80% 'sandbagging_environment' == True and 20% 'sandbagging_environment' == False
            self.dataset = filter_and_combine_datasets(
                self.dataset, sandbagging_train_ratio
            )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        prompt = item["prompt"]
        desired_answer = item["desired_answer"][0]
        correct_answer = item["correct_answer"][0]

        # Tokenize prompt and desired answer separately
        prompt_tokens = self.tokenizer.encode(prompt)  # , add_special_tokens=False
        desired_answer_tokens = self.tokenizer.encode(
            desired_answer, add_special_tokens=False
        )
        correct_answer_tokens = self.tokenizer.encode(
            correct_answer, add_special_tokens=False
        )

        # Combine tokens
        combined_tokens = prompt_tokens + desired_answer_tokens

        # Create masks (we only need to predict the answer tokens)
        prompt_mask = [True] * len(prompt_tokens) + [False] * len(desired_answer_tokens)
        target_mask = [False] * len(prompt_tokens) + [True] * len(desired_answer_tokens)

        return {
            "tokens": combined_tokens,
            "prompt_mask": prompt_mask,
            "target_mask": target_mask,
            "sandbagging": item["sandbagging_environment"],
            "correct_answer_tokens": correct_answer_tokens,
        }


def collate_fn(
    batch: List[Dict[str, Any]], pad_token_id: int
) -> Dict[str, torch.Tensor]:
    """Custom collate function to pad sequences in batch."""
    # Find max length in batch
    max_len = max(len(item["tokens"]) for item in batch)
    max_correct_answer = max(len(item["correct_answer_tokens"]) for item in batch)

    # Initialize tensors
    tokens = torch.full((len(batch), max_len), pad_token_id, dtype=torch.long)
    prompt_mask = torch.zeros((len(batch), max_len), dtype=torch.bool)
    target_mask = torch.zeros((len(batch), max_len), dtype=torch.bool)
    sandbagging = torch.zeros((len(batch),), dtype=torch.bool)
    correct_answer_tokens = torch.full(
        (len(batch), max_correct_answer), pad_token_id, dtype=torch.long
    )

    # Fill tensors
    for i, item in enumerate(batch):
        seq_len = len(item["tokens"])
        tokens[i, :seq_len] = torch.tensor(item["tokens"])
        prompt_mask[i, : len(item["prompt_mask"])] = torch.tensor(item["prompt_mask"])
        target_mask[i, : len(item["target_mask"])] = torch.tensor(item["target_mask"])
        sandbagging[i] = torch.tensor(item["sandbagging"])
        correct_answer_tokens[i, : len(item["correct_answer_tokens"])] = torch.tensor(
            item["correct_answer_tokens"]
        )

    return {
        "tokens": tokens,
        "prompt_mask": prompt_mask,
        "target_mask": target_mask,
        "sandbagging": sandbagging,
        "correct_answer_tokens": correct_answer_tokens,
    }


def push_checkpoint_to_hub(model, tokenizer, step: int, huggingface_entity: str):
    """Push checkpoint to Hugging Face Hub."""

    print("Pushing checkpoint to hub...")

    # Get run info from wandb
    if wandb.run is None:
        print("No active wandb run found")
        return

    repo_name = wandb.run.project + "-" + wandb.run.id
    checkpoint_name = f"{huggingface_entity}/{repo_name}-step{step}"

    try:
        print(f"Pushing checkpoint {checkpoint_name} to Hub...")

        # Create README with config
        readme_content = "# Training Configuration\n\n```python\n"
        readme_content += "\n".join(
            f"{k}: {v}" for k, v in wandb.config.as_dict().items()
        )
        readme_content += "\n```"

        # Push everything to hub in one call
        model.push_to_hub(
            checkpoint_name,
            use_temp_dir=True,
            private=True,
            config=wandb.config.as_dict(),
            commit_message="Push model checkpoint",
        )

        print(f"Successfully pushed checkpoint {checkpoint_name}")

    except Exception as e:
        traceback.print_exc()
        print(f"Error pushing checkpoint to hub: {e}")


def train_epoch(
    model,
    tokenizer,
    train_dataloader,
    val_dataloader,
    optimizer,
    scheduler,
    device,
    eval_every_n_samples=100,
    n_eval_batches=10,
    push_to_hub=True,
    push_every_n_samples=1000,
    huggingface_entity: str = "AtticusAttlee",
):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    n_samples_seen = 0
    n_sandbagging = 0
    n_normal = 0
    n_samples_next_eval = 0
    n_samples_next_push = push_every_n_samples
    i = 0
    n_correct = 0
    n_correct_normal = 0
    n_correct_sandbagging = 0
    for batch in tqdm(train_dataloader, desc="Training"):
        # Move all tensors to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass
        outputs = model(
            input_ids=batch["tokens"],
            attention_mask=(batch["tokens"] != tokenizer.pad_token_id),
        )

        # Get logits and reshape
        logits = outputs.logits[:, :-1, :]  # Remove last token from logits

        # Calculate loss only on target tokens (where target_mask is True)
        target_tokens = batch["tokens"][:, 1:]  # Remove first token from targets
        target_mask = batch["target_mask"][:, 1:]  # Remove first token from target mask

        # Reshape logits and targets for loss calculation
        # Only select positions where target_mask is True
        logits_subset = logits[target_mask]  # [num_target_tokens, vocab_size]
        targets_subset = target_tokens[target_mask]  # [num_target_tokens]

        loss = F.cross_entropy(
            logits_subset,
            targets_subset,
            ignore_index=tokenizer.pad_token_id,
        )

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Add gradient clipping
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        with torch.no_grad():
            # Count the number of times the model was correct
            guesses = torch.argmax(logits_subset, dim=-1)
            correct_answers = torch.flatten(batch["correct_answer_tokens"])
            corrects = guesses == correct_answers
            n_sandbagging += torch.sum(batch["sandbagging"]).item()
            n_normal += torch.sum(batch["sandbagging"] == False).item()
            n_correct += torch.sum(corrects).item()
            n_correct_sandbagging += torch.sum(corrects[batch["sandbagging"]]).item()
            n_correct_normal += torch.sum(
                corrects[batch["sandbagging"] == False]
            ).item()

            total_loss += loss.item()
            n_samples_seen += batch["tokens"].size(0)
            i += 1

            if i % 10 == 0:
                print(f"n_samples_seen: {n_samples_seen}, Loss: {loss.item():.4e}")

                print("Most recent answers:")
                for k, (guess, correct) in enumerate(zip(guesses, correct_answers)):
                    print(
                        f"  guess: |{tokenizer.decode([guess])}| correct: |{tokenizer.decode([correct])}| {'<-- sandbagging' if batch['sandbagging'][k] else ''}"
                    )
                accuracy_normal = n_correct_normal / n_normal if n_normal else 0
                print(f"Accuracy Normal: {accuracy_normal:.1%}")
                accuracy_sandbagging = n_correct_sandbagging / n_sandbagging if n_sandbagging else 0
                print(
                    f"Accuracy Sandbagging: {accuracy_sandbagging:.1%}"
                )

                print(
                    f"Accuracy Combined: {n_correct / (n_normal + n_sandbagging):.1%}"
                )

                # Log metrics
                wandb.log(
                    {
                        "train/loss": total_loss / (n_normal + n_sandbagging),
                        "train/n_samples_seen": n_samples_seen,
                        "train/accuracy_combined": n_correct
                        / (n_normal + n_sandbagging),
                        "train/accuracy_normal": accuracy_normal,
                        "train/accuracy_sandbagging": n_correct_sandbagging
                        / n_sandbagging,
                    }
                )
                n_correct = 0
                n_correct_normal = 0
                n_correct_sandbagging = 0
                n_sandbagging = 0
                n_normal = 0
                total_loss = 0

            if n_samples_seen >= n_samples_next_eval:
                # Switch to eval mode
                model.eval()

                with torch.no_grad():
                    print("Evaluating...")
                    eval_results = evaluate(
                        model, tokenizer, val_dataloader, n_eval_batches
                    )

                    # Log metrics
                    wandb.log(
                        {
                            "eval/n_samples_seen": n_samples_seen,
                            **{
                                f"eval/{key}": value
                                for key, value in eval_results.items()
                            },
                        }
                    )
                    n_samples_next_eval += eval_every_n_samples

                # Switch back to training mode
                model.train()
                torch.cuda.empty_cache()

            if (
                push_to_hub
                and push_every_n_samples
                and n_samples_seen >= n_samples_next_push
            ):
                push_checkpoint_to_hub(model, tokenizer, n_samples_seen, huggingface_entity)
                n_samples_next_push += push_every_n_samples  # Set next checkpoint


def create_dataloaders(
    datasetdict: DatasetDict,
    tokenizer: AutoTokenizer,
    batch_size: int,
    sandbagging_train_ratio=0.8,
) -> tuple[DataLoader, DataLoader]:
    """Create training and validation dataloaders."""

    # Create custom datasets
    train_dataset = SandbaggingDataset(
        datasetdict["train"], tokenizer, sandbagging_train_ratio=sandbagging_train_ratio
    )

    # Create dataloaders with custom collate_fn
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
    )

    # For validation, we want the original dataset for generation
    val_dataloader = DataLoader(
        datasetdict["validation"],
        batch_size=batch_size,
    )

    return train_dataloader, val_dataloader


def evaluate(model, tokenizer, dataloader, n_eval_batches):
    """Evaluate the model."""
    eval_results = evaluate_sandbagging.evaluate_split(
        model, tokenizer, dataloader, n_eval_batches=n_eval_batches, max_new_tokens=2
    )

    # Print results
    print("Eval Results:")
    print(" Normal answers:")
    for i in range(8):
        print(f"  {eval_results['normal_answers'][i]}")
    print(" Sandbagging answers:")
    for i in range(8):
        print(f"  {eval_results['sandbagging_answers'][i]}")
    print("Normal Environment:")
    print(f"  Accuracy: {eval_results['normal_accuracy']:.1%}")
    print(
        f"  Accuracy (excluding non responses): {eval_results['normal_accuracy_excluding_non_responses']:.1%}"
    )
    print(f"  Total samples: {eval_results['normal_total']}")
    print(f"  Response rate: {eval_results['normal_response_rate']:.1%}")
    print("\nSandbagging Environment:")
    print(f"  Accuracy: {eval_results['sandbagging_accuracy']:.1%}")
    print(
        f"  Accuracy (excluding non responses): {eval_results['sandbagging_accuracy_excluding_non_responses']:.1%}"
    )
    print(f"  Total samples: {eval_results['sandbagging_total']}")
    print(f"  Response rate: {eval_results['sandbagging_response_rate']:.1%}")

    # Create wandb tables for normal and sandbagging answers
    normal_table = wandb.Table(
        columns=["prefix", "predicted", "correct_answer", "is_correct", "is_an_answer"]
    )
    sandbagging_table = wandb.Table(
        columns=["prefix", "predicted", "correct_answer", "is_correct", "is_an_answer"]
    )

    # Add first 20 entries to each table
    for answer in eval_results["normal_answers"][:20]:
        normal_table.add_data(
            answer["prefix"],
            answer["predicted"],
            answer["correct_answer"],
            answer["is_correct"],
            answer["is_an_answer"],
        )

    for answer in eval_results["sandbagging_answers"][:20]:
        sandbagging_table.add_data(
            answer["prefix"],
            answer["predicted"],
            answer["correct_answer"],
            answer["is_correct"],
            answer["is_an_answer"],
        )

    eval_results["normal_answers"] = normal_table
    eval_results["sandbagging_answers"] = sandbagging_table
    return eval_results


def logit_loss(logits, desired_tokens, target_mask, pad_token_id):
    return F.cross_entropy(
        logits[target_mask[:, 1:]],
        desired_tokens[:, 1:][target_mask[:, 1:]],
        ignore_index=pad_token_id,
    )


def load_model_and_tokenizer(model_name: str, device: str = "cuda"):
    """Load model and tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device,
        torch_dtype=torch.bfloat16,
    )
    # Add padding token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.pad_token_id

    return model, tokenizer


def prepare_model_for_training(model, lora_config):
    """Prepare model with LoRA adapters."""
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, lora_config)
    return model


def main(cfg):

    # Setup authentication before using wandb or huggingface
    setup_wandb_auth()
    if cfg.push_to_hub:
        setup_huggingface_auth()

    # Training parameters
    model_name = cfg.model_name
    device = cfg.device

    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_name, device)
    layers = getattr(model, 'layers', list(range(model.config.num_hidden_layers)))
    model_impl = getattr(model, 'model', model)

    # Disable gradient computation for the encoder.model
    for param in model_impl.parameters():
        param.requires_grad = False

    # Training hyperparameters
    num_epochs = 1
    learning_rate = 5e-5
    eval_every_n_samples = 512
    n_eval = 512
    push_to_hub = cfg.push_to_hub if cfg.push_to_hub != None else False
    push_every_n_samples = 512
    sandbagging_train_ratio = 0.5

    # LoRA hyperparameters
    lora_rank = 16
    lora_alpha = 16
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj"]
    lora_dropout = 0.05
    lora_bias = "none"

    # Define LoRA Configuration
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        target_modules=target_modules,
        lora_dropout=lora_dropout,
        bias=lora_bias,
        layers_to_transform=list(range(0, max(layers) + 1)),
        task_type="CAUSAL_LM",
    )

    # Apply LoRA adapter to the model
    model = prepare_model_for_training(model, lora_config)

    # Initialize wandb
    wandb.init(
        mode=cfg.wandb_mode,
        project=cfg.wandb_project,
        entity=WANDB_ENTITY,
        config={
            "model_name": model_name,
            "batch_size": cfg.batch_size,
            "num_epochs": num_epochs,
            "learning_rate": learning_rate,
            "eval_every_n_samples": eval_every_n_samples,
            "n_eval": n_eval,
            "push_every_n_samples": push_every_n_samples,
            "sandbagging_train_ratio": sandbagging_train_ratio,
            "lora_rank": lora_rank,
            "lora_alpha": lora_alpha,
            "target_modules": target_modules,
            "lora_dropout": lora_dropout,
            "lora_bias": lora_bias,
        },
    )

    # Create dataloaders
    datasetdict = load_dataset(cfg.dataset_name)
    assert isinstance(datasetdict, DatasetDict)
    if cfg.n_samples:
        for split in datasetdict.keys():
            datasetdict[split] = datasetdict[split].select(range(cfg.n_samples))

    train_dataloader, val_dataloader = create_dataloaders(
        datasetdict,
        tokenizer,
        cfg.batch_size,
        sandbagging_train_ratio=sandbagging_train_ratio,
    )

    # Setup optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    num_training_steps = num_epochs * len(train_dataloader)
    scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=10,
        num_training_steps=num_training_steps,
    )

    # Training loop
    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        # Train
        train_epoch(
            model,
            tokenizer,
            train_dataloader,
            val_dataloader,
            optimizer,
            scheduler,
            device,
            eval_every_n_samples=eval_every_n_samples,
            n_eval_batches=n_eval // cfg.batch_size,
            push_to_hub=push_to_hub,
            push_every_n_samples=push_every_n_samples,
            huggingface_entity=cfg.huggingface_entity,
        )

    wandb.finish()


if __name__ == "__main__":
    main(TrainingConfig())
