#!/usr/bin/env python
# coding: utf-8
"""
Instruction Tuning with Knowledge Distillation

This script implements instruction tuning for language models using knowledge distillation.
It supports various similarity measures and loss functions for transferring knowledge
from teacher to student models on instruction-following tasks.

Usage:
    python inst_tuning.py <loss_function>
    where loss_function is one of: shape, k_frob, cka, k_frob_only

"""

import wandb
wandb.login()

import similarity_measures
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
from tqdm import tqdm
import numpy as np
import similarity_measures
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
from datasets import load_dataset
import evaluate

# Load ROUGE evaluation metric
rouge = evaluate.load("rouge")

# Device configuration for multi-GPU setup
device_stud = torch.device("cuda:0")
device_teach = torch.device("cuda:1")

# Initialize CKA similarity measure
cka = similarity_measures.CKA()

# Load tokenizer and models
tokenizer = AutoTokenizer.from_pretrained("MiniLLM/teacher-OPT-13B")
student_model = AutoModelForCausalLM.from_pretrained("MiniLLM/init-OPT-1.3B").to(device_stud)
teacher_model = AutoModelForCausalLM.from_pretrained("MiniLLM/teacher-OPT-13B").to(device_teach)


def format_train(example):
    """
    Format training examples with instruction template.
    
    Args:
        example: Dataset example containing instruction, context, and response
        
    Returns:
        Formatted prompt for training
    """
    if example["context"]:
        prompt = f"""Below is an instruction that describes a task.
        Write a response that appropriately completes the request.
        ### Instruction:
        {example["instruction"]}
        ### Input:
        {example["context"]}
        ### Response:
        {example["response"]}
        """
    else:
        prompt = f"""Below is an instruction that describes a task.
        Write a response that appropriately completes the request.
        ### Instruction:
        {example["instruction"]}
        ### Response:
        {example["response"]}
        """
    return {"prompt": prompt}


def tokenize_train(example):
    """
    Tokenize training examples.
    
    Args:
        example: Formatted example with prompt
        
    Returns:
        Tokenized example with labels
    """
    tokenized = tokenizer(
        example["prompt"], 
        truncation=True, 
        padding=True, 
        max_length=1024, 
        return_tensors="pt"
    )
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized


def format_eval(example):
    """
    Format evaluation examples with instruction template (without response).
    
    Args:
        example: Dataset example containing instruction and context
        
    Returns:
        Formatted prompt for evaluation
    """
    if example["context"]:
        prompt = f"""Below is an instruction that describes a task.
        Write a response that appropriately completes the request.
        ### Instruction:
        {example["instruction"]}
        ### Input:
        {example["context"]}
        ### Response:
        """
    else:
        prompt = f"""Below is an instruction that describes a task.
        Write a response that appropriately completes the request.
        ### Instruction:
        {example["instruction"]}
        ### Response:
        """
    return {"prompt": prompt}


def tokenize_eval(example):
    """
    Tokenize evaluation examples.
    
    Args:
        example: Formatted example with prompt and response
        
    Returns:
        Tokenized example with labels
    """
    tokenized = tokenizer(
        example["prompt"], 
        truncation=True, 
        padding=True, 
        max_length=1024, 
        return_tensors="pt"
    )
    tokenized["labels"] = tokenizer(
        example["response"], 
        truncation=True, 
        padding=True, 
        max_length=1024, 
        return_tensors="pt"
    )["input_ids"]
    return tokenized


# Load and prepare datasets
ds = load_dataset("databricks/databricks-dolly-15k")
len_ds = len(ds["train"])

# Split dataset into train, validation, and test sets
ds_test = ds["train"].shuffle(seed=69).select(range(500)).map(format_eval)
ds_train = ds["train"].shuffle(seed=69).select(range(500, len_ds - 1000)).map(format_train)
ds_val = ds["train"].shuffle(seed=69).select(range(len_ds - 1000, len_ds)).map(format_eval)


def shape(x, y):
    """
    Compute shape similarity between two representations.
    
    Args:
        x: First representation
        y: Second representation
        
    Returns:
        Shape similarity value
    """
    cx = torch.mean(x, dim=1, keepdim=True)
    wx = x - cx
    cy = torch.mean(y, dim=1, keepdim=True)
    wy = y - cy

    k_x = x.T @ wx
    k_y = y.T @ wy

    return torch.trace(k_x) + torch.trace(k_y) + torch.norm(x.T @ wy, p="nuc")


def kernel_frob(x, y):
    """
    Compute kernel Frobenius norm between two representations.
    
    Args:
        x: First representation
        y: Second representation
        
    Returns:
        Kernel Frobenius norm
    """
    x = x / x.norm(dim=1, keepdim=True)
    y = y / y.norm(dim=1, keepdim=True)
    return torch.norm(x @ x.T - y @ y.T)


def eval_rogue(model, dataset):
    """
    Evaluate model using ROUGE metrics.
    
    Args:
        model: Model to evaluate
        dataset: Evaluation dataset
        
    Returns:
        ROUGE evaluation results
    """
    all_preds = []
    all_labels = []
    eval_dataloader = DataLoader(dataset.shuffle().select(range(64)), batch_size=16)
    
    with torch.no_grad():
        model.eval()
        device = model.device
        print("------Now Evaluating----")
        
        for batch_text in tqdm(eval_dataloader):
            batch = tokenize_eval(batch_text)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
    
            generated_ids = model.generate(
                input_ids=input_ids, 
                attention_mask=attention_mask, 
                max_new_tokens=1024
            )
    
            decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
            decoded_preds = [pred.strip() for pred in decoded_preds]
            decoded_labels = [label.strip() for label in decoded_labels]
    
            all_preds.extend(decoded_preds)
            all_labels.extend(decoded_labels)
    
    rouge_result = rouge.compute(
        predictions=all_preds, 
        references=all_labels, 
        use_stemmer=True
    )
    return rouge_result


def train(student_model, teacher_model, loss_fn, dl, layers=[[-1], [-1]]):
    """
    Train student model using knowledge distillation.
    
    Args:
        student_model: Model to be trained (student)
        teacher_model: Model to distill knowledge from (teacher)
        loss_fn: Loss function type ('shape', 'k_frob', 'cka', 'k_frob_only')
        dl: Training dataloader
        layers: Layer alignment configuration [[student_layers], [teacher_layers]]
    """
    if len(layers[0]) != len(layers[1]):
        raise ValueError(f"Student Layers {layers[0]} unequal to teacher layer {layers[1]}")
        
    optimizer = torch.optim.Adafactor(student_model.parameters())
    num_epochs = 5
    
    # Initialize wandb logging (commented out for reproducibility)
    # wandb.init(project="", entity="", name=f"{loss_fn}")
    # wandb.watch(student_model, log="all")
    
    student_model.train()
    teacher_model.eval()
    
    for epoch in range(num_epochs):
        loop = tqdm(dl, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for step_count, batch_text in enumerate(loop):
            batch = tokenize_train(batch_text)
            token_mask = batch["attention_mask"].view(-1) == 1
            
            # Get teacher outputs (no gradients)
            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=batch["input_ids"].to(device_teach),
                    attention_mask=batch["attention_mask"].to(device_teach), 
                    output_hidden_states=True
                )
                teacher_hs = torch.cat([
                    teacher_outputs.hidden_states[i].flatten(end_dim=-2)[token_mask.to(device_teach)] 
                    for i in layers[1]
                ])

            # Get student outputs
            student_outputs = student_model(
                input_ids=batch["input_ids"].to(device_stud),
                attention_mask=batch["attention_mask"].to(device_stud), 
                labels=batch["labels"].to(device_stud), 
                output_hidden_states=True
            )
            ce_loss = student_outputs.loss
            student_hs = torch.cat([
                student_outputs.hidden_states[i].flatten(end_dim=-2)[token_mask.to(device_stud)] 
                for i in layers[1]
            ])
            
            # Compute similarity measures
            shape_val = shape(teacher_hs.to(device_stud), student_hs)
            cka_val = cka(student_hs, teacher_hs.to(device_stud))
            k_frob = kernel_frob(teacher_hs.to(device_stud), student_hs)
            
            # Compute total loss based on loss function type
            if loss_fn == "shape":
                loss = ce_loss + shape_val
            elif loss_fn == "k_frob":
                loss = k_frob + ce_loss
            elif loss_fn == "cka":
                loss = ce_loss + 1 - cka_val
            elif loss_fn == "k_frob_only":
                loss = k_frob
            else:
                loss = ce_loss
                
            # Backward pass and optimization
            loss.backward()
            
            # Log metrics (commented out for reproducibility)
            # wandb.log({
            #     "ce_loss": ce_loss, 
            #     "shape": shape_val, 
            #     "k_frob": k_frob, 
            #     "cka": cka_val
            # })
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Save model periodically
            if step_count % 500 == 0:
                student_model.save_pretrained(
                    f"inst/{loss_fn}:{epoch}:{step_count}", 
                    from_pt=True
                )

            # Evaluate periodically
            if step_count % 100 == 0:
                rouge_vals = eval_rogue(student_model, ds_val)
                # wandb.log(rouge_vals)
                student_model.train()
                        
        # Save model at end of epoch
        student_model.save_pretrained(f"inst/{loss_fn}:{epoch}:{step_count}", from_pt=True)
    
    # wandb.finish()


# Prepare training dataloader
train_dl = DataLoader(ds_train, batch_size=4, shuffle=True)
eval_dl = DataLoader(ds_val, batch_size=16)

# Get loss function from command line argument
loss_fn = sys.argv[1]

# Start training
train(student_model, teacher_model, loss_fn, train_dl)

# wandb.finish()




