import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from symo.utils import to_numpy


def mlp_kernel_init(tensor: torch.Tensor):
    if tensor.ndim >= 2:
        fan_in = tensor.shape[1]
    else:
        fan_in = tensor.shape[0]

    std = 1.0 / torch.sqrt(torch.tensor(fan_in, dtype=torch.float32))

    with torch.no_grad():
        tensor.normal_(0.0, std.item())


def mlp_bias_init(tensor: torch.Tensor):
    std = 1.0 / 3.0
    with torch.no_grad():
        tensor.normal_(0.0, std)


def train_loop(
    model,
    train_data,
    val_data,
    optimizer,
    train_step_fn,
    eval_fn,
    num_epochs: int,
    record_metrics: bool = False,
    print_output: bool = False,
):
    train_losses = []
    val_losses = []
    metrics_list = [] if record_metrics else None

    for epoch in range(num_epochs):

        if record_metrics:
            loss, metrics = train_step_fn(model, optimizer, train_data)
            metrics_list.append(metrics)
        else:
            loss, _ = train_step_fn(model, optimizer, train_data)

        train_losses.append(to_numpy(loss))

        val_loss = eval_fn(model, val_data)
        val_losses.append(to_numpy(val_loss))

        if print_output:
            print(
                f"Epoch [{epoch+1:4d}/{num_epochs}] | "
                f"Train Loss: {loss:.4f} | "
                f"Val Loss: {val_loss:.4f}"
            )

    if metrics_list is not None and len(metrics_list) > 0:
        metrics = metrics_list[0]
        metric_cls = metrics.__class__
        out = {}
        for m in metrics_list:
            for k, v in m._asdict().items():
                if k not in out:
                    out[k] = []
                out[k].append(v)
        metrics_list = metric_cls(*[out[k] for k in metrics._fields])

    return train_losses, val_losses, metrics_list
