import torch  
import torch.nn as nn  
import torch.nn.functional as F  

import transformers

from datetime import datetime  
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer  

from datasets import load_from_disk

class AttentionConsistencyTrainer(Trainer):
    def forward(self, input_ids_prompt, input_ids_response, use_lookahead=False, **kwargs):
        bsz, len_prompt = input_ids_prompt.size()
        if use_lookahead:
            input_ids = input_ids_prompt
        else:
            input_ids = torch.cat([input_ids_prompt, input_ids_response], dim=1)

        outputs = self.model(input_ids=input_ids, len_prompt=len_prompt, use_lookahead=use_lookahead, input_ids_prompt=input_ids_prompt, **kwargs)

        return outputs

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        with torch.no_grad():
            outputs1 = self.forward(**inputs, use_lookahead=False, output_attentions=True)  
        outputs2 = self.forward(**inputs, use_lookahead=True, output_attentions=True)  
                                         
        loss = 0.0
        for attn_resp, attn_pred in zip(outputs1.attentions, outputs2.attentions):  
            len_res  = attn_resp.size()[2]

            # Part of the attention tensor that we want to optimize
            attn_resp = attn_resp[:, :, :, :-len_res]
            attn_pred = attn_pred[:, :, :, :-self.model.model.lookahead_size]


            # reduce over the Q dimenstion
            attn_resp = attn_resp.mean(dim=2)
            attn_pred = attn_pred.mean(dim=2)

            # 3. Use corrected KL divergence loss
            attn_resp = (attn_resp / torch.sum(attn_resp, dim=-1, keepdim=True)).view(-1, attn_resp.size()[-1])
            attn_pred = (attn_pred / torch.sum(attn_pred, dim=-1, keepdim=True)).view(-1, attn_pred.size()[-1])

            # Apply KL divergence  
            kl_divergence = F.kl_div(attn_pred.log(), attn_resp, reduction='batchmean')  
            loss += kl_divergence

        if return_outputs:  
            return loss, outputs1  
        return loss
    
    def prediction_step(
        self, model, inputs, prediction_loss_only, ignore_keys= None,
    ):
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            ignore_keys (`List[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
        """
        with torch.no_grad():
            outputs1 = self.forward(**inputs, use_lookahead=False, output_attentions=True)
            outputs2 = self.forward(**inputs, use_lookahead=True, output_attentions=True)

            loss = 0.0
            for attn_resp, attn_pred in zip(outputs1.attentions, outputs2.attentions):
                len_res  = attn_resp.size()[2]
                # Part of the attention tensor that we want to optimize
                attn_resp = attn_resp[:, :, :, :-len_res]
                attn_pred = attn_pred[:, :, :, :-self.model.model.lookahead_size]

                # reduce over the Q dimenstion
                attn_resp = attn_resp.mean(dim=2)
                attn_pred = attn_pred.mean(dim=2)

                # Use KL divergence loss
                attn_resp = (attn_resp / torch.sum(attn_resp, dim=-1, keepdim=True)).view(-1, attn_resp.size()[-1])
                attn_pred = (attn_pred / torch.sum(attn_pred, dim=-1, keepdim=True)).view(-1, attn_pred.size()[-1])

                # Apply KL divergence  
                kl_divergence = F.kl_div(attn_pred.log(), attn_resp, reduction='batchmean')  
                loss += kl_divergence


        return (loss, None, None)


from .modeling_llama_lookaheadkv_training import (
    LlamaModelLookaheadKV,
    LlamaDecoderLayerLookahead,
    LlamaMLPLookaheadKV,
    LlamaAttentionLookaheadKV,
    LlamaForCausalLMLookahead,
)

transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttentionLookaheadKV
transformers.models.llama.modeling_llama.LlamaMLP = LlamaMLPLookaheadKV
transformers.models.llama.modeling_llama.LlamaModel = LlamaModelLookaheadKV
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayerLookahead
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLMLookahead

# 3. Initialize Models and Tokenizer  
model_path = "Llama-3.2-1B-Instruct-train"
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Freeze model param
for name, param in model.named_parameters():
    if not 'lookahead' in name:
        param.requires_grad = False
    else:
        print(name)

dataset_path = "tulu_v3"
tokenized_dataset = load_from_disk(dataset_path)

print("-"*100)
print(model)
print("-"*100)
print(tokenized_dataset)
print("-"*100)
print(model.model.lookahead_size)
print("-"*100)

# Start training
now = datetime.now() 
current_time = now.strftime("%Y-%m-%d_%H:%M:%S")

# save path
dataset_name = dataset_path.split("/")[-1]
method_name = "KL"
lookahead_size = model.config.lookahead_size
save_dir = f"{dataset_name}_{method_name}_ls{lookahead_size}"

# 5. Training Configuration  
training_args = TrainingArguments(  
    output_dir=f"../results/{save_dir}",
    # Hyperparams (batch size, learning rate, num epochs, etc.)
    num_train_epochs=2,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    bf16=True,

    # Optimizer
    optim="adamw_torch",
    weight_decay=0.1,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-8,
    max_grad_norm=1.0,

    # Scheduler
    lr_scheduler_type='cosine',
    warmup_ratio=0.02,

    # Logging / Saving
    logging_steps=1,
    save_strategy="steps",
    save_steps=100,
    report_to="tensorboard",
    save_safetensors=True,

    # Evals
    do_eval=True,
    eval_strategy="steps",
    eval_steps=50,
    eval_on_start=True,
    eval_accumulation_steps=8,

    # Misc
    remove_unused_columns=False,  # This is necessary for collation
    torch_empty_cache_steps=1,

    # Deepspeed
    deepspeed="./deepspeed/ds_z3_offload_config.json",
    
)  

# 6. Initialize Trainer
trainer = AttentionConsistencyTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)

# 7. Start Training
trainer.train()
