# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import Dict
import time
import torch
from torch.utils.data import DataLoader
from .dataset import PromptDataset
from .net import FunctionalReversalNet
from .loss_fns import custom_loss_plus
from .model_utils import get_model_device_dtype


def train_net(
    data: Dict[str, float],
    max_tokens: int,
    model,
    tokenizer,
    epochs: int = 10,
    batch_size: int = 8,
    learning_rate: float = 1e-4,
    device: str = "cuda:0",
    capture_mode: str = "token",
    alpha: float = 0.7,
    log_file: str = "",
    target_dtype=None,
) -> FunctionalReversalNet:
    model_device, model_dtype = get_model_device_dtype(model)
    if target_dtype is None:
        target_dtype = model_dtype
    device_t = torch.device(device)
    d_model = model.config.hidden_size
    dataset = PromptDataset(
        data=data,
        max_tokens=max_tokens,
        model=model,
        tokenizer=tokenizer,
        device=device_t,
        capture_mode=capture_mode,
        target_dtype=target_dtype,
    )
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    net = FunctionalReversalNet(max_tokens=max_tokens, d_model=d_model, dtype=target_dtype).to(device_t)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    for _epoch in range(1, epochs + 1):
        net.train()
        total_loss = 0.0
        for batch in dataloader:
            seq_i, y_i, seq_j, y_j, flat_i, flat_j = batch
            seq_i = seq_i.to(device=device_t, dtype=target_dtype)
            seq_j = seq_j.to(device=device_t, dtype=target_dtype)
            flat_i = flat_i.to(device=device_t, dtype=target_dtype)
            flat_j = flat_j.to(device=device_t, dtype=target_dtype)
            optimizer.zero_grad(set_to_none=True)
            pred_i = net(seq_i)
            pred_j = net(seq_j)
            loss, _details = custom_loss_plus(
                pred_i,
                pred_j,
                y_i,
                y_j,
                embedding_1_flat=flat_i,
                embedding_2_flat=flat_j,
                alpha=alpha,
                fc_layer=net.fc,
            )
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
    return net