""" training.py
    Utilities for training models

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

from dataclasses import dataclass
from random import randrange

import torch
import torch.nn as nn
from icecream import ic
from tqdm import tqdm

from deepthinking.utils.testing import get_predicted
from deepthinking.utils.rotation import rotate_batch


# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702),
#     Too many local variables (R0914), Missing docstring (C0116, C0115, C0114),
#     Unused import (W0611).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115, C0114, W0611


@dataclass
class TrainingSetup:
    """Attributes to describe the training precedure"""
    optimizer: "typing.Any"
    scheduler: "typing.Any"
    warmup: "typing.Any"
    clip: "typing.Any"
    alpha: "typing.Any"
    beta: "typing.Any"
    gama: "typing.Any"
    act: "typing.Any"
    max_iters: "typing.Any"
    problem: "typing.Any"
    self_verify: "typing.Any"
    train_only_verification: "typing.Any"
    focal_loss_alpha: "typing.Any"
    focal_loss_gamma: "typing.Any"


def get_output_for_prog_loss(inputs, max_iters, net, use_act=False):
    # get features from n iterations to use as input
    n = randrange(0, max_iters)

    # do k iterations using intermediate features as input
    k = randrange(1, max_iters - n + 1)

    if n > 0:
        if not use_act:
            _, interim_thought, _ = net(inputs, iters_to_do=n)
        else:
            _, interim_thought, _, _ = net(inputs, iters_to_do=n)
        interim_thought = interim_thought.detach()
    else:
        interim_thought = None

    if not use_act:
        outputs, _, ssh_out = net(inputs, iters_elapsed=n, iters_to_do=k,
                        interim_thought=interim_thought)
    else:
        outputs, _, ssh_out, _ = net(inputs, iters_elapsed=n, iters_to_do=k,
                        interim_thought=interim_thought)
    return outputs, k, ssh_out


def eval_train(net, loaders, mode, train_setup, device):
    if mode == "progressive":
        train_loss, acc = eval_train_progressive(
            net, loaders, train_setup, device)
    else:
        raise ValueError(f"{ic.format()}: train_{mode}() not implemented.")
    return train_loss, acc


def eval_train_progressive(net, loaders, train_setup, device):
    trainloader = loaders["train"]
    net.train()
    alpha = train_setup.alpha
    max_iters = train_setup.max_iters
    k = 0
    problem = train_setup.problem
    clip = train_setup.clip
    criterion = torch.nn.CrossEntropyLoss(reduction="none")

    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, leave=False)):
        inputs, targets = inputs.to(device), targets.to(device).long()
        targets = targets.view(targets.size(0), -1)
        if problem == "mazes":
            mask = inputs.view(inputs.size(
                0), inputs.size(1), -1).max(dim=1)[0] > 0

        # get fully unrolled loss if alpha is not 1 (if it is 1, this loss term is not used
        # so we save time by settign it equal to 0).
        outputs_max_iters, _ = net(inputs, iters_to_do=max_iters)
        if alpha != 1:
            outputs_max_iters = outputs_max_iters.view(outputs_max_iters.size(0),
                                                       outputs_max_iters.size(1), -1)
            loss_max_iters = criterion(outputs_max_iters, targets)
        else:
            loss_max_iters = torch.zeros_like(targets).float()

        # get progressive loss if alpha is not 0 (if it is 0, this loss term is not used
        # so we save time by setting it equal to 0).
        if alpha != 0:
            outputs, k = get_output_for_prog_loss(inputs, max_iters, net, use_act)
            outputs = outputs.view(outputs.size(0), outputs.size(1), -1)
            loss_progressive = criterion(outputs, targets)
        else:
            loss_progressive = torch.zeros_like(targets).float()

        if problem == "mazes":
            loss_max_iters = (loss_max_iters * mask)
            loss_max_iters = loss_max_iters[mask > 0]
            loss_progressive = (loss_progressive * mask)
            loss_progressive = loss_progressive[mask > 0]

        loss_max_iters_mean = loss_max_iters.mean()
        loss_progressive_mean = loss_progressive.mean()

        loss = (1 - alpha) * loss_max_iters_mean + \
                alpha * loss_progressive_mean

        train_loss += loss.item()
        predicted = get_predicted(inputs, outputs_max_iters, problem)
        correct += torch.amin(predicted == targets, dim=[-1]).sum().item()
        total += targets.size(0)

    train_loss = train_loss / (batch_idx + 1)
    acc = 100.0 * correct / total

    return train_loss, acc

def train(net, loaders, mode, train_setup, device):
    if mode == "progressive":
        train_loss, acc, ssh_acc, cls_loss, ssh_loss, veri_loss = train_progressive(net, loaders, train_setup, device)
    else:
        raise ValueError(f"{ic.format()}: train_{mode}() not implemented.")
    return train_loss, acc, ssh_acc, cls_loss, ssh_loss, veri_loss

def train_progressive(net, loaders, train_setup, device):
    trainloader = loaders["train"]
    net.train()
    optimizer = train_setup.optimizer
    lr_scheduler = train_setup.scheduler
    warmup_scheduler = train_setup.warmup
    alpha = train_setup.alpha
    beta = train_setup.beta
    gama = train_setup.gama
    act = train_setup.act
    self_verify = train_setup.self_verify
    max_iters = train_setup.max_iters
    k = 0
    problem = train_setup.problem
    clip = train_setup.clip
    # ssh_lambda = train_setup.ssh_lambda
    criterion = torch.nn.CrossEntropyLoss(reduction="none")
    criterion_verification = nn.BCELoss(reduction='none')

    theta = act  # for act
    use_act = True if act > 0 else False
    use_self_verification = True if self_verify > 0 else False
    
    train_loss = 0
    correct = 0
    ssh_correct = 0
    total = 0
    cls_train_loss = 0 
    ssh_train_loss = 0
    veri_train_loss = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, leave=False)):
        inputs, targets = inputs.to(device), targets.to(device).long()
        batch_size = inputs.shape[0]
        if problem != "sudoku":
            ssh_inputs, ssh_labels = rotate_batch(inputs, 'rand')
        else:
            ssh_inputs, ssh_labels = inputs, targets
        ssh_inputs, ssh_labels = ssh_inputs.to(device), ssh_labels.to(device).long()
        ssh_labels = ssh_labels.view(ssh_labels.size(0), -1)
        targets = targets.view(targets.size(0), -1)
        if problem == "mazes":
            mask = inputs.view(inputs.size(0), inputs.size(1), -1).max(dim=1)[0] > 0
        elif problem == "sudoku":
            mask = inputs[:, 0, ...].reshape(batch_size, -1)

        optimizer.zero_grad()

        # get fully unrolled loss if alpha is not 1 (if it is 1, this loss term is not used
        # so we save time by settign it equal to 0).
        if use_act:
            outputs_max_iters, _, _, sum_pt = net(inputs, iters_to_do=max_iters)
            _, _, ssh_outputs_max_iter, _ = net(ssh_inputs, iters_to_do=max_iters)
        else:
            outputs_max_iters, _, _ = net(inputs, iters_to_do=max_iters)
            _, _, ssh_outputs_max_iter = net(ssh_inputs, iters_to_do=max_iters)
            
        if alpha != 1:
            outputs_max_iters = outputs_max_iters.view(outputs_max_iters.size(0),
                                                       outputs_max_iters.size(1), -1)
            loss_max_iters = criterion(outputs_max_iters, targets)
            ssh_outputs_max_iter = ssh_outputs_max_iter.view(ssh_outputs_max_iter.size(0),
                                                       ssh_outputs_max_iter.size(1), -1)
            loss_ssh_max_iters = criterion(ssh_outputs_max_iter, ssh_labels.long())
        else:
            loss_max_iters = torch.zeros_like(targets).float()
            loss_ssh_max_iters = torch.zeros_like(ssh_labels).float()

        # get progressive loss if alpha is not 0 (if it is 0, this loss term is not used
        # so we save time by setting it equal to 0).
        if alpha != 0:
            outputs, k, ssh_out = get_output_for_prog_loss(inputs, max_iters, net, use_act)
            outputs = outputs.view(outputs.size(0), outputs.size(1), -1)
            ssh_out = ssh_out.view(ssh_out.size(0), ssh_out.size(1), -1)
            loss_progressive = criterion(outputs, targets)
            loss_ssh = criterion(ssh_out, ssh_labels)
        else:
            loss_progressive = torch.zeros_like(targets).float()
            loss_ssh = torch.zeros_like(targets).float()

        # if problem == "mazes" or problem == "sudoku":
        if problem == "mazes":
            loss_max_iters = (loss_max_iters * mask)
            loss_max_iters = loss_max_iters[mask > 0]
            loss_progressive = (loss_progressive * mask)
            loss_progressive = loss_progressive[mask > 0]
            

        loss_max_iters_mean = loss_max_iters.mean()
        loss_progressive_mean = loss_progressive.mean()

        cls_loss = (1 - alpha) * loss_max_iters_mean + alpha * loss_progressive_mean
        ssh_loss = (1 - alpha) * loss_ssh_max_iters.mean() + alpha * loss_ssh.mean()
        loss = (1 - beta) * cls_loss + beta * ssh_loss + gama * (cls_loss - ssh_loss) ** 2

        if use_act:
            loss = loss -  theta * sum_pt.mean()
            
        
        loss.backward()
        cls_loss = cls_loss.item()
        cls_train_loss += cls_loss
        ssh_loss = ssh_loss.item()
        ssh_train_loss += ssh_loss
        
        if batch_idx % 50 == 0:
            print(f"cls_loss = {cls_loss}\t\tssh_loss = {ssh_loss}")
            if use_act:
                act_loss = sum_pt.mean().item()
                print(f"act_loss = {act_loss}")
        if clip is not None:
            torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()

        train_loss += loss.item()
        predicted = get_predicted(inputs, outputs_max_iters, problem)
        ssh_predicted = get_predicted(ssh_inputs, ssh_outputs_max_iter, problem)
        correct += torch.amin(predicted == targets, dim=[-1]).sum().item()
        ssh_correct += torch.amin(ssh_predicted == ssh_labels, dim=[-1]).sum().item()
        total += targets.size(0)


    train_loss = train_loss / (batch_idx + 1)
    cls_train_loss = cls_train_loss / (batch_idx + 1)
    ssh_train_loss = ssh_train_loss / (batch_idx + 1)
    veri_train_loss = veri_train_loss / (batch_idx + 1)
    
    acc = 100.0 * correct / total
    ssh_acc = 100.0 * ssh_correct / total

    lr_scheduler.step()
    warmup_scheduler.dampen()
    
    return train_loss, acc, ssh_acc, cls_train_loss, ssh_train_loss, veri_train_loss
