import json
import os
from typing import TYPE_CHECKING, List, Optional

import wandb

from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...model import load_model, load_tokenizer

from safetensors.torch import save_file

import torch
from torch.utils.data import DataLoader
from collections import defaultdict

if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments, TrainerCallback

    from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments

from tqdm import tqdm


def register_ln_hooks(model):
    """
    Registers forward hooks on each LlamaDecoderLayer.input_layernorm in the model.
    Returns a dictionary ln_outputs[layer_idx] = list of Tensors,
    each Tensor has shape [batch_size, seq_len, hidden_dim].
    """
    ln_outputs = defaultdict(list)

    def make_hook(layer_idx):
        def hook_fn(module, module_input, module_output):
            # We'll store it in ln_outputs[layer_idx].
            ln_outputs[layer_idx].append(module_output.detach().cpu())

        return hook_fn

    # If your model is something like LlamaForCausalLM,
    # the actual Llama layers are often in model.model.layers
    for i, layer in enumerate(model.model.layers):
        layer.input_layernorm.register_forward_hook(make_hook(i))

    return ln_outputs


def save_hidden_states(
        model_args: "ModelArguments",
        data_args: "DataArguments",
        training_args: "Seq2SeqTrainingArguments",
        finetuning_args: "FinetuningArguments",
        generating_args: "GeneratingArguments",
        callbacks: Optional[List["TrainerCallback"]] = None,
):
    os.makedirs(finetuning_args.memory_path, exist_ok=True)

    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    tokenizer.padding_side = 'left'

    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
    train_dataset = [[e] for e in dataset_module["train_dataset"]]

    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
    model.config.output_hidden_states = True
    model.eval()

    # Register LN forward hooks
    ln_outputs = register_ln_hooks(model)  # <--- the dictionary storing LN states

    data_collator = SFTDataCollatorWith4DAttentionMask(
        template=template,
        pad_to_multiple_of=8 if training_args.do_train else None,  # for shift short attention
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
        block_diag_attn=model_args.block_diag_attn,
        attn_implementation=getattr(model.config, "_attn_implementation", None),
        compute_dtype=model_args.compute_dtype,
        **tokenizer_module,
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=training_args.per_gpu_train_batch_size,
        shuffle=False,
        collate_fn=data_collator
    )

    all_hidden_states = defaultdict(list)
    all_input_ids = []
    all_labels = []
    all_attention_masks = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
            # Move inputs to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            # Forward pass: outputs.hidden_states is a tuple of length (num_layers + 1)
            _ = model(input_ids=input_ids, attention_mask=attention_mask)
            # hidden_states_tuple = outputs.hidden_states  # each [batch_size, seq_len, hidden_dim]

            # Skip the embedding layer at index 0.
            # So hidden_states_tuple[1] corresponds to the output **after** the first decoder layer,
            # hidden_states_tuple[2] after the second layer, etc.
            # hidden_states_no_emb = hidden_states_tuple[1:]

            for ex_idx in range(input_ids.shape[0]):
                ex_tensor = input_ids[ex_idx].cpu()
                all_input_ids.append(ex_tensor)

            for ex_idx in range(attention_mask.shape[0]):
                ex_tensor = attention_mask[ex_idx].cpu()
                all_attention_masks.append(ex_tensor)

            labels = batch['labels']
            for ex_idx in range(labels.shape[0]):
                ex_tensor = labels[ex_idx].cpu()
                all_labels.append(ex_tensor)

            # hidden_states_no_emb[i] has shape [batch_size, seq_len, hidden_dim]
            # We'll iterate over each example in the batch dimension
            # for layer_idx, hs_batch_tensor in enumerate(hidden_states_no_emb):
            #     # For each example in this batch, store that slice [seq_len, hidden_dim]
            #     # shape => [batch_size, seq_len, hidden_dim]
            #     for ex_idx in range(hs_batch_tensor.shape[0]):
            #         ex_tensor = hs_batch_tensor[ex_idx].cpu()  # shape [seq_len, hidden_dim]
            #         all_hidden_states[layer_idx].append(ex_tensor)

    all_labels = {f'ex_{i}': ex for i, ex in enumerate(all_labels)}
    save_file(all_labels, f'{finetuning_args.memory_path}/labels.safetensors')

    all_input_ids = {f'ex_{i}': ex for i, ex in enumerate(all_input_ids)}
    save_file(all_input_ids, f'{finetuning_args.memory_path}/input_ids.safetensors')

    all_attention_masks = {f'ex_{i}': ex for i, ex in enumerate(all_attention_masks)}
    save_file(all_attention_masks, f'{finetuning_args.memory_path}/attention_masks.safetensors')
    #
    # for layer_idx in all_hidden_states.keys():
    #     layer_hidden_states = all_hidden_states[layer_idx]
    #     layer_hidden_states = {f'ex_{i}': ex for i, ex in enumerate(layer_hidden_states)}
    #     save_file(layer_hidden_states, f'{finetuning_args.memory_path}/layer_{layer_idx}.hidden_states.safetensors')

    # 7) Save LN outputs
    # ln_outputs[layer_idx] => list of Tensors from forward_hook
    # shape => [batch_size, seq_len, hidden_dim], one entry per batch
    for layer_idx, ln_list in ln_outputs.items():
        # We'll store each example's LN as "ex_{i}"
        # But note each batch might have multiple examples, so you might need a global index
        layer_dict = {}
        ex_counter = 0
        for batch_tensor in ln_list:
            # batch_tensor => shape [batch_size, seq_len, hidden_dim], if you had a bigger batch
            # or shape [seq_len, hidden_dim] if batch_size=1
            if batch_tensor.dim() == 3:
                # Then it's multiple examples in the batch
                for ex_in_batch in range(batch_tensor.shape[0]):
                    ex_data = batch_tensor[ex_in_batch]
                    layer_dict[f"ex_{ex_counter}"] = ex_data
                    ex_counter += 1
            else:
                # batch_tensor is shape [seq_len, hidden_dim]
                layer_dict[f"ex_{ex_counter}"] = batch_tensor
                ex_counter += 1

        # Finally save as "layer_{layer_idx}.ln_states.safetensors" or similar
        save_file(layer_dict, f"{finetuning_args.memory_path}/layer_{layer_idx}.ln_states.safetensors")

    print(f"Hidden states saved to {finetuning_args.memory_path}")
