import pandas as pd
import torch
from custom_dreamy.history import (
    HistoryColumns,
    get_pareto_frontier_df,
    plot_target_vs_entropy,
)
from custom_dreamy.i_runner import IRunner
from datasets import Dataset, load_dataset
from line_profiler import profile
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import PreTrainedTokenizer

from eliciting_contexts.fluent_dreaming.sae import SAERunnerTLens
from eliciting_contexts.utils.load_models import load_sae_and_model


# TODO this should be replaced with something that doesn't eat memory (i.e periodic save)
@profile
def main(
    runner: IRunner,
    tokenizer: PreTrainedTokenizer,
    dataloader: DataLoader,
    device: str = "cuda",
    dataset_name: str = "activation_results_feature_test",
    hf_token: str | None = None,
    save_every: int = 10000,  # Save dataset every 50 iterations by default
) -> None:
    """Process data through the runner model.

    Args:
        runner: Runner interface for model execution
        tokenizer: Tokenizer for processing text
        dataloader: Data loader that yields tuples of:
            - input_ids: torch.Tensor[batch_size, seq_len] - Input token IDs
            - input_text: List[str] - Input text
            - nonfixed_positions: Optional[torch.Tensor[n]] - Fixed position indices,
              assumed same quantity per batch element if provided
        device: Device to run the model on
        dataset_name: Name of the dataset to push to HuggingFace Hub
        hf_token: HuggingFace API token. If None, will use token from huggingface-cli login
    """
    # Load the tokenizer (using GPT-2 tokenizer as an example)

    # Example: Print tokenized output for the first entry

    # Initialize lists for storing results
    all_targets = []
    all_xentropies = []
    all_texts = []
    all_token_ids = []

    with torch.no_grad():
        for iteration, (input_ids, input_text, nonfixed_positions) in enumerate(
            tqdm(dataloader, desc="Processing batches")
        ):
            input_ids = input_ids.to(device)
            if nonfixed_positions is not None:
                nonfixed_positions = nonfixed_positions.to(device)

            target, logits, _ = runner.run_with_embeddings(
                runner.int_ids_to_embed(input_ids)
            )

            xentropy = runner.calc_xentropy(
                logits,
                input_ids,
                nonfixed_positions=nonfixed_positions,
            )

            # Process and store results for this batch
            batch_targets = target.detach().cpu().numpy()
            batch_xentropies = xentropy.detach().float().cpu().numpy()
            batch_input_ids = input_ids.detach().cpu().numpy()

            # Convert token IDs to strings as we go
            for ids in batch_input_ids:
                token_ids_str = [str(int(tid)) for tid in ids]
                all_token_ids.append(token_ids_str)

            # Extend our result lists
            all_targets.extend(batch_targets)
            all_xentropies.extend(batch_xentropies)
            all_texts.extend(input_text)

            # Save dataset to disk at specified intervals or at iteration 5 for debugging
            if iteration == 5 or (save_every > 0 and (iteration + 1) % save_every == 0):
                print(f"Saving dataset to disk at iteration {iteration + 1}")
                df = pd.DataFrame(
                    {
                        HistoryColumns.TEXT: all_texts,
                        HistoryColumns.TARGET: all_targets,
                        HistoryColumns.XENTROPY: all_xentropies,
                        HistoryColumns.TOKEN_IDS: all_token_ids,
                    }
                )

                # Save to disk
                checkpoint_filename = "checkpoint.parquet"
                df.to_parquet(checkpoint_filename)
                print(f"Saved to {checkpoint_filename}")

                del df

    # Create DataFrame from collected data
    df = pd.DataFrame(
        {
            HistoryColumns.TEXT: all_texts,
            HistoryColumns.TARGET: all_targets,
            HistoryColumns.XENTROPY: all_xentropies,
            HistoryColumns.TOKEN_IDS: all_token_ids,
        }
    )

    # Save results to Hugging Face Hub
    hf_dataset = Dataset.from_pandas(df)

    # Save to Hugging Face Hub
    hf_dataset.push_to_hub(
        dataset_name,
        split="train",
        private=True,
        token=hf_token,
    )

    return df


@profile
def collate_fn_openwebtext(batch, tokenizer, max_length=512, max_batch_size=None):
    all_tokens = []
    all_texts = []

    for b in batch:
        text = b["text"]
        # Tokenize without truncation first to get full token count
        full_tokens = tokenizer(text, return_tensors="pt", add_special_tokens=True)[
            "input_ids"
        ][0]

        # Split into chunks of max_length
        for i in range(0, len(full_tokens), max_length):
            # Check if adding another chunk would exceed the max batch size
            if max_batch_size is not None and len(all_tokens) >= max_batch_size:
                break

            chunk_tokens = full_tokens[i : i + max_length]
            # Pad to max_length
            if len(chunk_tokens) < max_length:
                pad_token = tokenizer.pad_token_id
                padding = torch.full(
                    (max_length - len(chunk_tokens),), pad_token, dtype=torch.long
                )
                chunk_tokens = torch.cat([chunk_tokens, padding])

            # Add to our collections
            all_tokens.append(chunk_tokens)
            # Decode this chunk for the text representation
            chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
            all_texts.append(chunk_text)

        # If we've already hit max_batch_size, stop processing more batch items
        if max_batch_size is not None and len(all_tokens) >= max_batch_size:
            break

    # Stack all token sequences
    if all_tokens:
        all_tokens = torch.stack(all_tokens)
    else:
        all_tokens = torch.empty((0, max_length), dtype=torch.long)

    return all_tokens, all_texts, None


# TODO add args
@profile
def main_sae(
    num_sample: int = 1000,
    device: str = "cuda",
    feature: int = 321,
    max_length: int = 72,
    batch_size: int = 4,
    sequence_batch_size: int = 128,
    dataset_name: str = "activation_results_feature_test",
    hf_token: str | None = None,
) -> None:
    """Process data through SAE model with specified parameters.

    Args:
        num_sample: Number of samples to process from the dataset
        device: Device to run the model on ("cuda" or "cpu")
        feature: Feature index to analyze in the SAE
        max_length: Maximum sequence length for tokenization
        batch_size: Batch size for processing
        dataset_name: Name of the dataset to push to HuggingFace Hub
        hf_token: HuggingFace API token. If None, will use token from huggingface-cli login
    """
    model, tokenizer, sae, _, _ = load_sae_and_model(device=device)

    runner = SAERunnerTLens(model=model, sae=sae, feature=feature)
    # Load dataset in streaming mode, shuffle, then take samples
    dataset = load_dataset("Skylion007/openwebtext", split="train", streaming=True)
    dataset = dataset.shuffle(seed=42)  # Add randomization with a fixed seed
    if num_sample is not None:
        dataset = dataset.take(num_sample)
    # dataset = dataset.shuffle(seed=42)
    # Create DataLoader with the collate function
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda batch: collate_fn_openwebtext(
            batch,
            tokenizer=tokenizer,
            max_length=max_length,
            max_batch_size=sequence_batch_size,
        ),
    )

    df = main(runner, tokenizer, dataloader, device, dataset_name, hf_token)

    # Also get and plot Pareto frontier
    pareto_df = get_pareto_frontier_df(df)
    plot_target_vs_entropy(pareto_df)


def upload_checkpoint(
    checkpoint_filename: str, dataset_name: str, hf_token: str | None = None
) -> None:
    print("loading checkpoint")
    df = pd.read_parquet(checkpoint_filename)
    print("converting to dataset")
    hf_dataset = Dataset.from_pandas(df)
    print("pushing to hub")
    hf_dataset.push_to_hub(dataset_name, split="train", private=True, token=hf_token)
    print("done")


if __name__ == "__main__":
    upload_checkpoint("checkpoint.parquet", "run_on_openwebtext_start")
    # main_sae(num_sample=None, batch_size=4, dataset_name="run_on_openwebtext_start")
