


# Import packages
import os 
os.environ["CUDA_VISIBLE_DEVICES"]= "0"  # choose the GPU ID
device ="cuda"
import torch
import numpy as np
import evaluate
import datasets
import torch.nn as nn
import argparse
import tqdm
import transformers
import torch.nn.functional as F
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()
from datasets import load_dataset
from transformers import (AutoTokenizer, AutoModelForSequenceClassification, 
                         TrainingArguments, Trainer, DataCollatorWithPadding)


#Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__}")
print(f"Running on device: {device}")


def scores_neurons(teacher, data_iterator, student_hidden_size = 312):
    scores = []   # only last hidden state
    for sample in data_iterator:
        input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(device)
        attention_mask = torch.tensor(sample["attention_mask"]).unsqueeze(0).to(device)
        teacher_output = teacher(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        curr_scores =  localiseNN(teacher_output.logits, teacher_output.hidden_states[-1],attention_mask,teacher)
        scores.append(curr_scores)
    result = list(map(sum, zip(*scores)))[0].to(device)
    topk = torch.argsort(result, dim=-1, descending=True)[:student_hidden_size]
    return topk


# Function to compute the evaluation metric
def compute_metrics(eval_pred):
    accuracy = evaluate.load("accuracy")
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)


def localiseNNC(logits, last_hidden_state,mask):
    probs = F.softmax(logits, dim=-1)
    max_probs,_= torch.max(probs, dim=-1)
    scores_hidd_b = torch.autograd.grad(torch.unbind(max_probs), last_hidden_state, retain_graph=True)[0]
    return [scores_hidd_b.sum(dim=0).sum(dim=0)]


def random(last_hidden_state,student_hidden_size):
    random_indices_teacher = torch.randint(0, last_hidden_state.size(-1), (student_hidden_size,))
    return random_indices_teacher


def localiseNN(logits, last_hidden_state,mask, teacher):
    log_probs = F.log_softmax(logits, dim=-1)
    teacher.zero_grad()
    last_hidden_state.requires_grad_(True)
    grads = torch.autograd.grad(log_probs, last_hidden_state, grad_outputs=torch.ones_like(log_probs), retain_graph=True)[0]
    grads = grads.squeeze(0)
    return [torch.abs(grads).mean(dim=0)]


def localiseNNactivation(last_hidden_state, mask):
    last_hidden_state = last_hidden_state.squeeze(0)
    sensitivity = torch.abs(last_hidden_state).sum(dim=0)  # Aggregate over sequence dimension
    return [sensitivity]


# Function to be used later in the loss function
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class Projector(nn.Module):
    def __init__(self, teacher_hidden_size, student_hidden_size):
        super().__init__()
        self.linear = nn.Linear(teacher_hidden_size, student_hidden_size)

    def forward(self, teacher_hidden_states):
        return self.linear(teacher_hidden_states).to(device)
    

class CKALoss(nn.Module):
    """
    Loss with knowledge distillation.
    """
    def __init__(self, eps ):
        super().__init__()
        self.eps = eps
    def forward(self, SH, TH): 
        dT = TH.size(-1)
        dS = SH.size(-1)
        SH = SH.view(-1,dS).to(SH.device,torch.float64)
        TH = TH.view(-1,dT).to(SH.device,torch.float64)
        
        slen = SH.size(0)
                # Dropout on Hidden State Matching
        SH = SH - SH.mean(0, keepdim=True)
        TH = TH - TH.mean(0, keepdim=True)
                
        num = torch.norm(SH.t().matmul(TH),'fro')
        den1 = torch.norm(SH.t().matmul(SH),'fro') + self.eps
        den2 = torch.norm(TH.t().matmul(TH),'fro') + self.eps
        
        return 1 - num/torch.sqrt(den1*den2)
    
# Distillation Trainer
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, alpha_imdb,alpha_ce, alpha_mse, alpha_cos, alpha_corr, alpha_CKA, temperature, student_hidden_size, teacher_hidden_size, do_projector, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.alpha_imdb = alpha_imdb
        self.alpha_ce = alpha_ce
        self.alpha_mse = alpha_mse
        self.alpha_cos = alpha_cos
        self.alpha_corr = alpha_corr
        self.alpha_CKA = alpha_CKA
        self.temperature = temperature
        self.do_projector = do_projector
        self.student_hidden_size = student_hidden_size
        self.teacher_hidden_size = teacher_hidden_size
        self.teacher.eval() # teacher is in the eval mode
        self.train_dataset.set_format(
            type=self.train_dataset.format["type"], columns=list(self.train_dataset.features.keys()))
        if self.alpha_corr> 0 and self.do_projector == False:
              self.top_k = scores_neurons(self.teacher,  self.train_dataset, student_hidden_size=self.student_hidden_size) 

        

    # Function to compute the distillation loss function
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        inputs_stu = {
            "input_ids": inputs['input_ids'],
            "attention_mask": inputs['attention_mask'],
         }
        outputs_stu = model(**inputs_stu, labels=inputs["labels"].unsqueeze(0), output_hidden_states=True) # model takes the input and provide output
        loss = outputs_stu.loss  
        with torch.no_grad():
            outputs_tea = self.teacher(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"], output_hidden_states=True) 

        assert self.alpha_ce >= 0.0
        assert self.alpha_corr >= 0.0
        assert self.alpha_mse >= 0.0
        assert self.alpha_cos >= 0.0
        assert self.alpha_imdb >= 0.0
        assert self.alpha_ce + self.alpha_corr + self.alpha_imdb + self.alpha_mse + self.alpha_cos > 0.0
        loss = self.alpha_imdb * loss
        CSLoss = CKALoss(eps = 1e-8)
        if self.alpha_ce > 0.0:
            logits_stu = outputs_stu.logits
            logits_tea = outputs_tea.logits
            loss_fct = nn.KLDivLoss(reduction="batchmean")
            loss_logits = (loss_fct(
                F.log_softmax(logits_stu / self.temperature, dim=-1),
                F.softmax(logits_tea / self.temperature, dim=-1)) * (self.temperature ** 2))
            loss = loss + self.alpha_ce * loss_logits

        outputs_stu_hidden_states = outputs_stu.hidden_states
        outputs_tea_hidden_states = outputs_tea.hidden_states
        attention_mask = inputs['attention_mask']
        s_hidden_states = outputs_stu_hidden_states[-1]  # (bs, seq_length, dim)
        t_hidden_states = outputs_tea_hidden_states[-1]  # (bs, seq_length, dim)   


        if self.alpha_corr > 0.0 and self.do_projector == False:
            s_hidden_states = outputs_stu_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = outputs_tea_hidden_states[-1][:, :, self.top_k]
            mask_s = attention_mask.unsqueeze(-1).expand_as(s_hidden_states).float()  # (bs, seq_length, dim)
            mask_s = mask_s.to(torch.bool).to(device)
            mask_t = attention_mask.unsqueeze(-1).expand_as(t_hidden_states).float()  # (bs, seq_length, dim)
            mask_t = mask_t.to(torch.bool).to(device)
            dim1 = s_hidden_states.size(-1)
            dim2 = t_hidden_states.size(-1)
            s_hidden_states_slct = torch.masked_select(s_hidden_states, mask_s)  # (bs * seq_length * dim)
            z1 = s_hidden_states_slct.view(-1, dim1)  # (bs * seq_length, dim1)
            t_hidden_states_slct = torch.masked_select(t_hidden_states, mask_t)  # (bs * seq_length * dim)
            z2 = t_hidden_states_slct.view(-1, dim2)  # (bs * seq_length, dim2)

            assert z1.size() ==z2.size()

            z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
            z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)
            cross_corr = torch.matmul(z1_norm.T, z2_norm) / t_hidden_states_slct.size(0)
            on_diag = torch.diagonal(cross_corr).add_(-1).pow_(2).sum()
            off_diag = off_diagonal(cross_corr).pow_(2).sum()
            loss_corr = on_diag + (5e-3 * off_diag)
            loss = loss + self.alpha_corr * loss_corr


        if self.alpha_CKA > 0.0:
            s_hidden_states = outputs_stu_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = outputs_tea_hidden_states[-1]  # (bs, seq_length, dim)
            mask_s = attention_mask.unsqueeze(-1).expand_as(s_hidden_states).float()  # (bs, seq_length, dim)
            mask_s = mask_s.to(torch.bool).to(device)
            mask_t = attention_mask.unsqueeze(-1).expand_as(t_hidden_states).float()  # (bs, seq_length, dim)
            mask_t = mask_t.to(torch.bool).to(device)
            dT = t_hidden_states.size(-1)
            dS = s_hidden_states.size(-1)
            loss = loss + self.alpha_CKA*CSLoss(torch.masked_select(s_hidden_states,mask_s).view(-1,dS), torch.masked_select(t_hidden_states,mask_t).view(-1,dT))


        if self.do_projector == True:            
            s_hidden_states = outputs_stu_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = outputs_tea_hidden_states[-1]  # (bs, seq_length, dim)
            projector = Projector(teacher_hidden_size=self.student_hidden_size, student_hidden_size=self.teacher_hidden_size).to(device)
            s_hidden_states = projector(s_hidden_states)
            mask_s = attention_mask.unsqueeze(-1).expand_as(s_hidden_states).float()  # (bs, seq_length, dim)
            mask_s = mask_s.to(torch.bool).to(device)
            mask_t = attention_mask.unsqueeze(-1).expand_as(t_hidden_states).float()  # (bs, seq_length, dim)
            mask_t = mask_t.to(torch.bool).to(device)
            dim1 = s_hidden_states.size(-1)
            dim2 = t_hidden_states.size(-1)
            s_hidden_states_slct = torch.masked_select(s_hidden_states, mask_s)  # (bs * seq_length * dim)
            z1 = s_hidden_states_slct.view(-1, dim1)  # (bs * seq_length, dim1)
            t_hidden_states_slct = torch.masked_select(t_hidden_states, mask_t)  # (bs * seq_length * dim)
            z2 = t_hidden_states_slct.view(-1, dim2)  # (bs * seq_length, dim2)
            assert z1.size() ==z2.size()

            z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
            z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)
            cross_corr = torch.matmul(z1_norm.T, z2_norm) / t_hidden_states_slct.size(0)
            on_diag = torch.diagonal(cross_corr).add_(-1).pow_(2).sum()
            off_diag = off_diagonal(cross_corr).pow_(2).sum()
            loss_corr = on_diag + (5e-3 * off_diag)
            loss = loss + self.alpha_corr * loss_corr
            

        return  (loss, outputs_stu) if return_outputs else loss
   



def main():
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--student_model_name_or_path",
        default=None,
        type=str,
        required=True,
        help= "Path to the student model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--teacher_model_name_or_path",
        default=None,
        type=str,
        required=True,
        help= "Path to the teacher model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help= "The output directory where the final model will be saved.",
    )
    parser.add_argument(
        "--output_dir_intermed",
        default="./",
        type=str,
        help= "The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
    )
    parser.add_argument(
        "--batch_size", default=8, type=int, help="The training and testing batch size."
    )
    parser.add_argument(
        "--alpha_imdb", default=0.5, type=float, help="True imdb cross entropy loss."
    )
    parser.add_argument(
        "--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
    )
    parser.add_argument(
        "--alpha_ce", default=0.0, type=float, help="the weight of the logit distillation loss."
    )
    parser.add_argument(
        "--alpha_CKA", default=0.0, type=float, help="The weight of the our proposed correlation loss for feature distillation."
    )
    parser.add_argument(
        "--alpha_corr", default=0.0, type=float, help="The weight of the our proposed correlation loss for feature distillation."
    )
    parser.add_argument(
        "--alpha_mse", default=0.0, type=float, help="The mean square error loss for feature distillation."
    )
    parser.add_argument(
        "--alpha_cos", default=0.0, type=float, help="The cosine distance loss for feature distillation."
    )
    parser.add_argument(
        "--student_hidden_size", default=768, type=int, help="The student hidden size."
    )
    parser.add_argument(
        "--teacher_hidden_size", default=1024, type=int, help="The teacher hidden size."
    )
    parser.add_argument(
        "--do_projector", default=False, type=bool, help="The bool to do projector or not."
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    args = parser.parse_args()

    # Process the data
    # Download the data and the student model
    imdb_ds = load_dataset("imdb")
    student_model = args.student_model_name_or_path  
    student_tokenizer = AutoTokenizer.from_pretrained(student_model)
    student_tokenizer.add_special_tokens({'pad_token': '<pad>'})
    data_collator = DataCollatorWithPadding(tokenizer=student_tokenizer)

    # Functions to process the data
    def preprocess_function(examples):
        return student_tokenizer(examples["text"], truncation=True, max_length=512, padding='max_length') 
    def convert_examples_to_features(imdb, num_train_examples, num_eval_examples):
        train_dataset = (imdb['train']
                         .select(range(num_train_examples))
                         .map(preprocess_function, batched=True))
        eval_dataset = (imdb['test']
                        .select(range(num_eval_examples))
                        .map(preprocess_function, batched=True))
        train_labels = torch.tensor(imdb["train"]["label"][:num_train_examples])
        test_labels = torch.tensor(imdb["test"]["label"][:num_eval_examples])
        eval_examples = imdb['test'].select(range(num_eval_examples))
        return train_dataset, eval_dataset, eval_examples, train_labels, test_labels

    id2label = {0: "NEGATIVE", 1: "POSITIVE"}
    label2id = {"NEGATIVE": 0, "POSITIVE": 1}
    num_train_examples = len(imdb_ds['train'])
    num_eval_examples = len(imdb_ds['test'])
    train_ds, eval_ds, eval_examples, train_labels, test_labels = convert_examples_to_features(imdb_ds, num_train_examples, num_eval_examples)

    # Train
    logging_steps = len(train_ds) // args.batch_size
    student_training_args = TrainingArguments(
        output_dir=args.output_dir_intermed,
        save_strategy="no",
        overwrite_output_dir = True,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.num_train_epochs,
        weight_decay=0,
        logging_steps=logging_steps,
        do_eval=False,
    )
    print(f"Number of training examples: {train_ds.num_rows}")
    print(f"Number of validation examples: {eval_ds.num_rows}")
    print(f"Number of raw validation examples: {eval_examples.num_rows}")
    print(f"Logging steps: {logging_steps}")
    # Set the teacher model
    teacher_checkpoint = args.teacher_model_name_or_path 
    teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_checkpoint).to(device)
    student_model = AutoModelForSequenceClassification.from_pretrained(student_model).to(device)
    student_model.resize_token_embeddings(len(student_tokenizer))
    teacher_model.resize_token_embeddings(len(student_tokenizer))
    student_model.config.pad_token_id = student_tokenizer.pad_token_id
    teacher_model.config.pad_token_id = student_tokenizer.pad_token_id


    distil_trainer = DistillationTrainer(
        model=student_model,
        teacher_model=teacher_model,
        alpha_imdb=args.alpha_imdb,
        alpha_ce=args.alpha_ce,
        alpha_mse=args.alpha_mse,
        alpha_cos=args.alpha_cos,
        alpha_corr=args.alpha_corr,
        alpha_CKA=args.alpha_CKA,
        temperature=args.temperature,
        student_hidden_size = args.student_hidden_size, 
        teacher_hidden_size = args.teacher_hidden_size, 
        do_projector = args.do_projector,
        args=student_training_args,
        train_dataset=train_ds,
        tokenizer=student_tokenizer,
    )
    distil_trainer.train()  # train
    distil_trainer.save_model(args.output_dir)  # save the mode to the specified path

    # Eval
    student_model = args.output_dir  # use the save student model
    student_tokenizer = AutoTokenizer.from_pretrained(student_model)
    student_trained = AutoModelForSequenceClassification.from_pretrained(student_model, num_labels=2,
                                                                           id2label=id2label, label2id=label2id)
    test_args = TrainingArguments(
        output_dir=args.output_dir_intermed,
        do_train=False,
        do_eval=True,
        overwrite_output_dir=True,
        per_device_eval_batch_size=16,
        dataloader_drop_last=False)
    student_trainer = Trainer(
        model=student_trained,
        args=test_args,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=student_tokenizer)
    # Print the results
    print(student_trainer.evaluate(eval_ds))


if __name__ == "__main__":
    main()