import argparse
import glob
import logging
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from torch.optim import AdamW
import torch.nn.functional as F
from torch import masked_select
import torch.nn as nn
from transformers import PhiForCausalLM, PhiConfig, AutoModelForCausalLM
from transformers import AutoConfig, AutoTokenizer
from datasets import load_dataset, interleave_datasets, load_from_disk, concatenate_datasets
from evaluate import load
from transformers.models.bart.modeling_bart import shift_tokens_right

from torch.nn.parallel import DistributedDataParallel as DDP
import os, sys




def Create_Student(student_id,teacher, student_layers, student_dim):

    student_intermediate_size = 2*student_dim
    student = AutoModelForCausalLM.from_pretrained(teacher_id, num_hidden_layers = student_layers, attn_implementation="flash_attention_2", \
                                        intermediate_size = student_intermediate_size, output_hidden_states=True, ignore_mismatched_sizes=True)

#    student.model.embed_tokens.weight.copy_(teacher.model.embed_tokens.weight.detach().clone())

    # scoreL = [0]*teacher.config.num_hidden_layers
    # for i in range(teacher.config.num_hidden_layers):
    #     _,indices = torch.sort(As[i],descending=True)
    #     scoreL[i] = As[i][indices[:student_intermediate_size]].sum()/As[i].sum()

    # scoreL = torch.stack(scoreL)
    # _,indxL = torch.sort(scoreL,descending=True)

    # indxL = indxL[:student_layers]
    # indxL,_ = torch.sort(indxL)
    for i in range(student_layers):
        Tix = 2*i
        student.model.layers[i].input_layernorm.weight.copy_(teacher.model.layers[Tix].input_layernorm.weight.detach().clone())
        student.model.layers[i].input_layernorm.bias.copy_(teacher.model.layers[Tix].input_layernorm.bias.detach().clone())

        student.model.layers[i].self_attn.q_proj.weight.copy_(teacher.model.layers[Tix].self_attn.q_proj.weight.detach().clone())
        student.model.layers[i].self_attn.q_proj.bias.copy_(teacher.model.layers[Tix].self_attn.q_proj.bias.detach().clone())

        student.model.layers[i].self_attn.k_proj.weight.copy_(teacher.model.layers[Tix].self_attn.k_proj.weight.detach().clone())
        student.model.layers[i].self_attn.k_proj.bias.copy_(teacher.model.layers[Tix].self_attn.k_proj.bias.detach().clone())

        student.model.layers[i].self_attn.v_proj.weight.copy_(teacher.model.layers[Tix].self_attn.v_proj.weight.detach().clone())
        student.model.layers[i].self_attn.v_proj.bias.copy_(teacher.model.layers[Tix].self_attn.v_proj.bias.detach().clone())        

        student.model.layers[i].self_attn.dense.weight.copy_(teacher.model.layers[Tix].self_attn.dense.weight.detach().clone())
        student.model.layers[i].self_attn.dense.bias.copy_(teacher.model.layers[Tix].self_attn.dense.bias.detach().clone())


        student.model.layers[i].mlp.fc1.weight.copy_(teacher.model.layers[Tix].mlp.fc1.weight[:student_intermediate_size,:].detach().clone())
        student.model.layers[i].mlp.fc1.bias.copy_(teacher.model.layers[Tix].mlp.fc1.bias[:student_intermediate_size].detach().clone())

        student.model.layers[i].mlp.fc2.weight.copy_(teacher.model.layers[Tix].mlp.fc2.weight[:,:student_intermediate_size].detach().clone())
        student.model.layers[i].mlp.fc2.bias.copy_(teacher.model.layers[Tix].mlp.fc2.bias.detach().clone())

    # student.model.final_layernorm.weight.copy_(teacher.model.final_layernorm.weight.detach().clone())
    # student.model.final_layernorm.bias.copy_(teacher.model.final_layernorm.bias.detach().clone())

    student.lm_head = nn.Linear(student_dim,student.config.vocab_size,bias = False)
    student.lm_head.weight.copy_(teacher.lm_head.weight.detach().clone())
    # student.lm_head.bias.copy_(teacher.lm_head.bias.detach().clone())


    return student




def Eval_Student(student,teacher,test_dataset, temperature = 1.0, batch_size = 32):
    eval_dataloader = DataLoader(test_dataset, shuffle=True, batch_size = batch_size)
    student.eval()
    teacher.eval()
    sumloss = [0]*(student.config.num_hidden_layers +3)    
    nBatch = 0
    for batch in eval_dataloader:
        ########### In Device 0 ###############
        batch = {k: v.to(device) for k, v in batch.items() if not isinstance(v,list)} 

        torch.cuda.empty_cache()
        loss = [0]*(student.config.num_hidden_layers +3)
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            decoder_input_ids = shift_tokens_right(batch['input_ids'], pad_token_id, pad_token_id)
            dec_mask = decoder_input_ids.ne(pad_token_id)

            teacher_out = teacher(input_ids = decoder_input_ids,attention_mask = dec_mask)

            ########### In Device 1 ###############
            student_out = student(input_ids = decoder_input_ids,attention_mask = dec_mask)
            dV = student_out.logits.size(-1)

            ########### In Device 0 ###############
            logit_mask = batch['attention_mask'].unsqueeze(-1).expand_as(teacher_out.logits).bool()
            SL = masked_select(student_out.logits,logit_mask).view(-1,dV)
            loss[0] = CELoss(SL,masked_select(batch['input_ids'],batch['attention_mask'].bool()))

            loss[1] = ((temperature)**2)* STLoss(F.log_softmax(SL/ temperature,dim=-1), F.softmax(masked_select(teacher_out.logits,logit_mask).view(-1,dV)/temperature, dim=-1))

            ############## Calculate CKA Loss #######

            teacher_mask = dec_mask.unsqueeze(-1).expand_as(teacher_out.hidden_states[-1]).bool()
            student_mask = dec_mask.unsqueeze(-1).expand_as(student_out.hidden_states[-1]).bool()
            dT = teacher_out.hidden_states[-1].size(-1)
            dS = student_out.hidden_states[-1].size(-1)             
                        
            target = torch.ones(dec_mask.sum()).to(device)
            
            loss[2] = EMBLoss(masked_select(student_out.hidden_states[0],student_mask).view(-1,dS), masked_select(teacher_out.hidden_states[0],teacher_mask).view(-1,dT)) 
            
            for i in range(student.config.num_hidden_layers):
                torch.cuda.empty_cache()
                loss[i+3] = CSLoss(masked_select(student_out.hidden_states[i+1],student_mask).view(-1,dS), masked_select(teacher_out.hidden_states[2*(i+1)],teacher_mask).view(-1,dT),target) 

        if(torch.isnan(torch.Tensor(loss)).sum()==0):
            for i in range(len(loss)): sumloss[i]+=loss[i]
            nBatch+=1

        if(nBatch%2000==0): break    
    
    return [l.item()/nBatch for l in sumloss]



os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

teacher_id = "microsoft/phi-2" 
torch.set_grad_enabled(False)
tokenizer = AutoTokenizer.from_pretrained(teacher_id, truncation_side = 'left',padding_side="left")
tokenizer.pad_token_id = tokenizer.eos_token_id
teacher = AutoModelForCausalLM.from_pretrained(teacher_id, output_hidden_states = True,\
                attn_implementation="flash_attention_2", use_cache = False)

for param in teacher.parameters(): param.requires_grad = False

print("Number of Teacher Parameters: ", sum(p.numel() for p in teacher.parameters()))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
student_dim = int(sys.argv[1])
student = Create_Student(teacher_id,teacher,16,student_dim)

print(student.config)
print(sum(p.numel() for p in student.parameters() if p.requires_grad))

############################ Stream C4 Dataset ###########################
torch.manual_seed(42)
torch.set_float32_matmul_precision('high')

train_dataset = load_from_disk('/Dump/Train_Phi')
test_dataset = load_from_disk('/Dump/Valid_Phi')

batch_size = 1
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size = batch_size)

############################ Training Starts ############################
teacher.to(device)
student.to(device)
teacher.eval()
no_decay = ["bias"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
     "weight_decay": 1e-1,
    },
    {"params": [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
     "weight_decay": 0.0,
    }
]

alpha = float(sys.argv[3])
LR = float(sys.argv[2])
K = int(sys.argv[5])
temperature = 1.0

optimizer = AdamW(optimizer_grouped_parameters , lr=LR*1e-5, betas = (0.9,0.999), eps = 1e-6)
f = open("../../Checkpoints/LOGIT-KD-Full-Phi-C4-xl-16-%d-%d-%d-%d.txt" % (student_dim,LR,alpha,K), "w+", buffering= 50)
f1 = open("../../Checkpoints/LOGIT-KD-Eval-Full-Phi-C4-xl-16-%d-%d-%d-%d.txt" % (student_dim,LR,alpha,K), "w+", buffering= 1)
from transformers import get_scheduler, get_cosine_schedule_with_warmup
num_epochs=60
batch_per_epoch = 50000
scheduler = get_cosine_schedule_with_warmup(optimizer = optimizer,num_warmup_steps=0, num_training_steps=num_epochs*batch_per_epoch)

STLoss = torch.nn.KLDivLoss(reduction = 'batchmean')
STLoss2 = torch.nn.KLDivLoss(reduction = 'none')
pad_token_id = tokenizer.pad_token_id

CSLoss = nn.CosineEmbeddingLoss() #CKALoss(eps = 1e-8)
CELoss = nn.CrossEntropyLoss()
EMBLoss = nn.MSELoss()
scaler = torch.cuda.amp.GradScaler()
  

nBatch = 0
nTokens = 0
teacher.eval()
nEpoch = 0
for batch in train_dataloader:

    ########### In Device 0 ###############
    batch = {k: v.to(device) for k, v in batch.items() if not isinstance(v,list)} 

    student.train()
    
    loss = [0]*(student.config.num_hidden_layers +3)   

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        decoder_input_ids = shift_tokens_right(batch['input_ids'], pad_token_id, pad_token_id)
        dec_mask = decoder_input_ids.ne(pad_token_id)

        teacher_out = teacher(input_ids = decoder_input_ids,attention_mask = dec_mask)
        logit_mask = batch['attention_mask'].unsqueeze(-1).expand_as(teacher_out.logits).bool()
        dV = teacher_out.logits.size(-1)
    
        TL = masked_select(teacher_out.logits,logit_mask).view(-1,dV)

        pT = F.softmax(TL.float(), dim=-1)        
        sorted_pT,sorted_idx = torch.sort(pT,dim=-1,descending = True)
        
    

        pTmax = torch.gather(pT,1,sorted_idx[:,:K]) #torch.where(pT>=percentile,pT,0)

                
        with torch.enable_grad():
            student_out = student(input_ids = decoder_input_ids,attention_mask = dec_mask)

            SL = masked_select(student_out.logits,logit_mask).view(-1,dV) 
            loss[0] = CELoss(SL,masked_select(batch['input_ids'],batch['attention_mask'].bool()))
           
            pS = F.softmax(SL.float(), dim=-1)        
            pSmax = torch.gather(pS,1,sorted_idx[:,:K])
            pT_not_max = torch.gather(pT,1,sorted_idx[:,K:]).sum(-1)
            pS_not_max = torch.gather(pS,1,sorted_idx[:,K:]).sum(-1)

            assert(pT.shape[0] == pT_not_max.shape[0])
        
            MKLD = STLoss(torch.log(torch.cat((pSmax,pS_not_max.unsqueeze(-1)),1)),torch.cat((pTmax,pT_not_max.unsqueeze(-1)),1))
        
            pT_nc = torch.gather(pT,1,sorted_idx[:,K:]) 
            pT_nc = pT_nc/pT_nc.sum(-1).unsqueeze(-1).expand_as(pT_nc)
        
            pS_nc = torch.exp(torch.gather(SL.double(),1,sorted_idx[:,K:])) 
            pS_nc = pS_nc/pS_nc.sum(-1).unsqueeze(-1).expand_as(pS_nc)
                            
            NKLD = STLoss2(torch.log(pS_nc),pT_nc)
                
            loss[1] = (MKLD + (alpha/pT_not_max.mean())*(pT_not_max*NKLD.sum(-1)).mean()) #Mean over Tokens
                            
                ############## Calculate CKA Loss #######

            teacher_mask = dec_mask.unsqueeze(-1).expand_as(teacher_out.hidden_states[-1]).bool()
            student_mask = dec_mask.unsqueeze(-1).expand_as(student_out.hidden_states[-1]).bool()
            dT = teacher_out.hidden_states[-1].size(-1)
            dS = student_out.hidden_states[-1].size(-1)             
                
            target = torch.ones(dec_mask.sum()).to(device)


            loss[2] = EMBLoss(masked_select(student_out.hidden_states[0],student_mask).view(-1,dS), masked_select(teacher_out.hidden_states[0],teacher_mask).view(-1,dT)) 
                
            for i in range(student.config.num_hidden_layers):
                torch.cuda.empty_cache()
                loss[i+3] = CSLoss(masked_select(student_out.hidden_states[i+1],student_mask).view(-1,dS), masked_select(teacher_out.hidden_states[2*(i+1)],teacher_mask).view(-1,dT),target)             

            loss_sum = sum(loss)
            torch.cuda.empty_cache()
            scaler.scale(loss_sum).backward()

        nBatch+=1
        nTokens += batch['attention_mask'].sum().item()
        if(nBatch%32 ==0):
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
            f.write(str(['%.3f' % l.item() for l in loss])+'\n')
                
    if(nBatch%batch_per_epoch==0):
        loss = Eval_Student(student,teacher,test_dataset, temperature, batch_size)
        f1.write(str(nTokens) + str(['%.3f' % l for l in loss]) +  '\n')
        f1.flush()
        torch.save(student.state_dict(),"../../Checkpoints/Phi/Phi-C4-xl-16-Full-%d-%d-%d-%d.pt" % (student_dim,LR,alpha,K))
        nEpoch+=1
        
    

torch.save(student.state_dict(),"../../Checkpoints/Phi/Phi-C4-xl-16-Full-%d-%d-%d-%d.pt" % (student_dim,LR,alpha,K))

f.close()
f1.close()