import torch
import torch.nn as nn
import torch.optim as optim
import utils.eval_utils as eval_utils
from tqdm import tqdm
from torch.nn import functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
import os


def train(teacher, student, train_loader, valid_loader, device, save_path, lr, num_epochs, accumulation_steps=4, kd_loss_scale=0.01):
    os.makedirs(f'{save_path}', exist_ok=True)
    
    student.train()
    teacher.eval()
    
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    
    optimizer = optim.AdamW(student.parameters(), lr=lr, weight_decay=0.,)
    scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader)*num_epochs//accumulation_steps)
    
    ppl = eval_utils.calculate_perplexity(student, valid_loader, device)
    print(f'Initial Validation Perplexity: {ppl:.4f}')
    with open(f"{save_path}/ppl.jsonl", "a") as f:
        f.write(json.dumps({"ppl": ppl.item()}) + "\n")
    
    for epoch in range(num_epochs):
        step = 0
        optimizer.zero_grad()  

        accumulated_loss = 0
        accumulation_counter = 0
        
        for batch_idx, inputs in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')):
            inputs['input_ids'] = inputs['input_ids'].to(device)
            inputs['attention_mask'] = inputs['attention_mask'].to(device)
            inputs['labels'] = inputs['labels'].to(device)
            
            with torch.no_grad():
                teacher_outputs = teacher(**inputs).logits
                teacher_probs_T = F.softmax(teacher_outputs, dim=2)
            
            del teacher_outputs
            teacher_probs_T = teacher_probs_T.to(device)
            
            student_out = student(**inputs)
            student_outputs = student_out.logits
            student_log_probs_T = F.log_softmax(student_outputs, dim=2)
            student_loss = student_out.loss
            
            del student_out
            del student_outputs
            
            loss = kd_loss_scale * kl_loss(student_log_probs_T, teacher_probs_T) + (1-kd_loss_scale) * student_loss
            
            original_loss_value = loss.item()
            
            accumulated_loss += original_loss_value
            accumulation_counter += 1
            
            loss = loss / accumulation_steps
            
            loss.backward()
            
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx == len(train_loader) - 1):
                avg_accumulated_loss = accumulated_loss / accumulation_counter
                
                with open(f"{save_path}/loss.jsonl", "a") as f:
                    f.write(json.dumps({
                        "loss": avg_accumulated_loss, 
                        "step": step, 
                        "batch_idx": batch_idx,
                        "accumulation_steps": accumulation_counter
                    }) + "\n")
                
                accumulated_loss = 0
                accumulation_counter = 0
                
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            
            step += 1
            if step % 2000 == 0:
                ppl = eval_utils.calculate_perplexity(student, valid_loader, device)
                print(f'Step {step}, Validation Perplexity: {ppl:.4f}')
                with open(f"{save_path}/ppl.jsonl", "a") as f:
                    f.write(json.dumps({"ppl": ppl.item(), "step": step, "epoch": epoch}) + "\n")
                torch.cuda.empty_cache()
        
        ppl = eval_utils.calculate_perplexity(student, valid_loader, device)
        print(f'Validation Perplexity: {ppl:.4f}')
        with open(f"{save_path}/ppl.jsonl", "a") as f:
            f.write(json.dumps({"ppl": ppl.item(), "epoch": epoch}) + "\n")
        
        os.makedirs(f'{save_path}', exist_ok=True)
        torch.save(student.state_dict(), f'{save_path}/ppl{ppl:.4f}.pth')
