import torch
from experiments.single_location_linear_regression.evaluate import evaluate_model  # Adjusted import
from tqdm import tqdm
from functools import partial
from exp_utils import utils
import torch.nn.functional as F  # Need loss function like MSE
import wandb
import collections  # Added for defaultdict


def train_model(
    model,
    train_loader,
    optimizer,
    device,
    wandb_writer,
    iteration,
    remaining_iters,
    eval_interval=100,
    scheduler=None,
    out_dir=None,
    save_interval=0.5e6,
    plot_interval=200000,
    eval_loader_generalization=None,
    config=None,
    save_checkpoint=True,
):
    """
    Trains the model for the Single Location Linear Regression task.

    Args:
        model: The model to train (expects input shape [batch, seq_len, dim]).
        train_loader: Dataloader yielding (X, Y) batches.
        optimizer: The optimizer.
        device: Torch device.
        wandb_writer: WandB writer object.
        iteration: Starting iteration number.
        remaining_iters: Total number of iterations to run.
        eval_interval: How often to evaluate and log.
        scheduler: Learning rate scheduler (optional).
        out_dir: Directory to save checkpoints and plots.
        save_interval: How often to save model checkpoints.
        plot_interval: How often evaluation should generate plots.
        eval_loader_generalization: Optional secondary evaluation dataloader for the specified burstiness_eval and p_repeat_eval.
        save_checkpoint: Flag to save checkpoints.

    Returns:
        dict: A dictionary containing loss histories:
              'train_loss_history': List of training losses recorded at eval intervals.
              'eval_loss_history': Dictionary mapping type of evaluation loss type to list of losses.
    """
    _loss_function = partial(F.mse_loss, reduction="none")
    # Divide loss by 2 as pytorch's MSE does not divide by 2
    loss_function = lambda x, y: _loss_function(x, y) / 2

    train_loss_history = []
    eval_loss_history = {"ID": [], "OOD": []}  # Store eval loss for both ID and OOD

    for _, (X_batch, Y_batch) in tqdm(
        enumerate(train_loader), total=remaining_iters, initial=iteration
    ):

        # Move data to device
        X_batch = X_batch.to(device)
        Y_batch = Y_batch.to(device)

        # Evaluation and logging
        if iteration % eval_interval == 0:
            # ID evaluation (training distribution)
            eval_loss_ID = evaluate_model(
                model,
                train_loader,
                loss_function,
                device,
                iteration,
                wandb_writer,
                out_dir,
                plot_interval=plot_interval,
                prefix="ID",
                config=config,
            )
            eval_loss_history["ID"].append(eval_loss_ID)

            # Secondary evaluation (if provided)
            if eval_loader_generalization is not None:
                eval_loss_OOD = evaluate_model(
                    model,
                    eval_loader_generalization,
                    loss_function,
                    device,
                    iteration,
                    wandb_writer,
                    out_dir,
                    plot_interval=plot_interval,
                    prefix="OOD",
                    config=config,
                )
                eval_loss_history["OOD"].append(eval_loss_OOD)

            torch.cuda.empty_cache()

        # Prepare inputs and targets
        model.train()

        # Forward pass
        optimizer.zero_grad()
        outputs = model(X_batch)
        predicted_Y = outputs[:, -1, :]

        # Calculate loss
        loss = loss_function(predicted_Y, Y_batch).sum(axis=-1).mean()
        current_train_loss = loss.item()

        # Backward and optimize
        loss.backward()
        optimizer.step()  # Adam step calculation happens here

        # Log Adam gradient rescaling magnitude
        if (
            wandb_writer is not None
            and isinstance(optimizer, torch.optim.Adam)
            and iteration % eval_interval == 0
        ):
            with torch.no_grad():
                learning_rate_metrics = {}
                eps = optimizer.param_groups[0]["eps"]  # Get epsilon from optimizer config

                for name, p in model.named_parameters():
                    if p.grad is None:
                        continue
                    if p in optimizer.state:
                        state = optimizer.state[p]
                        if "exp_avg_sq" in state:
                            exp_avg_sq = state["exp_avg_sq"]
                            # Calculate the magnitude: sqrt(v_t) + eps
                            # Take mean over the parameter tensor for a single value
                            rescaling_magnitude = (torch.sqrt(exp_avg_sq) + eps).mean().item()

                            # Log magnitude per parameter name
                            learning_rate_metrics[f"rescaled_lr/{name}"] = (
                                optimizer.param_groups[0]["lr"] / rescaling_magnitude
                            )

                if learning_rate_metrics:  # Log if we collected any metrics
                    wandb_writer.log(
                        {"Learning rate metrics": learning_rate_metrics}, step=iteration
                    )

        # Scheduler step
        if scheduler is not None:
            scheduler.step()
            lr = scheduler.get_last_lr()
            if wandb_writer is not None:
                wandb_writer.log({"Learning Rate": lr[0]}, step=iteration)

        # Log and store training loss (at eval interval)
        if iteration % eval_interval == 0:
            train_loss_history.append(current_train_loss)
            if wandb_writer is not None:
                wandb_writer.log({"Train Loss/MSE": current_train_loss}, step=iteration)

        # Save checkpoint
        if save_checkpoint and iteration % save_interval == 0 and iteration > 0:
            utils.save_checkpoint(
                f"{out_dir}/models/model_{iteration}.pth", model, optimizer, iteration
            )

        if iteration >= train_loader.iters:  # Check if dataloader epoch finished
            print(f"Completed {iteration} iterations (end of dataloader epoch).")
            # The dataloader will automatically regenerate if loop continues

        iteration += 1

        if torch.isnan(loss).any():
            print(f"NaN loss encountered at iteration {iteration}. Stopping training.")
            break

    # Return collected histories
    results = {
        "train_loss_history": train_loss_history,
        "eval_loss_history": eval_loss_history,
    }
    return results
