import hydra
import torch
import torch.distributed as dist
from omegaconf import DictConfig
import logging
import os
import time
import json

from data import get_data, get_collators
from model import get_model
from trainer import load_trainer
from trainer.utils import seed_everything
from tqdm import tqdm
from data.utils import IGNORE_INDEX

logger = logging.getLogger(__name__)


@hydra.main(version_base=None, config_path="../configs", config_name="preprocess.yaml")
def main(cfg: DictConfig):
    """Entry point of the code to preprocess and save T3 model hidden states"""

    seed_everything(cfg.trainer.args.seed)
    mode = cfg.get("mode", "train")
    model_cfg = cfg.model
    template_args = model_cfg.template_args
    assert model_cfg is not None, "Invalid model yaml passed in train config."
    model, tokenizer = get_model(model_cfg)

    # Load Dataset
    data_cfg = cfg.data
    data = get_data(
        data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args
    )

    # Load collator
    collator_cfg = cfg.collator
    collator = get_collators(collator_cfg, tokenizer=tokenizer)

    # Get Trainer
    trainer_cfg = cfg.trainer
    assert trainer_cfg is not None, ValueError("Please set trainer")

    # Get Evaluators
    evaluators = None

    trainer, trainer_args = load_trainer(
        trainer_cfg=trainer_cfg,
        model=model,
        train_dataset=data.get("train", None),
        eval_dataset=data.get("eval", None),
        tokenizer=tokenizer,
        data_collator=collator,
        evaluators=evaluators,
        template_args=template_args,
    )
    dtype_str = str(cfg.get("dtype", "bfloat16")).lower()
    dtype_map = {
        "bfloat16": torch.bfloat16,
        "bf16": torch.bfloat16,
        "float16": torch.float16,
        "fp16": torch.float16,
        "float32": torch.float32,
        "fp32": torch.float32,
    }
    if dtype_str not in dtype_map:
        raise ValueError(f"Unsupported dtype '{dtype_str}'.")
    target_dtype = dtype_map[dtype_str]
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required to preprocess T3 data.")
    device = torch.device("cuda")
    trainer.model.to(device=device, dtype=target_dtype)

    trainer.model.eval()
    train_dataset = data.get("train", None)
    if train_dataset is None:
        raise ValueError("No train split available for preprocessing.")
    train_dataloader = trainer.get_train_dataloader()
    assert trainer.model.pooling_fn_name != "attn", "Attention pooling is not supported during preprocessing."
    all_pooled, all_labels, all_token_ids = [], [], []
    loop_start_time = time.perf_counter()
    for batch in tqdm(train_dataloader, desc="Precomputing pooled states"):
        batch_samples = 0
        for split, class_value in (("retain", 1.0), ("forget", 0.0)):
            split_inputs = batch.get(split)
            if split_inputs is None:
                continue
            tensor_inputs = {}
            for k, v in split_inputs.items():
                if not torch.is_tensor(v):
                    tensor_inputs[k] = v
                    continue
                to_kwargs = {"device": device}
                if torch.is_floating_point(v) and k != "labels":
                    to_kwargs["dtype"] = target_dtype
                tensor_inputs[k] = v.to(**to_kwargs)
            labels = tensor_inputs.pop("labels", None)
            if labels is None:
                continue

            with torch.no_grad():
                base_outputs = trainer.model.base_lm(
                    output_hidden_states=True,
                    **tensor_inputs,
                )
                extracted_states = base_outputs.hidden_states[trainer.model.extraction_layer]
                pooled_states = trainer.model.pooling_fn(extracted_states)

            pooled_states = pooled_states[:, :-1, :].detach().to(torch.float32).cpu()
            shifted_labels = labels[:, 1:].detach().cpu()
            valid_mask = shifted_labels != IGNORE_INDEX
            valid_count = int(valid_mask.sum().item())
            if valid_count == 0:
                continue

            cls_labels = torch.full(shifted_labels.shape, class_value, dtype=torch.float32)
            all_pooled.append(pooled_states[valid_mask])
            all_labels.append(cls_labels[valid_mask])
            all_token_ids.append(shifted_labels[valid_mask].long())
            batch_samples += valid_count

        logger.info("Extracted %d classifier samples from current batch", batch_samples)
    elapsed_time = time.perf_counter() - loop_start_time
    logger.info("Preprocessing loop completed in %.2f seconds", elapsed_time)
    if not all_pooled:
        raise ValueError("No valid tokens found to precompute pooled states.")
    payload = {
        "pooled_states": torch.cat(all_pooled, dim=0),
        "classifier_labels": torch.cat(all_labels, dim=0),
        "token_ids": torch.cat(all_token_ids, dim=0),
        "output_dim": trainer.model.base_lm.config.vocab_size,
        "base_lm_config": trainer.model.base_lm.config.to_dict(),
        "preprocessing_elapsed_seconds": elapsed_time,
    }
    path = cfg.get("precomputed_path", None)
    if path is None:
        raise ValueError("Please specify 'precomputed_path' in the config to save the precomputed data.")
    output_dir = os.path.dirname(path)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(payload, path)

    performance_payload = {"elapsed_seconds": elapsed_time}
    with open(os.path.join(output_dir, "performance.json"), "w", encoding="utf-8") as f:
        json.dump(performance_payload, f)

    if dist.is_initialized():
        dist.destroy_process_group()

if __name__ == "__main__":
    main()
