from tqdm import tqdm
from moelayer import MoELoRA,LoRA
import torch
import time
import os
from evaluate import valid_path,eval_main_qwen,eval_main_qwen2 # type: ignore

def train_epoch(model, tokenizer,dataloader, optimizer, lr_scheduler,device, selected_tasks,valid_dataset,valid_batchsize=16,label_length=8,beta=0.01, lambda_1=0.01,
                use_aux_loss=False, use_orth_loss=False, save_steps=2500, save_path="./checkpoints"): 
    model.train()
    best_score = 0
    total_loss = 0
    total_aux_loss = 0
    total_orth_loss = 0
    start_time = time.time()
    gradient_accumulation_steps = 4 

    progress_bar = tqdm(dataloader, desc="Training")
    for i, batch in enumerate(progress_bar):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        task_ids = batch['task_id'].to(device)

        model.current_task_ids = task_ids
        model.set_current_task_ids_to_layers()

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        task_loss = outputs.loss

        aux_losses = []
        orth_losses = []
        for name, module in model.named_modules():
            if isinstance(module, MoELoRA):
                if use_aux_loss:
                    aux_loss = module.get_aux_loss()
                    if aux_loss is not None:
                        aux_losses.append(aux_loss)
                if use_orth_loss:
                    orth_loss = module.compute_orthogonality_loss()
                    orth_losses.append(orth_loss)
    
        if use_aux_loss:
            avg_aux_loss = (sum([l.to(device) for l in aux_losses]) / len(aux_losses) if aux_losses else torch.tensor(0.0, device=device))
            total_aux_loss += avg_aux_loss.item()   # type:ignore
        if use_orth_loss:
            avg_orth_loss = (sum([l.to(device) for l in orth_losses]) / len(orth_losses)if orth_losses else torch.tensor(0.0, device=device))
            total_orth_loss += avg_orth_loss.item() # type:ignore

        loss = task_loss
        if use_aux_loss:loss += beta * avg_aux_loss
        if use_orth_loss:loss += lambda_1 * avg_orth_loss

        loss.backward()
        if (i + 1) % gradient_accumulation_steps == 0:
            for module in model.modules():
                if isinstance(module, MoELoRA):
                    module.sparsify_gradients_by_momentum_with_random(optimizer,i+1)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()  
            
        total_loss += task_loss.item()
        elapsed_time = time.time() - start_time
        avg_time_per_batch = elapsed_time / (i + 1)
        remaining_batches = len(dataloader) - (i + 1)
        remaining_time = avg_time_per_batch * remaining_batches
        remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))

        progress_bar.set_postfix({
            "task_loss": f"{task_loss:.4f}",
            "aux_loss": f"{avg_aux_loss:.4f}" if use_aux_loss else "OFF",
            "orth_loss": f"{avg_orth_loss:.4f}" if use_orth_loss else "OFF",
            "total_avg_task_loss": f"{total_loss / (i + 1):.4f}",
            "total_avg_aux_loss": f"{total_aux_loss / (i + 1):.4f}" if use_aux_loss and aux_losses else "N/A",
            "total_avg_orth_loss": f"{total_orth_loss / (i + 1):.4f}" if use_orth_loss and orth_losses else "N/A",
            "ETA": remaining_time_str
        })

        if (i + 1) % save_steps == 0:
            model.eval()
            score = eval_main_qwen2(model, tokenizer, selected_tasks, 
                                    valid_dataset,batch_size=valid_batchsize,label_length=label_length)
            model.train()
            print(f"Step {i+1}, Score: {score}")
            if score > best_score:
                best_score = score
                print(f"the best score is {best_score}")

    avg_task_loss = total_loss / len(dataloader)
    avg_aux_loss = total_aux_loss / len(dataloader) if use_aux_loss and aux_losses else 0
    avg_orth_loss = total_orth_loss / len(dataloader) if use_orth_loss and orth_losses else 0
    return avg_task_loss, avg_aux_loss, avg_orth_loss
