
# Import packages
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader 
from torch.optim import AdamW
import torch.nn.functional as F
from torch import masked_select
import torch.nn as nn
from transformers import BartForConditionalGeneration, BartConfig
from transformers import BartTokenizer
from datasets import load_from_disk, load_dataset
from transformers.models.bart.modeling_bart import shift_tokens_right
from evaluate import load   
from transformers import get_scheduler

class CKALoss(nn.Module):
    """
    Loss with knowledge distillation.
    """
    def __init__(self, eps ):
        super().__init__()
        self.eps = eps
    def forward(self, SH, TH): 
        dT = TH.size(-1)
        dS = SH.size(-1)
        SH = SH.view(-1,dS).to(SH.device,torch.float64)
        TH = TH.view(-1,dT).to(SH.device,torch.float64)
        
        slen = SH.size(0)
                # Dropout on Hidden State Matching
        SH = SH - SH.mean(0, keepdim=True)
        TH = TH - TH.mean(0, keepdim=True)
                
        num = torch.norm(SH.t().matmul(TH),'fro')
        den1 = torch.norm(SH.t().matmul(SH),'fro') + self.eps
        den2 = torch.norm(TH.t().matmul(TH),'fro') + self.eps
        
        return 1 - num/torch.sqrt(den1*den2)

def localiseNN(logits, hidden_state,teacher):
        logits.requires_grad_(True)
        log_probs = F.log_softmax(logits, dim=-1)
        teacher.zero_grad()
        hidden_state.requires_grad_(True)
        grads = torch.autograd.grad(log_probs, hidden_state, grad_outputs=torch.ones_like(log_probs), retain_graph=True)[0]
        return [torch.abs(grads.mean(dim=0).mean(dim=0))]


from itertools import islice
def scores_neurons(teacher, data_iterator, device, student_layer, student_hidden_size):
        all_scores_encoder = []
        all_scores_decoder = []
        for layer in [-1]:
            per_layer_score_encoder = []
            per_layer_score_decoder = []
            for batch in islice(data_iterator, 40):
                    teacher_output = teacher(input_ids = batch['input_ids'].to(device),attention_mask = batch['attention_mask'].to(device), labels = batch['label_ids'].to(device),output_hidden_states=True)
                    logit_mask = batch['label_attention_mask'].unsqueeze(-1).expand_as(teacher_output.logits).bool().to(device)
                    dV = teacher_output.logits.size(-1)
                    curr_scores_encoder = localiseNN(masked_select(teacher_output.logits,logit_mask).view(-1,dV), teacher_output.encoder_hidden_states[layer],teacher)
                    curr_scores_decoder = localiseNN(masked_select(teacher_output.logits,logit_mask).view(-1,dV), teacher_output.decoder_hidden_states[layer],teacher)      
                    per_layer_score_encoder.append(curr_scores_encoder)
                    per_layer_score_decoder.append(curr_scores_decoder)
            result_encoder = list(map(sum, zip(*per_layer_score_encoder)))[0].to(device)    
            topk_encoder = torch.argsort(result_encoder, dim=-1, descending=True)[:student_hidden_size]
            all_scores_encoder.append(topk_encoder)
            result_decoder = list(map(sum, zip(*per_layer_score_decoder)))[0].to(device)    
            topk_decoder = torch.argsort(result_decoder, dim=-1, descending=True)[:student_hidden_size]
            all_scores_decoder.append(topk_decoder)
        return all_scores_encoder, all_scores_decoder 


def off_diagonal(x):
        # return a flattened view of the off-diagonal elements of a square matrix
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

    
def compute_corr(s_hidden_states, t_hidden_states):     
        z1_norm = (s_hidden_states - torch.mean(s_hidden_states, dim=0)) / torch.std(s_hidden_states, dim=0)
        z2_norm = (t_hidden_states - torch.mean(t_hidden_states, dim=0)) / torch.std(t_hidden_states, dim=0)
        cross_corr = torch.matmul(z1_norm.T, z2_norm) / t_hidden_states.size(0)
        on_diag = torch.diagonal(cross_corr).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(cross_corr).pow_(2).sum()
        loss_corr = on_diag + (5e-3 * off_diag)
        return loss_corr

    
def Create_Student(student_id,student_dim,student_layer,teacher):
    config = BartConfig.from_pretrained('facebook/bart-large', vocab_size = teacher.config.vocab_size, encoder_layers=student_layer, \
                                    decoder_layers = student_layer, d_model=student_dim,output_hidden_states = True, use_cache = False)
    student = BartForConditionalGeneration(config)
    print("Student Parameters", sum(p.numel() for p in student.parameters() if p.requires_grad))
    student_int_size = student.config.encoder_ffn_dim    
    student.model.shared.weight.copy_(teacher.model.shared.weight[:,:student_dim].clone())
    student.model.encoder.embed_tokens.weight.copy_(teacher.model.encoder.embed_tokens.weight[:,:student_dim].clone())
    student.model.encoder.embed_positions.weight.copy_(teacher.model.encoder.embed_positions.weight[:,:student_dim].clone())
        
    for i in range(student.config.encoder_layers):
        student.model.encoder.layers[i].self_attn.k_proj.weight.copy_(teacher.model.encoder.layers[4*i+3].self_attn.k_proj.weight[:student_dim,:student_dim].clone())
        student.model.encoder.layers[i].self_attn.k_proj.bias.copy_(teacher.model.encoder.layers[4*i+3].self_attn.k_proj.bias[:student_dim].clone())
        student.model.encoder.layers[i].self_attn.v_proj.weight.copy_(teacher.model.encoder.layers[4*i+3].self_attn.v_proj.weight[:student_dim,:student_dim].clone())
        student.model.encoder.layers[i].self_attn.v_proj.bias.copy_(teacher.model.encoder.layers[4*i+3].self_attn.v_proj.bias[:student_dim].clone())
        student.model.encoder.layers[i].self_attn.q_proj.weight.copy_(teacher.model.encoder.layers[4*i+3].self_attn.q_proj.weight[:student_dim,:student_dim].clone())
        student.model.encoder.layers[i].self_attn.q_proj.bias.copy_(teacher.model.encoder.layers[4*i+3].self_attn.q_proj.bias[:student_dim].clone())
        student.model.encoder.layers[i].self_attn.out_proj.weight.copy_(teacher.model.encoder.layers[4*i+3].self_attn.out_proj.weight[:student_dim,:student_dim].clone())
        student.model.encoder.layers[i].self_attn.out_proj.bias.copy_(teacher.model.encoder.layers[4*i+3].self_attn.out_proj.bias[:student_dim].clone())
    
        student.model.encoder.layers[i].self_attn_layer_norm.weight.copy_(teacher.model.encoder.layers[4*i+3].self_attn_layer_norm.weight[:student_dim].clone())
        student.model.encoder.layers[i].self_attn_layer_norm.bias.copy_(teacher.model.encoder.layers[4*i+3].self_attn_layer_norm.bias[:student_dim].clone())
    
        student.model.encoder.layers[i].fc1.weight.copy_(teacher.model.encoder.layers[4*i+3].fc1.weight[:student_int_size, :student_dim].clone())
        student.model.encoder.layers[i].fc1.bias.copy_(teacher.model.encoder.layers[4*i+3].fc1.bias[:student_int_size].clone())
        student.model.encoder.layers[i].fc2.weight.copy_(teacher.model.encoder.layers[4*i+3].fc2.weight[:student_dim, :student_int_size].clone())
        student.model.encoder.layers[i].fc2.bias.copy_(teacher.model.encoder.layers[4*i+3].fc2.bias[:student_dim,].clone())    
        student.model.encoder.layers[i].final_layer_norm.weight.copy_(teacher.model.encoder.layers[4*i+3].final_layer_norm.weight[:student_dim].clone())
        student.model.encoder.layers[i].final_layer_norm.bias.copy_(teacher.model.encoder.layers[4*i+3].final_layer_norm.bias[:student_dim].clone())
        
    student.model.encoder.layernorm_embedding.weight.copy_(teacher.model.encoder.layernorm_embedding.weight[:student_dim].clone())
    student.model.encoder.layernorm_embedding.bias.copy_(teacher.model.encoder.layernorm_embedding.bias[:student_dim].clone())
    
    for i in range(student.config.decoder_layers):
        student.model.decoder.layers[i].self_attn.k_proj.weight.copy_(teacher.model.decoder.layers[4*i+3].self_attn.k_proj.weight[:student_dim,:student_dim].clone())
        student.model.decoder.layers[i].self_attn.k_proj.bias.copy_(teacher.model.decoder.layers[4*i+3].self_attn.k_proj.bias[:student_dim].clone())
        student.model.decoder.layers[i].self_attn.v_proj.weight.copy_(teacher.model.decoder.layers[4*i+3].self_attn.v_proj.weight[:student_dim,:student_dim].clone())
        student.model.decoder.layers[i].self_attn.v_proj.bias.copy_(teacher.model.decoder.layers[4*i+3].self_attn.v_proj.bias[:student_dim].clone())
        student.model.decoder.layers[i].self_attn.q_proj.weight.copy_(teacher.model.decoder.layers[4*i+3].self_attn.q_proj.weight[:student_dim,:student_dim].clone())
        student.model.decoder.layers[i].self_attn.q_proj.bias.copy_(teacher.model.decoder.layers[4*i+3].self_attn.q_proj.bias[:student_dim].clone())
        student.model.decoder.layers[i].self_attn.out_proj.weight.copy_(teacher.model.decoder.layers[4*i+3].self_attn.out_proj.weight[:student_dim,:student_dim].clone())
        student.model.decoder.layers[i].self_attn.out_proj.bias.copy_(teacher.model.decoder.layers[4*i+3].self_attn.out_proj.bias[:student_dim].clone())
    
        student.model.decoder.layers[i].self_attn_layer_norm.weight.copy_(teacher.model.decoder.layers[4*i+3].self_attn_layer_norm.weight[:student_dim].clone())
        student.model.decoder.layers[i].self_attn_layer_norm.bias.copy_(teacher.model.decoder.layers[4*i+3].self_attn_layer_norm.bias[:student_dim].clone())
    
        student.model.decoder.layers[i].encoder_attn.k_proj.weight.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn.k_proj.weight[:student_dim,:student_dim].clone())
        student.model.decoder.layers[i].encoder_attn.k_proj.bias.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn.k_proj.bias[:student_dim].clone())
        student.model.decoder.layers[i].encoder_attn.v_proj.weight.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn.v_proj.weight[:student_dim,:student_dim].clone())
        student.model.decoder.layers[i].encoder_attn.v_proj.bias.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn.v_proj.bias[:student_dim].clone())
        student.model.decoder.layers[i].encoder_attn.q_proj.weight.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn.q_proj.weight[:student_dim,:student_dim].clone())
        student.model.decoder.layers[i].encoder_attn.q_proj.bias.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn.q_proj.bias[:student_dim].clone())
        student.model.decoder.layers[i].encoder_attn.out_proj.weight.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn.out_proj.weight[:student_dim,:student_dim].clone())
        student.model.decoder.layers[i].encoder_attn.out_proj.bias.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn.out_proj.bias[:student_dim].clone())
    
        student.model.decoder.layers[i].encoder_attn_layer_norm.weight.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn_layer_norm.weight[:student_dim].clone())
        student.model.decoder.layers[i].encoder_attn_layer_norm.bias.copy_(teacher.model.decoder.layers[4*i+3].encoder_attn_layer_norm.bias[:student_dim].clone())
    
        student.model.decoder.layers[i].fc1.weight.copy_(teacher.model.decoder.layers[4*i+3].fc1.weight[:student_int_size, :student_dim].clone())
        student.model.decoder.layers[i].fc1.bias.copy_(teacher.model.decoder.layers[4*i+3].fc1.bias[:student_int_size].clone())
        student.model.decoder.layers[i].fc2.weight.copy_(teacher.model.decoder.layers[4*i+3].fc2.weight[:student_dim, :student_int_size].clone())
        student.model.decoder.layers[i].fc2.bias.copy_(teacher.model.decoder.layers[4*i+3].fc2.bias[:student_dim,].clone())
        student.model.decoder.layers[i].final_layer_norm.weight.copy_(teacher.model.decoder.layers[4*i+3].final_layer_norm.weight[:student_dim].clone())
        student.model.decoder.layers[i].final_layer_norm.bias.copy_(teacher.model.decoder.layers[4*i+3].final_layer_norm.bias[:student_dim].clone())
        
    student.model.decoder.layernorm_embedding.weight.copy_(teacher.model.encoder.layernorm_embedding.weight[:student_dim].clone())
    student.model.decoder.layernorm_embedding.bias.copy_(teacher.model.encoder.layernorm_embedding.bias[:student_dim].clone())
    
    student.lm_head.weight.copy_(teacher.lm_head.weight[:,:student_dim])
    return student



def Eval_Student(student,teacher,test_dataset, batch_size = 32):
    rouge = load('rouge', experiment_id = "Bart-CKA-CNN-%d-%d.txt" % (student.config.encoder_layers,student.config.d_model), use_stemmer = True)
    alpha = 0.05
    eval_dataloader = DataLoader(test_dataset, batch_size=batch_size, drop_last = True)
    student.eval()
    teacher.eval()
    loss = [0]*4
    nBatch = 0
    for batch in eval_dataloader:
        with torch.cuda.amp.autocast():
            batch = {k: v.to(device) for k, v in batch.items() if not isinstance(v,list)}
            decoder_input_ids = shift_tokens_right(batch['label_ids'], pad_token_id, pad_token_id)
            dec_mask = decoder_input_ids.ne(pad_token_id)
            torch.cuda.empty_cache()
            student_out = student(input_ids = batch['input_ids'],attention_mask = batch['attention_mask'], labels = batch['label_ids'])  
            teacher_out = teacher(input_ids = batch['input_ids'],attention_mask = batch['attention_mask'], labels = batch['label_ids'])    
        
            dV = student_out.logits.size(-1)
            logit_mask = batch['label_attention_mask'].unsqueeze(-1).expand_as(student_out.logits).bool()
            SL = masked_select(student_out.logits,logit_mask).view(-1,dV)
            loss[0] += CELoss(SL,masked_select(batch['label_ids'],batch['label_attention_mask'].bool()))
            loss[1] += ((temperature)**2)*STLoss(F.log_softmax(masked_select(teacher_out.logits,logit_mask).view(-1,dV)/ temperature, dim=-1), F.softmax(SL/ temperature,dim=-1))
            
            teacher_mask = batch['attention_mask'].unsqueeze(-1).expand_as(teacher_out.encoder_hidden_states[-1][:, :, selected_indices_encoder[0]]).bool()
            student_mask = batch['attention_mask'].unsqueeze(-1).expand_as(student_out.encoder_hidden_states[-1]).bool()
            dS = student_out.encoder_hidden_states[-1].size(-1)   
            dT = teacher_out.encoder_hidden_states[-1][:, :, selected_indices_encoder[0]].size(-1)             
            loss[2] += alpha*compute_corr(masked_select(student_out.encoder_hidden_states[-1],student_mask).view(-1,dS), masked_select(teacher_out.encoder_hidden_states[-1][:, :, selected_indices_encoder[0]],teacher_mask).view(-1,dT)) 
            teacher_mask = dec_mask.unsqueeze(-1).expand_as(teacher_out.decoder_hidden_states[-1][:, :, selected_indices_decoder[0]]).bool()
            student_mask = dec_mask.unsqueeze(-1).expand_as(student_out.decoder_hidden_states[-1]).bool()    
            dT = teacher_out.decoder_hidden_states[-1][:, :, selected_indices_decoder[0]].size(-1)
            dS = student_out.decoder_hidden_states[-1].size(-1)
            loss[3] += alpha*compute_corr(masked_select(student_out.decoder_hidden_states[-1],student_mask).view(-1,dS), masked_select(teacher_out.decoder_hidden_states[-1][:, :, selected_indices_decoder[0]],teacher_mask).view(-1,dT))
            predictions = student.generate(input_ids = batch['input_ids'],attention_mask = batch['attention_mask'], max_new_tokens = 32)
            decoded_preds = tokenizer.batch_decode(predictions.sequences, skip_special_tokens=True)
            labels = np.where(batch['label_ids'].cpu().numpy() != -100, batch['label_ids'].cpu().numpy(), tokenizer.pad_token_id)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
            rouge.add_batch(predictions=decoded_preds, references=decoded_labels)
            
            nBatch+=1
    return [l.item()/nBatch for l in loss], rouge


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"



teacher_id = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(teacher_id)
def encode(examples):
    result = {}
    temp = tokenizer(examples['document'], truncation=True, padding='max_length', max_length = 1024)
    result['input_ids'] = temp['input_ids']
    result['attention_mask'] = temp['attention_mask']
    temp = tokenizer(examples['summary'], truncation=True, padding='max_length', max_length = 64)
    result['label_ids'] = temp['input_ids']
    result['label_attention_mask'] = temp['attention_mask']
    return result
dataset = load_dataset('cnn_dailymail','3.0.0').rename_column('article','document').rename_column('highlights','summary')
 
tokenized_datasets = dataset.map(encode, remove_columns = ["document", "id", "summary"], num_proc = 24)
tokenized_datasets.set_format('torch')

batch_size = 16 
train_dataloader = DataLoader(tokenized_datasets['train'], shuffle=True, batch_size=batch_size)
test_dataset = tokenized_datasets['validation']
vaild_sample_interval = len(train_dataloader)//4
print(vaild_sample_interval)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
teacher = BartForConditionalGeneration.from_pretrained(teacher_id,output_hidden_states = True, output_past = False, use_cache = False)
student_dim = 768
student_layer = 3
teacher.to(device)
selected_indices_encoder, selected_indices_decoder  = scores_neurons(teacher, train_dataloader, device, student_layer, student_hidden_size = student_dim)
torch.set_grad_enabled(False)
for param in teacher.parameters(): param.requires_grad = False
student = Create_Student('facebook/bart-large',student_dim,student_layer,teacher)
print("Student: ", student)
print(sum(p.numel() for p in student.parameters() if p.requires_grad))
student.to(device)

Eh = nn.Linear(student.config.hidden_size,teacher.config.hidden_size)
Eh.to(device)

############################ Training Starts ############################
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
     "weight_decay": 5e-3,
    },
    {"params": [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
     "weight_decay": 0.0,
    }
]


optimizer = AdamW(optimizer_grouped_parameters , lr=1e-4, betas = (0.9,0.999), eps = 1e-8)
f = open("./Checkpoints6CNN768/LOGIT-KD-CKA-Bart-CNN-%d-%d.txt" % (student_layer,student_dim), "w+", buffering= 50)
f1 = open("./Checkpoints6CNN768/LOGIT-KD-Eval-CKA-Bart-CNN-%d-%d.txt" % (student_layer,student_dim), "w+", buffering= 1)
num_epochs=10
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
optimizer1 = AdamW(Eh.parameters() , lr=2e-4, betas = (0.9,0.999), eps = 1e-6, weight_decay = 5e-4)
lr_scheduler1 = get_scheduler(name="linear", optimizer=optimizer1, num_warmup_steps=0, num_training_steps = num_training_steps)


STLoss = nn.KLDivLoss(reduction = 'batchmean')
CSLoss = CKALoss(eps = 1e-8)
CELoss = nn.CrossEntropyLoss()
pad_token_id = tokenizer.pad_token_id
scaler = torch.cuda.amp.GradScaler()
teacher.eval()
temperature = 1.0
lambdaH = 1.0
alpha = 0.05
torch.manual_seed(42)


for epoch in range(num_epochs):
    student.train()
    nBatch = 0
    for batch in train_dataloader:
        with torch.cuda.amp.autocast():
            torch.cuda.empty_cache()
            batch = {k: v.to(device) for k, v in batch.items() if not isinstance(v,list)}
            decoder_input_ids = shift_tokens_right(batch['label_ids'], pad_token_id, pad_token_id)
            dec_mask = decoder_input_ids.ne(pad_token_id)
            teacher_out = teacher(input_ids = batch['input_ids'],attention_mask = batch['attention_mask'], labels = batch['label_ids']) 
            loss = [0]*4

            with torch.enable_grad():
                student_out = student(input_ids = batch['input_ids'],attention_mask = batch['attention_mask'], labels = batch['label_ids'])  
                torch.cuda.empty_cache()
                dV = student_out.logits.size(-1)
                logit_mask = batch['label_attention_mask'].unsqueeze(-1).expand_as(student_out.logits).bool()
                SL = masked_select(student_out.logits,logit_mask).view(-1,dV)
                loss[0] = CELoss(SL,masked_select(batch['label_ids'],batch['label_attention_mask'].bool()))/lambdaH
                ######## Reverse KL #########
                loss[1] = ((temperature)**2)*STLoss(F.log_softmax(masked_select(teacher_out.logits,logit_mask).view(-1,dV)/ temperature, dim=-1), F.softmax(SL/ temperature,dim=-1))
                teacher_mask = batch['attention_mask'].unsqueeze(-1).expand_as(teacher_out.encoder_hidden_states[-1][:, :, selected_indices_encoder[0]]).bool()
                student_mask = batch['attention_mask'].unsqueeze(-1).expand_as(student_out.encoder_hidden_states[-1]).bool()
                dT = teacher_out.encoder_hidden_states[-1][:, :, selected_indices_encoder[0]].size(-1)
                dS = student_out.encoder_hidden_states[-1].size(-1)                
                loss[2]  = alpha*compute_corr(masked_select(student_out.encoder_hidden_states[-1],student_mask).view(-1,dS), masked_select(teacher_out.encoder_hidden_states[-1][:, :, selected_indices_encoder[0]],teacher_mask).view(-1,dT))
                teacher_mask = dec_mask.unsqueeze(-1).expand_as(teacher_out.decoder_hidden_states[-1][:, :, selected_indices_decoder[0]]).bool()
                student_mask = dec_mask.unsqueeze(-1).expand_as(student_out.decoder_hidden_states[-1]).bool()    
                dT = teacher_out.decoder_hidden_states[-1][:, :, selected_indices_decoder[0]].size(-1)
                dS = student_out.decoder_hidden_states[-1].size(-1)
                loss[3] = alpha*compute_corr(masked_select(student_out.decoder_hidden_states[-1],student_mask).view(-1,dS), masked_select(teacher_out.decoder_hidden_states[-1][:, :, selected_indices_decoder[0]],teacher_mask).view(-1,dT))

                loss_sum= sum(loss)
                scaler.scale(loss_sum).backward()
                torch.cuda.empty_cache()
                scaler.step(optimizer)
                # scaler.step(optimizer1)
                scaler.update()
                lr_scheduler.step()
                # lr_scheduler1.step()
                optimizer.zero_grad()
                # optimizer1.zero_grad()       

        f.write(str(['%.3f' % l.item() for l in loss])+'\n')
        nBatch+=1
        if(nBatch%vaild_sample_interval==0):
            torch.save(student.state_dict(),"./Checkpoints6CNN768/Bart-%d-%d-CNN-CKA.pt" % (student_layer,student_dim))
            loss, result = Eval_Student(student,teacher,test_dataset, batch_size)
            f1.write(str(['%.3f' % l for l in loss]) +  '\n' + str(result.compute()) + '\n')
            f1.flush()
            
    
torch.save(student.state_dict(),"./Checkpoints6CNN768/Bart-%d-%d-CNN-CKA.pt" % (student_layer,student_dim))

f.close()
f1.close()
