""" training_utils.py
    Utilities for training models
    Developed as part of DeepThinking project
    July 2021
"""

import sys
from dataclasses import dataclass
from random import randrange

import torch
from icecream import ic
from tqdm import tqdm

from utils.testing_utils import get_predicted


# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702),
#     Too many local variables (R0914), Missing docstring (C0116, C0115, C0114),
#     Unused import (W0611).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115, C0114, W0611


@dataclass
class TrainingSetup:
    """Attributes to describe the training precedure"""

    optimizer: "typing.Any"
    scheduler: "typing.Any"
    warmup: "typing.Any"
    clip: "typing.Any"
    progressive_loss_weight: "typing.Any"
    min_n: "typing.Any"
    min_k: "typing.Any"
    max_iters: "typing.Any"
    problem: "typing.Any"
    inner_lr: "typing.Any"


def calculate_progressive_output(inputs, min_n, max_iters, min_k, net):
    # get features from n iterations to use as input
    n = randrange(min_n, max_iters - min_k + 1)

    # do k iterations using intermediate features as input
    k = randrange(min_k, max_iters - n + 1)

    if n > 0:
        _, interim_thought = net(inputs, k=n)
        interim_thought = interim_thought.detach()
    else:
        interim_thought = None

    outputs, _ = net(inputs, n=n, k=k, interim_thought=interim_thought)
    return outputs


def get_batch(data_loader, iterator):
    try:
        task_data = iterator.next()
    except StopIteration:
        iterator = iter(data_loader)
        task_data = iterator.next()

    inputs, labels = task_data
    return inputs, labels, iterator


def get_params_data(net):
    actual_params = []
    for _, param in net.named_parameters():
        actual_params.append(param.data)
    return actual_params


def load_params(net, weights):
    p = 0
    for _, param in net.named_parameters():
        param.data = weights[p]
        p += 1
    return net


def train(net, loaders, mode, train_setup, device, disable_tqdm=False):
    try:
        train_loss, acc = eval(f"train_{mode}")(net, loaders, train_setup, device, disable_tqdm)
    except NameError:
        print(f"{ic.format()}: train_{mode}() not implemented. Exiting.")
        sys.exit()
    return train_loss, acc


def train_progressive(net, loaders, train_setup, device, disable_tqdm):

    trainloader = loaders["train"]
    net.train()
    optimizer = train_setup.optimizer
    lr_scheduler = train_setup.scheduler
    warmup_scheduler = train_setup.warmup
    wt = train_setup.progressive_loss_weight
    min_n = train_setup.min_n
    min_k = train_setup.min_k
    max_iters = train_setup.max_iters
    problem = train_setup.problem
    clip = train_setup.clip
    criterion = torch.nn.CrossEntropyLoss(reduction="none")

    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, leave=False, disable=disable_tqdm)):
        inputs, targets = inputs.to(device), targets.to(device).long()
        targets = targets.view(targets.size(0), -1)
        if problem == "mazes":
            mask = inputs.view(inputs.size(0), inputs.size(1), -1).max(dim=1)[0] > 0

        optimizer.zero_grad()

        # get fully unrolled loss if wt is not 1 (if it is 1, this loss term is not used
        # so we save time by settign it equal to 0).
        outputs_max_iters, _ = net(inputs, k=max_iters)
        if wt != 1:
            outputs_max_iters = outputs_max_iters.view(outputs_max_iters.size(0),
                                                       outputs_max_iters.size(1), -1)
            loss_max_iters = criterion(outputs_max_iters, targets)
        else:
            loss_max_iters = torch.zeros_like(targets).float()

        # get progressive loss if wt is not 0 (if it is 0, this loss term is not used
        # so we save time by setting it equal to 0).
        if wt != 0:
            outputs = calculate_progressive_output(inputs, min_n, max_iters, min_k, net)
            outputs = outputs.view(outputs.size(0), outputs.size(1), -1)
            loss_progressive = criterion(outputs, targets)
        else:
            loss_progressive = torch.zeros_like(targets).float()

        if problem == "mazes":
            loss_max_iters = (loss_max_iters * mask)
            loss_max_iters = loss_max_iters[mask > 0]
            loss_progressive = (loss_progressive * mask)
            loss_progressive = loss_progressive[mask > 0]

        loss_max_iters_mean = loss_max_iters.mean()
        loss_progressive_mean = loss_progressive.mean()

        loss = (1 - wt) * loss_max_iters_mean + wt * loss_progressive_mean
        loss.backward()

        if clip is not None:
            torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()

        train_loss += loss.item()
        predicted = get_predicted(inputs, outputs_max_iters, problem)
        correct += torch.amin(predicted == targets, dim=[-1]).sum().item()
        total += targets.size(0)

    train_loss = train_loss / (batch_idx + 1)
    acc = 100.0 * correct / total

    lr_scheduler.step()
    warmup_scheduler.dampen()

    return train_loss, acc
