import argparse
import glob
import logging
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from torch.optim import AdamW
import torch.nn.functional as F
from torch import masked_select
import torch.nn as nn
from transformers import PhiForCausalLM, PhiConfig, AutoModelForCausalLM
from transformers import AutoConfig, AutoTokenizer
from datasets import load_dataset, interleave_datasets, load_from_disk, concatenate_datasets
from evaluate import load
from transformers.models.bart.modeling_bart import shift_tokens_right
from einops import rearrange
from torch.nn.parallel import DistributedDataParallel as DDP
import os, sys
import torch._dynamo


def Create_Student(student_id,teacher, student_layers, student_intermediate_size):

    student_dim = teacher.config.hidden_size
    student = AutoModelForCausalLM.from_pretrained(teacher_id, attn_implementation="eager", output_hidden_states = True, torch_dtype=torch.float32,\
                                            num_hidden_layers = student_layers, intermediate_size = student_intermediate_size, ignore_mismatched_sizes=True)


    for i in range(student_layers):
        Tix = 3*i
        student.model.layers[i].self_attn.q_proj.weight.copy_(teacher.model.layers[Tix].self_attn.q_proj.weight.detach().clone())
        student.model.layers[i].self_attn.k_proj.weight.copy_(teacher.model.layers[Tix].self_attn.k_proj.weight.detach().clone())
        student.model.layers[i].self_attn.v_proj.weight.copy_(teacher.model.layers[Tix].self_attn.v_proj.weight.detach().clone())
        student.model.layers[i].self_attn.o_proj.weight.copy_(teacher.model.layers[Tix].self_attn.o_proj.weight.detach().clone())

        
        student.model.layers[i].mlp.gate_proj.weight.copy_(teacher.model.layers[Tix].mlp.gate_proj.weight[:student_intermediate_size,:].detach().clone())
        student.model.layers[i].mlp.up_proj.weight.copy_(teacher.model.layers[Tix].mlp.up_proj.weight[:student_intermediate_size,:].detach().clone())
        student.model.layers[i].mlp.down_proj.weight.copy_(teacher.model.layers[Tix].mlp.down_proj.weight[:,:student_intermediate_size].detach().clone())

        
        student.model.layers[i].input_layernorm.weight.copy_(teacher.model.layers[Tix].input_layernorm.weight.detach().clone())
        student.model.layers[i].pre_feedforward_layernorm.weight.copy_(teacher.model.layers[Tix].pre_feedforward_layernorm.weight.detach().clone())
        student.model.layers[i].post_feedforward_layernorm.weight.copy_(teacher.model.layers[Tix].post_feedforward_layernorm.weight.detach().clone())
        student.model.layers[i].post_attention_layernorm.weight.copy_(teacher.model.layers[Tix].post_attention_layernorm.weight.detach().clone())



    return student



def Eval_Student(student,teacher,test_dataset, temperature = 1.0, batch_size = 32):
    eval_dataloader = DataLoader(test_dataset, batch_size = batch_size)
    student.eval()
    teacher.eval()
    sumloss = [0]*(student.config.num_hidden_layers +3)    
    nBatch = 0
    for batch in eval_dataloader:
        ########### In Device 0 ###############
        batch = {k: v.squeeze(1).to(device1) for k, v in batch.items() if not isinstance(v,list)} 
            
        decoder_input_ids = shift_tokens_right(batch['input_ids'], pad_token_id, pad_token_id)
        dec_mask = decoder_input_ids.ne(pad_token_id)
        torch.cuda.empty_cache()
        loss = [0]*(student.config.num_hidden_layers +3)
        #with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        teacher_out = teacher(input_ids = decoder_input_ids,attention_mask = dec_mask)
        TL = teacher_out.logits.squeeze(0).float()

            ########### In Device 0 ###############
        student_out = student(input_ids = decoder_input_ids.to(device0),attention_mask = dec_mask.to(device0))
        dV = student_out.logits.size(-1)
        SL = student_out.logits.squeeze(0).float().to(device1)

            ########### In Device 1 ###############
        loss[0] = CELoss(SL,batch['input_ids'].view(-1))
        loss[1] = ((temperature)**2)* STLoss(F.log_softmax(SL/ temperature,dim=-1), \
                    F.softmax(TL/temperature, dim=-1))

            ############## Calculate CKA Loss #######

        teacher_mask = dec_mask.unsqueeze(-1).expand_as(teacher_out.hidden_states[-1]).bool()
        student_mask = dec_mask.unsqueeze(-1).expand_as(student_out.hidden_states[-1]).bool().to(device0)
        dT = teacher_out.hidden_states[-1].size(-1)
        dS = student_out.hidden_states[-1].size(-1)             
            
        target = torch.ones(dec_mask.sum()).to(device1)


        loss[2] = EMBLoss(masked_select(student_out.hidden_states[0],student_mask).view(-1,dS).to(device1), masked_select(teacher_out.hidden_states[0],teacher_mask).view(-1,dT)) 
            
        for i in range(student.config.num_hidden_layers):
            Tix =3*(i+1)
            loss[i+3] = CSLoss(masked_select(student_out.hidden_states[i+1],student_mask).view(-1,dS).to(device1), \
                        masked_select(teacher_out.hidden_states[Tix],teacher_mask).view(-1,dT),target)             

        if(torch.isnan(torch.Tensor(loss)).sum()==0):
            for i in range(len(loss)): sumloss[i]+=loss[i]
            nBatch+=1

        if(nBatch%5000==0): break    
    
    return [l.item()/nBatch for l in sumloss]


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

teacher_id = "google/gemma-2-9b" 
torch.set_grad_enabled(False)
tokenizer = AutoTokenizer.from_pretrained(teacher_id)
teacher = AutoModelForCausalLM.from_pretrained(teacher_id, torch_dtype=torch.float32, output_hidden_states = True,\
                attn_implementation="eager")

for param in teacher.parameters(): param.requires_grad = False

print("Number of Teacher Parameters: ", sum(p.numel() for p in teacher.parameters()))
device0 = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device1 = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

student_intermediate_size = int(sys.argv[1])
student_raw = Create_Student(teacher_id,teacher,14,student_intermediate_size)

############################ Stream C4 Dataset ###########################
torch.manual_seed(int(sys.argv[4]))
torch._dynamo.config.capture_scalar_outputs = True
torch.set_float32_matmul_precision('high')
torch._dynamo.config.suppress_errors = True

train_dataset = load_from_disk('/Dump/Train_Gemma')
test_dataset = load_from_disk('/Dump/Valid_Gemma')

batch_size = 1
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size = batch_size)

############################ Training Starts ############################
teacher.to(device1,dtype = torch.bfloat16)
student_raw.to(device0,dtype = torch.bfloat16)
student = torch.compile(student_raw)
teacher.eval()
no_decay = ["bias"]
for p in student.model.embed_tokens.parameters(): p.requires_grad = False

optimizer_grouped_parameters = [
    {"params": [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad and "embed_tokens" not in n],
     "weight_decay": 1e-1,
    },
    {"params": [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad and "embed_tokens" not in n],
     "weight_decay": 0.0,
    }
]

student_intermediate_size = int(sys.argv[1])
LR = float(sys.argv[2])
alpha = float(sys.argv[3])
K = int(sys.argv[5])
temperature = 1.0

optimizer = AdamW(optimizer_grouped_parameters , lr=LR*1e-5, betas = (0.9,0.999), eps = 1e-6)
f = open("../../Checkpoints/LOGIT-KD-Full-Gem-C4-9b-l-%d-%d-%d-%d.txt" % (student_intermediate_size,LR,alpha,K), "w+", buffering= 50)
f1 = open("../../Checkpoints/LOGIT-KD-Eval-Full-Gem-C4-9b-l-%d-%d-%d-%d.txt" % (student_intermediate_size,LR,alpha,K), "w+", buffering= 1)
from transformers import get_scheduler, get_cosine_schedule_with_warmup
num_epochs=60
batch_per_epoch = 50000
scheduler = get_cosine_schedule_with_warmup(optimizer = optimizer,num_warmup_steps=0, num_training_steps=num_epochs*batch_per_epoch)

pad_token_id = tokenizer.pad_token_id
STLoss = torch.compile(torch.nn.KLDivLoss(reduction = 'batchmean'))
STLoss2 = torch.compile(torch.nn.KLDivLoss(reduction = 'none'))
CSLoss = torch.compile(nn.CosineEmbeddingLoss()) #CKALoss(eps = 1e-8)
CELoss = nn.CrossEntropyLoss(ignore_index=pad_token_id)
EMBLoss = nn.MSELoss()
  

nBatch = 0
nTokens = 0
teacher.eval()
for batch in train_dataloader:
    ########### In Device 0 ###############
    batch = {k: v.squeeze(1).to(device1) for k, v in batch.items() if not isinstance(v,list)} 

    student.train()
    
    decoder_input_ids = shift_tokens_right(batch['input_ids'], pad_token_id, pad_token_id)
    dec_mask = decoder_input_ids.ne(pad_token_id)
    #with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    teacher_out = teacher(input_ids = decoder_input_ids,attention_mask = dec_mask)
    dV = teacher_out.logits.size(-1)
    
    TL = teacher_out.logits.squeeze(0).float()
    pT = F.softmax(TL.float(), dim=-1)        
    sorted_pT,sorted_idx = torch.sort(pT,dim=-1,descending = True)
        
    

    pTmax = torch.gather(pT,1,sorted_idx[:,:K]) 

    loss = [0]*(student.config.num_hidden_layers +2)   

    with torch.enable_grad():
                    ########### In Device 0 ###############
        student_out = student(input_ids = decoder_input_ids.to(device0),attention_mask = dec_mask.to(device0))
        dV = student_out.logits.size(-1)
        SL = student_out.logits.squeeze(0).float().to(device1)

                ########### In Device 1 ###############

        loss[0] = CELoss(SL,batch['input_ids'].view(-1))

        pS = F.softmax(SL.float(), dim=-1)        

        pSmax = torch.gather(pS,1,sorted_idx[:,:K])
        pT_not_max = torch.gather(pT,1,sorted_idx[:,K:]).sum(-1)
        pS_not_max = torch.gather(pS,1,sorted_idx[:,K:]).sum(-1)

        assert(pT.shape[0] == pT_not_max.shape[0])
        
        MKLD = STLoss(torch.log(torch.cat((pSmax,pS_not_max.unsqueeze(-1)),1)),torch.cat((pTmax,pT_not_max.unsqueeze(-1)),1))
        
        pT_nc = torch.gather(pT,1,sorted_idx[:,K:]) 
        pT_nc = pT_nc/pT_nc.sum(-1).unsqueeze(-1).expand_as(pT_nc)
        
        pS_nc = torch.exp(torch.gather(SL.double(),1,sorted_idx[:,K:])) 
        pS_nc = pS_nc/pS_nc.sum(-1).unsqueeze(-1).expand_as(pS_nc)
            
            
        NKLD = STLoss2(torch.log(pS_nc),pT_nc)
                
        loss[1] = (MKLD + (alpha/pT_not_max.mean())*(pT_not_max*NKLD.sum(-1)).mean()) #Mean over Tokens
                    

        ############## CKA Loss ##############
        teacher_mask = dec_mask.unsqueeze(-1).expand_as(teacher_out.hidden_states[-1]).bool()
        student_mask = dec_mask.unsqueeze(-1).expand_as(student_out.hidden_states[-1]).bool().to(device0)
        dT = teacher_out.hidden_states[-1].size(-1)
        dS = student_out.hidden_states[-1].size(-1)             
            
        target = torch.ones(dec_mask.sum()).to(device1)


        #loss[2] = EMBLoss(masked_select(student_out.hidden_states[0],student_mask).view(-1,dS).to(device1), masked_select(teacher_out.hidden_states[0],teacher_mask).view(-1,dT)) 
            
        for i in range(student.config.num_hidden_layers):
            Tix =3*(i+1)
            loss[i+2] = CSLoss(masked_select(student_out.hidden_states[i+1],student_mask).view(-1,dS).to(device1), masked_select(teacher_out.hidden_states[Tix],teacher_mask).view(-1,dT),target)             

        loss_sum = sum(loss)
        torch.cuda.empty_cache()
        loss_sum.backward()

        nBatch+=1
        nTokens += batch['attention_mask'].sum().item()
        if(nBatch%32 ==0):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            f.write(str(['%.3f' % l.item() for l in loss])+'\n')
                
    if(nBatch%batch_per_epoch==0):
        loss = Eval_Student(student_raw,teacher,test_dataset, temperature, batch_size)
        f1.write(str(nTokens) + str(['%.3f' % l for l in loss]) +  '\n')
        f1.flush()
        torch.save(student.state_dict(),"../../Checkpoints/Gemma/Gem-C4-9B-l-Full-%d-%d-%d-%d.pt" % (student_intermediate_size,LR,alpha,K))


torch.save(student.state_dict(),"../../Checkpoints/Gemma/Gem-C4-9B-l-Full-%d-%d-%d-%d.pt" % (student_intermediate_size,LR,alpha,K))

f.close()
f1.close()