def train_gaze_attention(args, model, epoch, custom_loader, tokenizer, optimizer, lr_scheduler, device_id, wandb):


    reg_value = 100
    #what does this do, check
    total_loss = 0.0  # Track total loss across all batches
    total_samples = 0  # Track total samples processed
    image_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
    endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)[
        "input_ids"
    ][-1]

    model.train()
    print("training started")

    for num_steps, batch in enumerate(custom_loader):
        # data_time_m.update(time.time() - end)
        batch_size = batch[0].size(0)  # Assuming batch[0] is your input tensor - should be 32

        with torch.cuda.amp.autocast(enabled=args.precision != "fp32"):
            images,overlays,gaze,input_ids, attention_mask, labels = prepare_batch_gaze_attention(batch, tokenizer, device_id,image_token_id,endofchunk_token_id)
            outputs,attn_weights = model(images,overlays,input_ids, attention_mask, labels)
            target_dist= calculate_gaze_proportions_batch(gaze,16,16,device_id)
            print("shape of target dist",target_dist.shape)
            print("shape of attn dist",attn_weights.shape)

            kl_loss= KL_divergence(target_dist,attn_weights,device_id)

            loss = (outputs.loss+ reg_value*kl_loss) / args.gradient_accumulation_steps
            if device_id==0:
                print("KL loss is ",kl_loss)
                print("loss from outputs is ",outputs.loss)
            loss = outputs.loss / args.gradient_accumulation_steps
        loss.backward()
        torch.cuda.empty_cache()  
        
        if (num_steps + 1) % args.gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            #lr_scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            

        # Aggregate loss across all processes
        loss_tensor = torch.tensor([loss.item() * batch_size], device=device_id)
        torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM)
        total_loss += loss_tensor.item()
        total_samples += batch_size * torch.distributed.get_world_size()
        if device_id==0 and (num_steps % 20 == 0):
            print(
                f"Step {num_steps+1} of epoch {epoch}/{args.epochs} complete. Loss is: {loss.item():.3f}")
            print("KL loss is ",kl_loss)
            print("loss from outputs is ",outputs.loss)                

    lr_scheduler.step()
    avg_loss = total_loss / total_samples
    if torch.distributed.get_rank() == 0:
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")
        wandb.log({"loss_epoch":avg_loss})
        logits = outputs.logits 
        token_ids = logits_to_token_ids(logits)
        predicted_text = decode_token_ids(tokenizer,token_ids)
        new_text = decode_token_ids(tokenizer,input_ids)
        for pred_text, gt_text in zip(predicted_text, new_text):
            i=1
            print(pred_text)
            if(i==1):
                break

        
    torch.cuda.empty_cache()
    return avg_loss
