from contextlib import nullcontext
from functools import partial
from itertools import cycle
import json
import math
import os
from pathlib import Path
import pprint
import sys
import time
from types import SimpleNamespace
from typing import Literal, Tuple, Union, Optional

from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
import numpy as np
from peft import get_peft_model
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
import wandb

from core.llm import LLM
from core.utils import num_parameters, DualOutput
from training.metrics import Aggregator
from training.student_teacher_dataset import StudentTeacherDataset, TeacherDataset, IGNORE_INDEX
from training.utils import generate_answers, InfiniteSampler, save_with_base_model_config, save_with_deepspeed
from training.tulu_dataset import tulu_collate_fn, merge_with_tulu_batch

def get_ds_plugin(deepspeed_config_path: str) -> DeepSpeedPlugin:
    """
    Load a DeepSpeed configuration JSON file and return a DeepSpeedPlugin instance.
    """
    with open(deepspeed_config_path, "r") as f:
        ds_config = json.load(f)
    return DeepSpeedPlugin(hf_ds_config=ds_config)


# Check if model parameters are in bf16
def check_model_dtype(model):
    return all(p.dtype == torch.bfloat16 for p in model.parameters())


def init_logit_train_dataset(base_llm, logit_train_file_list, hparams):
    logit_train_dataset = StudentTeacherDataset(
        base_llm, logit_train_file_list, verbose=hparams.verbose,
        datapath=hparams.datapath,
        max_length=hparams.max_length,
    )
    if hparams.logit_loss_weight and len(logit_train_dataset) == 0:
        raise ValueError("No logit training data available.")
    return logit_train_dataset
    

def init_logit_train_dataloader(base_llm, logit_train_dataset, hparams):
    if hparams.logit_loss_weight:
        logit_train_dataloader = DataLoader(
            logit_train_dataset,
            batch_size=hparams.logit_loss_micro_batch_size,
            collate_fn=partial(logit_train_dataset.collate_fn, padding_value=0, llm=base_llm),
            shuffle=True,
        )
    else:
        logit_train_dataloader = None
    return logit_train_dataloader


def init_token_train_dataset(base_llm, token_train_file_list, hparams):
    token_train_dataset = TeacherDataset(
        base_llm, token_train_file_list, verbose=hparams.verbose,
        datapath=hparams.datapath,
        max_length=hparams.max_length,
        distractor_dataset=hparams.distractor_dataset,
    )
    return token_train_dataset


def init_token_train_dataloader(base_llm, token_train_dataset, hparams, infinite=True):
    if len(token_train_dataset) == 0 or hparams.token_loss_micro_batch_size == 0 or hparams.token_loss_weight == 0.:
        token_train_dataloader = None
        if hparams.verbose:
            print("No token training data available.")
        n_token_micro_batches_per_batch = 0
    else:
        sampler = InfiniteSampler(len(token_train_dataset)) if infinite else None
        token_train_dataloader = DataLoader(
            token_train_dataset,
            batch_size=hparams.token_loss_micro_batch_size,
            collate_fn=partial(token_train_dataset.collate_fn, padding_value=0, llm=base_llm,
                               max_total_length=hparams.max_total_length),
            sampler=sampler,
        )
    return token_train_dataloader
    

def init_validation_datasets_and_dataloaders(base_llm, accelerator, logit_val_file_list, token_val_file_list, hparams):
    ##### Validation data #####
    logit_val_dataset = StudentTeacherDataset(
        base_llm, logit_val_file_list, verbose=hparams.verbose,
        datapath=hparams.datapath,
    )
    token_val_dataset = TeacherDataset(
        base_llm, token_val_file_list, verbose=hparams.verbose,
        datapath=hparams.datapath,
        distractor_dataset=hparams.distractor_dataset,
    )

    if len(logit_val_dataset) == 0:
        logit_val_dataloader = None
    else:
        logit_val_dataloader = DataLoader(
            logit_val_dataset,
            batch_size=hparams.logit_loss_micro_batch_size,
            collate_fn=partial(logit_val_dataset.collate_fn, padding_value=0, llm=base_llm)
        )

    if len(token_val_dataset) == 0:
        token_val_dataloader = None
    else:
        assert hparams.token_loss_micro_batch_size > 0, "token loss micro batch size is zero."
        token_val_dataloader = DataLoader(
            token_val_dataset,
            batch_size=hparams.token_loss_micro_batch_size,
            collate_fn=partial(token_val_dataset.collate_fn, padding_value=0, llm=base_llm)
        )

    if hparams.verbose:
        print(f"Validation data: token loss: {len(token_val_dataset)} examples, logit loss: {len(logit_val_dataset)} examples\n\n")

    token_val_dataloader, logit_val_dataloader = accelerator.prepare(
        token_val_dataloader, logit_val_dataloader
    )
    return token_val_dataloader, logit_val_dataloader


def initialize_run(
    hparams: SimpleNamespace,
):
    run_name = hparams.run_name
    hparams.run_project_dir = hparams.project_path / run_name
    os.makedirs(hparams.run_project_dir, exist_ok=True)

    project_config = ProjectConfiguration(
        automatic_checkpoint_naming=True,
    )

    if hparams.deepspeed_path and hparams.teacher in ["student", "student_base"]:
        # No separate teacher model
        deepspeed_plugin = get_ds_plugin(hparams.deepspeed_path)
    elif hparams.deepspeed_path:
        student_plugin = DeepSpeedPlugin(hf_ds_config=hparams.deepspeed_path)
        teacher_plugin = DeepSpeedPlugin(hf_ds_config=hparams.deepspeed_path_teacher)
        deepspeed_plugin = {"student": student_plugin, "teacher": teacher_plugin}
    else:
        deepspeed_plugin = None

    accelerator = Accelerator(
        mixed_precision=hparams.mixed_precision,
        project_dir=hparams.run_project_dir,
        project_config=project_config,
        deepspeed_plugin=deepspeed_plugin
    )
    hparams.devices = accelerator.state.num_processes if hparams.deepspeed_path else 1
    hparams.verbose = accelerator.is_main_process
    output_log = hparams.project_path / run_name / f"output_{accelerator.process_index}.log"
    # Set the output to the console and to a file
    sys.stdout = DualOutput(output_log)

    if hparams.verbose:
        print("run_name:", run_name)
        print("Hyperparameters:")
        pprint.pprint(hparams)

    hparams.log_to_wandb = hparams.use_wandb and accelerator.is_main_process
    hparams.generate = hparams.generation_interval is not None and hparams.generation_interval != 0
    hparams.checkpoint = hparams.checkpoint_interval is not None and hparams.checkpoint_interval != 0

    if hparams.log_to_wandb:
        wandb.init(
            project=hparams.project_path.name,
            name=run_name,
            group=hparams.group_name,
            allow_val_change=True,
            config=hparams,
        )
    return accelerator


def init_models(base_llm, accelerator, hparams):
    ##### Model #####
    if hparams.verbose:
        print(f"Loading model from disk", flush=True)
    student = base_llm.load_model(training=True, deepspeed=hparams.deepspeed_path)

    if hparams.verbose:
        print(f"Number of trainable parameters: {num_parameters(student, requires_grad=True):,}")
        print(f"Number of non trainable parameters: {num_parameters(student, requires_grad=False):,}", flush=True)

    if hparams.verbose:
        print("Preparing student", flush=True)

    student = get_peft_model(student, hparams.peft_config)
    if not hparams.deepspeed_path and hparams.mixed_precision == 'bf16':
        student = student.to(torch.bfloat16)

    if hparams.verbose:
        print("Peft model created", flush=True)
        print("Tuned student model")
        print(student, flush=True)
        print(f"Number of trainable parameters: {num_parameters(student, requires_grad=True):,}")
        print(f"Number of non trainable parameters: {num_parameters(student, requires_grad=False):,}", flush=True)

    if hparams.teacher in ("student", "student_base"):
        teacher = hparams.teacher
        teacher_llm = None
    else:
        teacher_llm = LLM(hparams.teacher, opening_message=hparams.opening_message)

    optimizer = torch.optim.AdamW(student.parameters(), lr=hparams.learning_rate, weight_decay=hparams.weight_decay)

    accelerator.register_for_checkpointing(student)
    student, optimizer = accelerator.prepare(student, optimizer)
    if hparams.verbose:
        print("project_dir", accelerator.project_dir, flush=True)

    # Verify student and optimizer
    is_student_bf16 = check_model_dtype(student)
    if hparams.verbose:
        print(f"Student model is bf16: {is_student_bf16}")
        
    if teacher_llm is not None:
        accelerator.state.select_deepspeed_plugin("teacher")
        teacher = teacher_llm.load_model(training=False, deepspeed=True)
        teacher = accelerator.prepare(teacher)
        teacher.eval()

    return student, teacher, optimizer

def log_to_wandb(accelerator, metrics_total, metrics_by_group, step, hparams):
    if hparams.log_to_wandb and accelerator.is_main_process:
        # Log aggregated metrics
        wandb.log(metrics_total, step=step)
        wandb.log(metrics_by_group, step=step)


def update_lr(step, optimizer, train_metrics, hparams, is_logging, max_steps):
    if step <= hparams.warmup_steps and hparams.warmup_steps:
        # linear warmup
        lr = hparams.learning_rate * step / hparams.warmup_steps
    else:
        if hparams.decay:
            multiplier = max(0.0, float(max_steps - step) / float(max(1, max_steps - hparams.warmup_steps)))
            lr = hparams.learning_rate * multiplier
        else:
            lr = hparams.learning_rate

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    if is_logging:
        train_metrics['lr'] = lr


def save_maybe_deepspeed(student, accelerator, base_llm, hparams):
    if hparams.deepspeed_path:
        save_with_deepspeed(student, accelerator, base_llm, hparams.run_project_dir)
    else:
        save_with_base_model_config(student, base_llm, hparams.run_project_dir)


def log_step(step, is_logging, accelerator, train_metrics, hparams, step_t0, max_steps, chunk=None, n_chunks=None, total_step=None, total_max_steps=None, all_losses=None, step_times=None):
    if is_logging and accelerator.is_main_process:
        t1 = time.perf_counter()
        if step_times is not None:
            step_times.append((t1 - step_t0) * 1000)
        loss_to_print = train_metrics.get('logit_loss', train_metrics.get('token_loss', None))
        if all_losses is not None:
            all_losses.append(loss_to_print)
            avg_loss = sum(all_losses)/len(all_losses)
        else:
            avg_loss = None
        loss_type = 'logit' if 'logit_loss' in train_metrics else 'token'

        print(
            (f"Total step {total_step}/{total_max_steps}, " if n_chunks else "") +
            (f"Chunk {chunk}/{n_chunks}, " if n_chunks else "") +
            f"Step {step}/{max_steps}: {loss_type} loss {loss_to_print:.8f}, "
            f"iter time: {(t1 - step_t0) * 1000:.2f}ms" +
            (f", avg iter time: {np.mean(step_times):.2f}ms" if step_times else "") +
            (f", Total avg loss {avg_loss:.4f} " if avg_loss else ""),
            flush=True
        )

        if hparams.log_to_wandb:
            train_metrics['step_time'] = (t1 - step_t0) * 1000
            if total_step is not None:
                wandb.log(train_metrics, step=total_step)
            else:
                wandb.log(train_metrics, step=step)

def train(
    project_path: Path,
    base_llm: LLM,
    data: Union[list[str], Tuple[list[str], list[str]]],
    hparams: SimpleNamespace,
    teacher_llm: Optional[LLM] = None,
) -> None:
    accelerator = initialize_run(hparams)

    ##### Training data files #####
    logit_train_file_list = []
    token_train_file_list = []
    if hparams.logit_loss_weight > 0:
        logit_train_file_list = data[0]
    if hparams.token_loss_weight > 0:
        token_train_file_list = data[0]

    logit_val_file_list = data[1]
    token_val_file_list = data[1]

    ##### Train datasets and loaders #####
    logit_train_dataset = init_logit_train_dataset(base_llm, logit_train_file_list, hparams)
    logit_train_dataloader = init_logit_train_dataloader(base_llm, logit_train_dataset, hparams)
    token_train_dataset = init_token_train_dataset(base_llm, token_train_file_list, hparams)
    token_train_dataloader = init_token_train_dataloader(base_llm, token_train_dataset, hparams)

    if hparams.verbose:
        print(f"Training data: token loss: {len(token_train_dataset)} examples, logit loss: {len(logit_train_dataset)} examples\n\n")

    logit_train_dataloader, token_train_dataloader  = accelerator.prepare(
        logit_train_dataloader, token_train_dataloader
    )
    if logit_train_dataloader:
        logit_train_dataloader = cycle(logit_train_dataloader)

    ##### Validation data #####
    token_val_dataloader, logit_val_dataloader = init_validation_datasets_and_dataloaders(base_llm, accelerator,
                                                                                          logit_val_file_list,
                                                                                          token_val_file_list,
                                                                                          hparams)

    ##### Init models #####
    student, teacher, optimizer = init_models(base_llm, accelerator, hparams)

    ##### Set up training #####
    if hparams.logit_loss_weight:
        n_batches = math.ceil(len(logit_train_dataset) / hparams.n_logit_micro_batches_per_batch / hparams.logit_loss_micro_batch_size / hparams.devices)
    else:
        n_batches = math.ceil(len(token_train_dataset) / hparams.n_token_micro_batches_per_batch / hparams.token_loss_micro_batch_size / hparams.devices)
    max_steps = hparams.n_epochs * n_batches
    if hparams.verbose:
        print(f"Training for {hparams.n_epochs} epochs, {max_steps} iterations", flush=True)

    if hparams.tulu:
        tulu_dataset = load_dataset("allenai/tulu-3-sft-mixture", split="train")
        tulu_dataloader = DataLoader(
            tulu_dataset,
            batch_size=hparams.tulu_batch_size*2, # We sample extra to discard invalid ones
            shuffle=True,
            collate_fn=partial(tulu_collate_fn, padding_value=0, llm=base_llm, max_length=hparams.max_total_length,
                               system_msg=hparams.opening_message,
                               lesson_ix=len(logit_train_file_list),
                               logit_collate_fn=partial(logit_train_dataset.collate_fn,
                                                        padding_value=0,
                                                        llm=base_llm),
                               use_batch_size=hparams.tulu_batch_size,
                               ),
        )
        tulu_dataloader = accelerator.prepare(tulu_dataloader)
        iter_tulu_dataloader = cycle(iter(tulu_dataloader))
    else:
        iter_tulu_dataloader = None

    warmup_steps = 0
    if hasattr(hparams, 'warmup_steps') and hparams.warmup_steps is not None:
        hparams.warmup_steps = hparams.warmup_steps
    elif hasattr(hparams, 'warmup_ratio'):
        hparams.warmup_steps = int(hparams.warmup_ratio * max_steps)
    if hparams.verbose:
        print(f"Learning rate warmup: {hparams.warmup_steps} steps", flush=True)

    train_t0 = time.perf_counter()
    student.train()

    iter_logit_train_dataloader = iter(logit_train_dataloader) if logit_train_dataloader is not None else None
    iter_token_train_dataloader = iter(token_train_dataloader) if token_train_dataloader is not None else None

    if hparams.eval_interval < 0:
        hparams.eval_interval = n_batches
        print(f"Setting evaluation interval to {hparams.eval_interval}", flush=True)
    if hparams.generate and hparams.generation_interval < 0:
        hparams.generation_interval = n_batches
        print(f"Setting generation interval to {hparams.generation_interval}", flush=True)

    last_checkpoint_time = time.time()
    n_saves = 0
    all_losses = []
    step_times = []

    for step in range(max_steps):
        # We start with validation to log the untrained model's performance
        if hparams.validate and step % hparams.eval_interval == 0:
            t0 = time.perf_counter()
            metrics_total, metrics_by_group = validate(student, teacher, token_val_dataloader, logit_val_dataloader, accelerator,
                                                       hparams.closed_book_token_loss, hparams)

            t1 = time.perf_counter() - t0
            if hparams.verbose:
                print("Validation results:", metrics_total, flush=True)

            log_to_wandb(accelerator, metrics_total, metrics_by_group, step, hparams)

        if hparams.generate and step % hparams.generation_interval == 0:
            # Select 1 random sample from logit_train_dataset
            if hparams.logit_loss_weight:
                ix = torch.randint(high=len(logit_train_dataset), size=(1,)).item()
                generation_samples = [logit_train_dataset[ix]]
            else:
                # token loss
                ix = torch.randint(high=len(token_train_dataset), size=(1,)).item()
                generation_samples = [token_train_dataset[ix]]
            generate_answers(student, base_llm, generation_samples, accelerator)

        is_logging = step % hparams.log_interval == 0
        train_metrics = {}

        update_lr(step, optimizer, train_metrics, hparams, is_logging, max_steps)

        step_t0 = time.perf_counter()
        # Process logit loss
        if hparams.logit_loss_weight:
            for _ in range(hparams.n_logit_micro_batches_per_batch):
                batch = next(iter_logit_train_dataloader)
                if iter_tulu_dataloader:
                    batch_tulu = next(iter_tulu_dataloader)
                    batch = merge_with_tulu_batch(batch, batch_tulu, padding_value=0)
                logit_loss = compute_logit_loss(batch, student, teacher, temperature=hparams.train_temperature, reverse_kl=hparams.reverse_kl, base_llm=base_llm)
                logit_loss = logit_loss.mean()
                loss = hparams.logit_loss_weight * logit_loss / hparams.n_logit_micro_batches_per_batch
                accelerator.backward(loss)

            if is_logging:
                train_metrics['logit_loss'] = logit_loss.item()  # expensive device-to-host synchronization
            
        # Process token loss
        if hparams.token_loss_weight:
            for _ in range(hparams.n_token_micro_batches_per_batch):
                batch = next(iter_token_train_dataloader)
                token_loss = compute_token_loss(batch, student, reduction="batch",
                                                closed_book_token_loss=hparams.closed_book_token_loss,
                                                base_llm=base_llm)
                loss = hparams.token_loss_weight * token_loss / hparams.n_token_micro_batches_per_batch
                accelerator.backward(loss)

            if is_logging and hparams.n_token_micro_batches_per_batch > 0:
                train_metrics['token_loss'] = token_loss.item()  # expensive device-to-host synchronization

        if hparams.max_grad_norm:
            accelerator.clip_grad_norm_(student.parameters(), hparams.max_grad_norm)
       
        torch.cuda.empty_cache()
        optimizer.step()
        optimizer.zero_grad()
        log_step(step, is_logging, accelerator, train_metrics, hparams, step_t0, max_steps, all_losses=all_losses, step_times=step_times)

        if hparams.save_during_training:
            current_time = time.time()
            if hparams.checkpoint_interval and (step + 1) % hparams.checkpoint_interval == 0:
                if hparams.deepspeed_path:
                    save_with_deepspeed(student, accelerator, base_llm, hparams.run_project_dir)
                else:
                    save_with_base_model_config(student, base_llm, hparams.run_project_dir)

            if hparams.checkpoint_interval_seconds and current_time - last_checkpoint_time >= hparams.checkpoint_interval_seconds:
                n_saves += 1
                save_idx = n_saves*hparams.checkpoint_interval_seconds
                checkpoint_subfolder = Path(os.path.join(hparams.run_project_dir, f"checkpoint_time_{save_idx}"))
                if hparams.deepspeed_path:
                    print("Saving deepspeed model")
                    save_with_deepspeed(student, accelerator, base_llm, checkpoint_subfolder)
                else:
                    save_with_base_model_config(student, base_llm, checkpoint_subfolder)
                last_checkpoint_time = current_time

    if hparams.save:
        if hparams.deepspeed_path:
            save_with_deepspeed(student, accelerator, base_llm, hparams.run_project_dir)
        else:
            save_with_base_model_config(student, base_llm, hparams.run_project_dir)

    print(f"Training time: {(time.perf_counter()-train_t0):.2f}s", flush=True)


def compute_token_loss(
    batch,
    model,
    reduction: Literal["batch", "sample"] = "batch",  # "batch" for training, "sample" for validation
    closed_book_token_loss: bool = False,
    base_llm: LLM = None,
):
    assert reduction in ["batch", "sample"]

    if closed_book_token_loss:
        inputs = batch['student_seqs'][..., :-1]  # (batch_size, seq_length)
        labels = batch['student_labels'][..., 1:]  # (batch_size, seq_length)
    else:
        inputs = batch['seqs'][..., :-1]  # (batch_size, seq_length)
        labels = batch['labels'][..., 1:]  # (batch_size, seq_length)

    batch_size, seq_length = inputs.shape
    output = model.forward(inputs)
    output_logits = output.logits
    token_loss = F.cross_entropy(
        output_logits.flatten(0, 1),
        labels.flatten(0, 1),
        ignore_index=IGNORE_INDEX,
        reduction="mean" if reduction == "batch" else "none"
    )
    if reduction == "sample":
        token_loss = token_loss.reshape(batch_size, seq_length).mean(-1)

    return token_loss


def compute_logit_loss(
    batch,
    student: AutoModelForCausalLM,
    teacher,
    temperature,
    reverse_kl: bool = False,
    base_llm: LLM = None,
):
    """
    Compute a KL-divergence-based distillation loss from teacher to student.
    """
    student_inputs = batch['student_seqs'][..., :-1]  # (batch_size, seq_length)
    student_labels = batch['student_labels'][..., 1:]  # (batch_size, seq_length)
    # The mask determines which outputs should be used for loss calculation
    student_masks = student_labels != IGNORE_INDEX

    batch_size, seq_length = student_inputs.shape

    if teacher == "student_base":
        teacher_context = student.disable_adapter()
        teacher = student
    elif teacher == "student":
        teacher_context = nullcontext()
        teacher = student
    else:
        teacher_context = nullcontext()

    with torch.no_grad(), teacher_context:
        teacher.eval()
        teacher_inputs = batch['teacher_seqs'][..., :-1]
        # The mask determines which outputs should be used for loss calculation
        teacher_masks = batch['teacher_masks'][..., 1:]
    
        teacher_output = teacher.forward(teacher_inputs)
        teacher_logits = teacher_output.logits[teacher_masks].detach()

        t_logits = teacher_logits / temperature  # (n_tokens, vocab_size)

        t_log_probs = F.log_softmax(t_logits, dim=-1).detach()

    student_output = student.forward(student_inputs)
    student_logits = student_output.logits
    student_logits = student_logits[student_masks]

    s_logits = student_logits / temperature  # (n_tokens, vocab_size)
    s_log_probs = F.log_softmax(s_logits, dim=-1)

    if reverse_kl:
        logit_loss = F.kl_div(
            t_log_probs, s_log_probs, log_target=True,
            reduction="none",
        )
    else:
        logit_loss = F.kl_div(
            s_log_probs, t_log_probs, log_target=True,
            reduction="none",
        )

    logit_loss_t = logit_loss.sum(-1)  # (n_tokens,)

    logit_loss_mx = torch.zeros(batch_size, seq_length, device=student_inputs.device, dtype=logit_loss_t.dtype)
    logit_loss_mx[student_masks] = logit_loss_t
    logit_loss = logit_loss_mx.sum(-1) / student_masks.sum(-1)  # (batch_size,)

    return logit_loss


def validate(student, teacher, token_dataloader, logit_dataloader, accelerator, closed_book_token_loss, hparams):
    student.eval()
    metrics_total = {}
    metrics_by_group = {}
    with torch.no_grad():
        if token_dataloader is not None:
            group_names = token_dataloader.dataset.lesson_names
            aggregator = Aggregator(group_names, accelerator.device)
            for _, batch in enumerate(token_dataloader, start=1):
                token_loss = compute_token_loss(batch, student, reduction="sample",
                                                closed_book_token_loss=closed_book_token_loss)
                metrics = {
                    "val_token_loss": token_loss,
                }
                aggregator.add_batch(batch["lesson_ixs"], metrics, accelerator)
            token_metrics_total, token_metrics_by_group = aggregator.get_average()
            metrics_total.update(token_metrics_total)
            metrics_by_group.update(token_metrics_by_group)

        if logit_dataloader is not None:
            group_names = logit_dataloader.dataset.lesson_names
            aggregator = Aggregator(group_names, accelerator.device)
            for _, batch in enumerate(logit_dataloader, start=1):
                logit_loss = compute_logit_loss(batch, student, teacher, temperature=1, reverse_kl=hparams.reverse_kl)
                metrics = {
                    "val_logit_loss": logit_loss,
                }
                aggregator.add_batch(batch["lesson_ixs"], metrics, accelerator)
            logit_metrics_total, logit_metrics_by_group = aggregator.get_average()
            metrics_total.update(logit_metrics_total)
            metrics_by_group.update(logit_metrics_by_group)

    student.train()
    return metrics_total, metrics_by_group
