import torch
from experiments.ar_ic.evaluate import evaluate_model
from tqdm import tqdm
import matplotlib.pyplot as plt
from exp_utils.utils import save_checkpoint


def train_model(
    model,
    train_loader,
    optimizer,
    loss_function,
    device,
    wandb_writer,
    iteration,
    remaining_iters,
    eval_interval=10,
    scheduler=None,
    out_dir=None,
    save_interval=0.5e6,
    plot_interval=200000,
    save_checkpoints=True,
):
    train_loss_history = []
    for _, sequences in tqdm(enumerate(train_loader), total=remaining_iters, initial=iteration):
        iteration += 1
        # Evaluation and logging
        if iteration % eval_interval == 0:
            evaluate_model(
                model,
                train_loader,
                loss_function,
                device,
                iteration,
                wandb_writer,
                out_dir,
                plot_interval,
            )  # Evaluating on 50 sequences
            torch.cuda.empty_cache()

        # Prepare inputs and targets
        inputs = sequences[:, :-1].long()
        targets = sequences[:, 1:].long()
        model.train()

        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        # loss = loss_function(outputs.transpose(1, 2), targets)
        # compute the loss skipping the first k tokens in the sequence where k is the maximum order of the markov chains in the train_loader
        loss = loss_function(
            outputs[:, -1, :].unsqueeze(1).transpose(1, 2), targets[:, -1].unsqueeze(1)
        )

        # Backward and optimize
        loss.backward()

        optimizer.step()
        if scheduler is not None:
            scheduler.step()
            lr = scheduler.get_last_lr()
            if wandb_writer is not None:
                wandb_writer.log({"Learning Rate": lr[0]}, step=iteration)
        if iteration % eval_interval == 0:
            if wandb_writer is not None:
                # Log scalar loss value
                wandb_writer.log({"Train Loss/Transformer": loss.item()}, step=iteration)
                # Calculate and log training accuracy
                preds = torch.argmax(outputs[:, -1, :], dim=-1)
                correct_predictions = (preds == targets[:, -1]).float()
                accuracy = correct_predictions.sum() / len(correct_predictions)
                wandb_writer.log({"Train Accuracy/Transformer": accuracy.item()}, step=iteration)
            train_loss_history.append(loss.item())  # Append scalar loss to history
        if save_checkpoints and iteration % save_interval == 0 and iteration > 0:
            save_checkpoint(f"{out_dir}/models/model_{iteration}.pth", model, optimizer, iteration)

    return {"train_loss_history": train_loss_history}  # Return the collected history
