import argparse
import glob
import logging
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from torch.optim import AdamW
import torch.nn.functional as F
from torch import masked_select
import torch.nn as nn
from transformers import PhiForCausalLM, PhiConfig, AutoModelForCausalLM
from transformers import AutoConfig, AutoTokenizer
from datasets import load_dataset, interleave_datasets, load_from_disk, concatenate_datasets
from evaluate import load
from transformers.models.bart.modeling_bart import shift_tokens_right
from einops import rearrange
from torch.nn.parallel import DistributedDataParallel as DDP
import os, sys



def Create_Student(student_id,teacher, student_layers, student_intermediate_size):

    student = AutoModelForCausalLM.from_pretrained(teacher_id,num_hidden_layers=student_layers,ignore_mismatched_sizes=True,\
                intermediate_size = student_intermediate_size, attn_implementation="flash_attention_2", output_hidden_states = True)

    print(sum(p.numel() for p in student.parameters() if p.requires_grad))

    for i in range(student_layers):
        Tix = 2*i
        student.model.layers[i].self_attn.q_proj.weight.copy_(teacher.model.layers[Tix].self_attn.q_proj.weight.detach().clone())
        student.model.layers[i].self_attn.k_proj.weight.copy_(teacher.model.layers[Tix].self_attn.k_proj.weight.detach().clone())
        student.model.layers[i].self_attn.v_proj.weight.copy_(teacher.model.layers[Tix].self_attn.v_proj.weight.detach().clone())
        student.model.layers[i].self_attn.o_proj.weight.copy_(teacher.model.layers[Tix].self_attn.o_proj.weight.detach().clone())

        
        student.model.layers[i].mlp.gate_proj.weight.copy_(teacher.model.layers[Tix].mlp.gate_proj.weight[:student_intermediate_size,:].detach().clone())
        student.model.layers[i].mlp.up_proj.weight.copy_(teacher.model.layers[Tix].mlp.up_proj.weight[:student_intermediate_size,:].detach().clone())
        student.model.layers[i].mlp.down_proj.weight.copy_(teacher.model.layers[Tix].mlp.down_proj.weight[:,:student_intermediate_size].detach().clone())

        
        student.model.layers[i].input_layernorm.weight.copy_(teacher.model.layers[Tix].input_layernorm.weight.detach().clone())
        student.model.layers[i].post_attention_layernorm.weight.copy_(teacher.model.layers[Tix].post_attention_layernorm.weight.detach().clone())



    return student




def Eval_Student(student,teacher,test_dataset, temperature = 1.0, batch_size = 32):
    eval_dataloader = DataLoader(test_dataset, shuffle=True, batch_size = batch_size)
    student.eval()
    teacher.eval()
    sumloss = [0]*(student.config.num_hidden_layers +3)    
    nBatch = 0
    for batch in eval_dataloader:
        ########### In Device 0 ###############
        batch = {k: v.to(device) for k, v in batch.items() if not isinstance(v,list)} 

        torch.cuda.empty_cache()
        loss = [0]*(student.config.num_hidden_layers +3)
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            decoder_input_ids = shift_tokens_right(batch['input_ids'], pad_token_id, pad_token_id)
            dec_mask = decoder_input_ids.ne(pad_token_id)

            teacher_out = teacher(input_ids = decoder_input_ids,attention_mask = dec_mask)

            student_out = student(input_ids = decoder_input_ids,attention_mask = dec_mask)
            dV = student_out.logits.size(-1)


            decoder_input_ids = shift_tokens_right(batch['input_ids'], pad_token_id, pad_token_id)
            dec_mask = decoder_input_ids.ne(pad_token_id)



            ########### In Device 0 ###############
            logit_mask = batch['attention_mask'].unsqueeze(-1).expand_as(teacher_out.logits).bool()
            SL = masked_select(student_out.logits,logit_mask).view(-1,dV)
            loss[0] = CELoss(SL,masked_select(batch['input_ids'],batch['attention_mask'].bool()))
            loss[1] = ((temperature)**2)* STLoss(F.log_softmax(SL/ temperature,dim=-1),\
                    F.softmax(masked_select(teacher_out.logits,logit_mask).view(-1,dV)/temperature, dim=-1))

            ############## Calculate CKA Loss #######

            teacher_mask = dec_mask.unsqueeze(-1).expand_as(teacher_out.hidden_states[-1]).bool()
            student_mask = dec_mask.unsqueeze(-1).expand_as(student_out.hidden_states[-1]).bool()
            dT = teacher_out.hidden_states[-1].size(-1)
            dS = student_out.hidden_states[-1].size(-1)             
                        
            target = torch.ones(dec_mask.sum()).to(device)
            
            loss[2] = EMBLoss(masked_select(student_out.hidden_states[0],student_mask).view(-1,dS), masked_select(teacher_out.hidden_states[0],teacher_mask).view(-1,dT)) 
            
            for i in range(student.config.num_hidden_layers):
                torch.cuda.empty_cache()
                Tix = 2*(i+1)
                loss[i+3] = CSLoss(masked_select(student_out.hidden_states[i+1],student_mask).view(-1,dS), masked_select(teacher_out.hidden_states[Tix],teacher_mask).view(-1,dT),target) 

        if(torch.isnan(torch.Tensor(loss)).sum()==0):
            for i in range(len(loss)): sumloss[i]+=loss[i]
            nBatch+=1

        if(nBatch%5000==0): break    
    
    return [l.item()/nBatch for l in sumloss]


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

teacher_id = "Qwen/Qwen2.5-3B" 
teacher = AutoModelForCausalLM.from_pretrained(teacher_id,attn_implementation="flash_attention_2",output_hidden_states = True)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision('high')
tokenizer = AutoTokenizer.from_pretrained(teacher_id, truncation_side = 'left',padding_side="left")

for param in teacher.parameters(): param.requires_grad = False

print("Number of Teacher Parameters: ", sum(p.numel() for p in teacher.parameters()))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
student_intermediate_size = int(sys.argv[1])
student = Create_Student(teacher_id,teacher,18,student_intermediate_size)

print(student.config)
print(sum(p.numel() for p in student.parameters() if p.requires_grad))

############################ Stream C4 Dataset ###########################
torch.manual_seed(int(sys.argv[4]))
train_dataset = load_from_disk('/Dump/Train_CultureX_Qwen')

batch_size = 1
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size = batch_size)
test_dataset = load_from_disk('/Dump/Valid_Qwen')

############################ Training Starts ############################
teacher.to(device)
student = student.to(device)
for p in student.model.embed_tokens.parameters(): p.requires_grad = False

teacher.eval()
no_decay = ["bias"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad and "embed_tokens" not in n],
     "weight_decay": 1e-1,
    },
    {"params": [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad and "embed_tokens" not in n],
     "weight_decay": 0.0,
    }
]

alpha = float(sys.argv[3])
LR = float(sys.argv[2])
K = int(sys.argv[5])
temperature = 1.0
optimizer = AdamW(optimizer_grouped_parameters , lr=LR*1e-5, betas = (0.9,0.999), eps = 1e-6)
f = open("../../Checkpoints/LOGIT-KD-Full-Qwen-MC4-xl-l-%d-%d-%d-%d.txt" % (student_intermediate_size,LR,alpha,K), "w+", buffering= 50)
f1 = open("../../Checkpoints/LOGIT-KD-Eval-Full-Qwen-MC4-xl-l-%d-%d-%d-%d.txt" % (student_intermediate_size,LR,alpha,K), "w+", buffering= 1)
from transformers import get_scheduler, get_cosine_schedule_with_warmup
num_epochs=50
batch_per_epoch = 40000
scheduler = get_cosine_schedule_with_warmup(optimizer = optimizer,num_warmup_steps=0, num_training_steps=num_epochs*batch_per_epoch)

pad_token_id = tokenizer.pad_token_id
torch.set_float32_matmul_precision('high')
STLoss = torch.nn.KLDivLoss(reduction = 'batchmean')
STLoss2 = torch.nn.KLDivLoss(reduction = 'none')
CSLoss = nn.CosineEmbeddingLoss() #CKALoss(eps = 1e-8)
CELoss = nn.CrossEntropyLoss(ignore_index=pad_token_id)
EMBLoss = nn.MSELoss()
scaler = torch.cuda.amp.GradScaler()
  

nBatch = 0
nTokens = 0
teacher.eval()
for batch in train_dataloader:
    ########### In Device 0 ###############
    batch = {k: v.to(device) for k, v in batch.items() if not isinstance(v,list)} 

    student.train()
            
    loss = [0]*(student.config.num_hidden_layers +3)   

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        decoder_input_ids = shift_tokens_right(batch['input_ids'], pad_token_id, pad_token_id)
        dec_mask = decoder_input_ids.ne(pad_token_id)

        teacher_out = teacher(input_ids = decoder_input_ids,attention_mask = dec_mask)
        logit_mask = batch['attention_mask'].unsqueeze(-1).expand_as(teacher_out.logits).bool()
        dV = teacher_out.logits.size(-1)
    
        TL = masked_select(teacher_out.logits,logit_mask).view(-1,dV)

        pT = F.softmax(TL.float(), dim=-1)        
        sorted_pT,sorted_idx = torch.sort(pT,dim=-1,descending = True)
        
    

        pTmax = torch.gather(pT,1,sorted_idx[:,:K]) #torch.where(pT>=percentile,pT,0)

                
        with torch.enable_grad():
            student_out = student(input_ids = decoder_input_ids,attention_mask = dec_mask)

            SL = masked_select(student_out.logits,logit_mask).view(-1,dV) 
            loss[0] = CELoss(SL,masked_select(batch['input_ids'],batch['attention_mask'].bool()))
           
            pS = F.softmax(SL.float(), dim=-1)        

            pSmax = torch.gather(pS,1,sorted_idx[:,:K])
             
            pT_nc = torch.gather(pT,1,sorted_idx[:,K:])
            pS_nc = torch.gather(pS,1,sorted_idx[:,K:])

            pT_not_max = pT_nc.sum(-1)
            pS_not_max = pS_nc.sum(-1)
        
            MKLD = STLoss(torch.log(torch.cat((pSmax,pS_not_max.unsqueeze(-1)),1)),torch.cat((pTmax,pT_not_max.unsqueeze(-1)),1))
        
            #pT_nc = torch.gather(pT,1,sorted_idx[:,K:]) 
            pT_nc = pT_nc/pT_nc.sum(-1).unsqueeze(-1).expand_as(pT_nc)
        
            #pS_nc = torch.exp(torch.gather(SL.double(),1,sorted_idx[:,K:])) 
            pS_nc = pS_nc/pS_nc.sum(-1).unsqueeze(-1).expand_as(pS_nc)
                            
            NKLD = STLoss2(torch.log(pS_nc),pT_nc)
                
            loss[1] = (MKLD + (alpha/pT_not_max.mean())*(pT_not_max*NKLD.sum(-1)).mean()) #Mean over Tokens
        


            ############## Calculate CKA Loss #######

            teacher_mask = dec_mask.unsqueeze(-1).expand_as(teacher_out.hidden_states[-1]).bool()
            student_mask = dec_mask.unsqueeze(-1).expand_as(student_out.hidden_states[-1]).bool()
            dT = teacher_out.hidden_states[-1].size(-1)
            dS = student_out.hidden_states[-1].size(-1)             
            
            target = torch.ones(dec_mask.sum()).to(device)


            loss[2] = EMBLoss(masked_select(student_out.hidden_states[0],student_mask).view(-1,dS), masked_select(teacher_out.hidden_states[0],teacher_mask).view(-1,dT)) 
            
            for i in range(student.config.num_hidden_layers):
                torch.cuda.empty_cache()
                Tix = 2*(i+1)
                loss[i+3] = CSLoss(masked_select(student_out.hidden_states[i+1],student_mask).view(-1,dS), masked_select(teacher_out.hidden_states[Tix],teacher_mask).view(-1,dT),target)             

            loss_sum = sum(loss)
            torch.cuda.empty_cache()
            scaler.scale(loss_sum).backward()

            nBatch+=1
            nTokens += batch['attention_mask'].sum().item()
            if(nBatch%32 ==0):
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
                f.write(str(['%.3f' % l.item() for l in loss])+'\n')
                
    if(nBatch%batch_per_epoch==0):
        loss = Eval_Student(student,teacher,test_dataset, temperature, batch_size)
        f1.write(str(nTokens) + str(['%.3f' % l for l in loss]) +  '\n')
        f1.flush()
        torch.save(student.state_dict(),"../../Checkpoints/Qwen/Qwen-MC4-xl-l-Full-%d-%d-%d-%d.pt" % (student_intermediate_size,LR,alpha,K))


torch.save(student.state_dict(),"../../Checkpoints/Qwen/Qwen-MC4-xl-l-Full-%d-%d-%d-%d.pt" % (student_intermediate_size,LR,alpha,K))

f.close()
f1.close()