""" training.py
    Utilities for training models

    Borrowed from code for DeepThinking project
"""

from dataclasses import dataclass
import pdb
from random import randrange

import torch
from icecream import ic
from tqdm import tqdm

from deepthinking.utils.testing import get_predicted

@dataclass
class TrainingSetup:
    """Attributes to describe the training precedure"""
    optimizer: "typing.Any"
    scheduler: "typing.Any"
    warmup: "typing.Any"
    clip: "typing.Any"
    alpha: "typing.Any"
    max_iters: "typing.Any"
    problem: "typing.Any"


def get_output_for_prog_loss(inputs, max_iters, net):
    # get features from n iterations to use as input
    n = randrange(0, max_iters)

    # do k iterations using intermediate features as input
    k = randrange(1, max_iters - n + 1)

    if n > 0:
        _, interim_thought = net(inputs, iters_to_do=n)
        interim_thought = interim_thought.detach()
    else:
        interim_thought = None

    outputs, _ = net(inputs, iters_elapsed=n, iters_to_do=k, interim_thought=interim_thought)
    return outputs, k


def train(net, loaders, mode, train_setup, device, **kwargs):
    if mode == "progressive":
        train_loss, acc = train_progressive(net, loaders, train_setup, device)
    elif mode == 'deq':
        train_loss, acc = train_deq(net, loaders, train_setup, device, **kwargs)
    else:
        raise ValueError(f"{ic.format()}: train_{mode}() not implemented.")
    return train_loss, acc


def train_progressive(net, loaders, train_setup, device):
    trainloader = loaders["train"]
    net.train()
    optimizer = train_setup.optimizer
    lr_scheduler = train_setup.scheduler
    warmup_scheduler = train_setup.warmup
    alpha = train_setup.alpha
    max_iters = train_setup.max_iters
    k = 0
    problem = train_setup.problem
    clip = train_setup.clip
    criterion = torch.nn.CrossEntropyLoss(reduction="none")

    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, leave=False)):
        inputs, targets = inputs.to(device), targets.to(device).long()
        targets = targets.view(targets.size(0), -1)
        if problem == "mazes":
            mask = inputs.view(inputs.size(0), inputs.size(1), -1).max(dim=1)[0] > 0

        optimizer.zero_grad()

        # get fully unrolled loss if alpha is not 1 (if it is 1, this loss term is not used
        # so we save time by settign it equal to 0).
        outputs_max_iters, _ = net(inputs, iters_to_do=max_iters)
        if alpha != 1:
            outputs_max_iters = outputs_max_iters.view(outputs_max_iters.size(0),
                                                       outputs_max_iters.size(1), -1)
            loss_max_iters = criterion(outputs_max_iters, targets)
        else:
            loss_max_iters = torch.zeros_like(targets).float()

        # get progressive loss if alpha is not 0 (if it is 0, this loss term is not used
        # so we save time by setting it equal to 0).
        if alpha != 0:
            outputs, k = get_output_for_prog_loss(inputs, max_iters, net)
            outputs = outputs.view(outputs.size(0), outputs.size(1), -1)
            loss_progressive = criterion(outputs, targets)
        else:
            loss_progressive = torch.zeros_like(targets).float()

        if problem == "mazes":
            loss_max_iters = (loss_max_iters * mask)
            loss_max_iters = loss_max_iters[mask > 0]
            loss_progressive = (loss_progressive * mask)
            loss_progressive = loss_progressive[mask > 0]

        loss_max_iters_mean = loss_max_iters.mean()
        loss_progressive_mean = loss_progressive.mean()

        loss = (1 - alpha) * loss_max_iters_mean + alpha * loss_progressive_mean
        loss.backward()

        if clip is not None:
            torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()

        train_loss += loss.item()
        predicted = get_predicted(inputs, outputs_max_iters, problem)
        correct += torch.amin(predicted == targets, dim=[-1]).sum().item()
        total += targets.size(0)

    train_loss = train_loss / (batch_idx + 1)
    acc = 100.0 * correct / total

    lr_scheduler.step()
    warmup_scheduler.dampen()

    return train_loss, acc

def train_deq(net, loaders, train_setup, device, **kwargs):
    cfg = kwargs.get("cfg")
    
    trainloader = loaders["train"]
    net.train()
    optimizer = train_setup.optimizer
    lr_scheduler = train_setup.scheduler
    warmup_scheduler = train_setup.warmup

    problem = train_setup.problem
    clip = train_setup.clip
    criterion = torch.nn.CrossEntropyLoss(reduction="none")

    train_loss = 0
    correct = 0
    total = 0
    jac_loss = 0
    factor = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, leave=False)):
        inputs, targets = inputs.to(device), targets.to(device).long()
        targets = targets.view(targets.size(0), -1)
        if problem == "mazes":
            mask = inputs.view(inputs.size(0), inputs.size(1), -1).max(dim=1)[0] > 0

        if cfg.problem.deq.loss.jac_loss:
            current_epoch = lr_scheduler._step_count-1
            # compute jacobian loss weight (which is dynamically scheduled)
            deq_steps = current_epoch - cfg.problem.train.pretrain_steps
            if deq_steps < 0:
                # We can also regularize output Jacobian when pretraining
                factor = cfg.problem.deq.loss.pretrain_jac_loss_weight
            elif current_epoch >= cfg.problem.deq.loss.jac_stop_epoch:
                # If are above certain epoch, we may want to stop jacobian regularization training
                # (e.g., when the original loss is 0.01 and jac loss is 0.05, the jacobian regularization
                # will be dominating and hurt performance!)
                factor = 0
            else:
                # Dynamically schedule the Jacobian reguarlization loss weight, if needed
                factor = cfg.problem.deq.loss.jac_loss_weight + 0.1 * (deq_steps // cfg.problem.deq.loss.jac_loss_incremental)
            
            compute_jac_loss = (torch.rand([]).item() < cfg.problem.deq.loss.jac_loss_freq) and (factor > 0)

        optimizer.zero_grad()
        if cfg.problem.deq.loss.jac_loss:
            outputs_max_iters, jac_loss = net(inputs, train_step=lr_scheduler._step_count-1, compute_jac_loss=compute_jac_loss)
            jac_loss = factor * jac_loss.mean()
        elif cfg.problem.deq.loss.layer_loss:
            outputs_max_iters, interm_outputs = net(inputs, train_step=lr_scheduler._step_count-1)
            total_layer_loss = 0
            for idx, interm_output in enumerate(interm_outputs):
                interm_output = interm_output.view(interm_output.size(0),
                                                    interm_output.size(1), -1)
                interm_layer_loss = criterion(interm_output, targets)
                if problem == "mazes":
                    interm_layer_loss = (interm_layer_loss * mask)
                    interm_layer_loss = interm_layer_loss[mask > 0]

                layer_loss = interm_layer_loss.mean()
                total_layer_loss += cfg.problem.deq.loss.gamma[idx] * layer_loss
        else:
            outputs_max_iters, _ = net(inputs, train_step=lr_scheduler._step_count-1)
        
        outputs_max_iters = outputs_max_iters.view(outputs_max_iters.size(0),
                                                    outputs_max_iters.size(1), -1)
        loss_max_iters = criterion(outputs_max_iters, targets)
        if problem == "mazes":
            loss_max_iters = (loss_max_iters * mask)
            loss_max_iters = loss_max_iters[mask > 0]

        loss_max_iters_mean = loss_max_iters.mean()
        loss = loss_max_iters_mean
        
        if cfg.problem.deq.loss.jac_loss:
            loss += jac_loss
        
        if cfg.problem.deq.loss.layer_loss:
            loss += total_layer_loss

        loss.backward()

        if clip is not None:
            torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()

        train_loss += loss.item()
        predicted = get_predicted(inputs, outputs_max_iters, problem)
        correct += torch.amin(predicted == targets, dim=[-1]).sum().item()
        total += targets.size(0)

    train_loss = train_loss / (batch_idx + 1)
    acc = 100.0 * correct / total

    lr_scheduler.step()
    warmup_scheduler.dampen()

    return train_loss, acc
 