"""Implement training routines for sanity checks."""

"""interfaces:
.search(model, trainloader, metaloader, validloader, defs)
.evaluate(model, trainloader, validloader, defs)
"""

import torch

from collections import defaultdict
from dataclasses import dataclass
import time
import warnings
import logging

# from .analyzer import Analyzer
from .scheduler import GradualWarmupScheduler
from ._alpha_step import _fill_alpha_gradients_liu, _fill_alpha_gradients_binaryconnect
from ._alpha_step import _fill_alpha_gradients_higher, _fill_alpha_gradients_single
from ._alpha_step import _loss_evaluation

from .optim import MirrorDescent, MirrorAdam
from .utils import psnr_compute


@dataclass
class State:
    """Optimization State."""

    param_optimizer: torch.optim.Optimizer
    alpha_optimizer: torch.optim.Optimizer
    param_scheduler: torch.optim.lr_scheduler._LRScheduler
    alpha_scheduler: torch.optim.lr_scheduler._LRScheduler
    epoch: int


def search(model, trainloader, metaloader, validloader, defs, callback=None):
    """Search for optimal genotype for the given search space defined within 'model'."""
    param_optimizer, alpha_optimizer, param_scheduler, alpha_scheduler = set_optimizer(model, defs)

    state = State(param_optimizer, alpha_optimizer, param_scheduler, alpha_scheduler, epoch=0)
    stats = defaultdict(list)

    for state.epoch in range(defs.epochs):
        train_epoch(model, trainloader, metaloader, state, defs, stats)

        if param_scheduler is not None:
            param_scheduler.step()
        if alpha_scheduler is not None:
            alpha_scheduler.step()

        if param_optimizer is None and alpha_optimizer is None:
            break

        if defs.decay_entropy and alpha_optimizer is not None:
            if defs.mirror_descent:
                for param_group in alpha_optimizer.param_groups:
                    param_group['entropy'] = 1 - 2 * state.epoch / defs.epochs
                print(f'Entropy regularization is now set to {param_group["entropy"]}.')
            else:
                model.temperature *= 1.15
                print(f'Model temperature is now {model.temperature}.')

        if callback is not None and state.epoch % callback == 0:
            validate_model(model, validloader, defs, stats)
            validate_genotype(model, validloader, defs, stats)
            # Log the status after every epoch:
            status = training_status(state, stats)
            status += f'Mean entropy is {model.mean_entropy()}. Normalized entropy: {model.normalized_entropy():6.2%}\n'
            # status += f'Max EV is {max(eigenvals)}. Min EV is: {min(eigenvals)}\n'
            status += f'Current genotype: {model.genotype()}'
            logging.info(status)

        if not torch.as_tensor(stats['train_loss'][-1]).isfinite() or defs.dryrun:
            break

    validate_model(model, validloader, defs, stats)
    validate_genotype(model, validloader, defs, stats)
    # Log the status after every epoch:
    status = training_status(state, stats)
    status += f'Mean entropy is {model.mean_entropy()}. Normalized entropy: {model.normalized_entropy():6.2%}\n'
    # status += f'Max EV is {max(eigenvals)}. Min EV is: {min(eigenvals)}\n'
    status += f'Current genotype: {model.genotype()}'
    logging.info(status)

    # Compute final genotype:
    genotype_final = model.genotype()

    return genotype_final, stats


def evaluate(model, trainloader, validloader, defs):
    """Retrain a given optimal genotype for the given architecture defined within 'model' arch_parameters()."""
    # Retrain:
    param_optimizer, _, param_scheduler, _ = set_optimizer(model, defs)

    state = State(param_optimizer, None, param_scheduler, None, epoch=0)
    stats = defaultdict(list)

    for state.epoch in range(defs.epochs):
        train_genotype(model, state, trainloader, defs, stats, key='train')
        # Scheduling learning rates and entropy parameter
        if param_scheduler is not None:
            param_scheduler.step()
        if not torch.as_tensor(stats['train_loss'][-1]).isfinite() or defs.dryrun:
            print('Training non-finite - cancelling')
            break
        if state.param_optimizer is None:
            # if not parameters are trained, break after seeing one epoch
            break

    validate_genotype(model, validloader, defs, stats, key='valid')
    # Log the status after every epoch:
    status = training_status(state, stats, tested=False)
    logging.info(status)

    return stats


def train_epoch(model, trainloader, metaloader, state, defs, stats):
    """Train for a single epoch."""
    epoch_time = time.time()  # roughly estimate training time per epoch
    epoch_loss, epoch_psnr = 0, 0

    # Train
    model.train()

    for batch, (data_train, data_meta) in enumerate(zip(trainloader, metaloader)):

        data_train = [t.to(device=model.setup['device']) for t in data_train]
        data_meta = [t.to(device=model.setup['device']) for t in data_meta]

        # Minibatching
        loss, psnr = _train_step(model, state, data_train, data_meta, defs)
        # print(loss)
        epoch_loss += loss
        epoch_psnr += psnr

        if not torch.as_tensor(epoch_loss).isfinite():
            warnings.warn('Loss is NaN/Inf ... terminating early ...')
            break
        if defs.dryrun:
            break

    stats['train_time'].append(time.time() - epoch_time)
    stats['train_loss'].append(epoch_loss / (batch + 1))
    stats['train_psnr'].append(epoch_psnr / (batch + 1))


def _train_step(model, state, data_train, data_meta, defs):
    """Take a single minibatch step and record results."""
    loss = None
    # Alpha update step
    if state.alpha_optimizer is not None:
        state.alpha_optimizer.zero_grad()

        if state.param_optimizer is not None:  # If we are already training parameters, then the approximate Hessians make sense
            state.param_optimizer.zero_grad()
            # Use the approx. Hessian from the original DARTS paper.
            if defs.update in ['liu', 'liu-1th']:
                _fill_alpha_gradients_liu(model, state, data_train, data_meta)
            elif defs.update in ['liu-0th', 'alternating']:
                loss = _loss_evaluation(model, *data_meta)
                loss.backward()
            elif defs.update == 'liu-single':
                _fill_alpha_gradients_single(model, state, data_meta)
            elif defs.update == 'higher':
                # Use a higher-order approximation:
                _fill_alpha_gradients_higher(model, state, data_train, data_meta, iterations=2)
            elif defs.update == 'binary-connect':
                # Binary connect is a stand-in for Bayesian sampling approaches
                _fill_alpha_gradients_binaryconnect(model, state, data_train)
        else:  # otherwise we always take a simple step
            loss, psnr = _loss_evaluation(model, *data_meta)
            loss.backward()
        state.alpha_optimizer.step()

    # Parameter update step
    if state.param_optimizer is not None:
        state.param_optimizer.zero_grad()

        loss, psnr = _loss_evaluation(model, *data_train)
        loss.backward()
        state.param_optimizer.step()

    elif loss is None:
        with torch.no_grad():
            loss, psnr = _loss_evaluation(model, *data_train)

    # Project if projective methods are registered.
    with torch.no_grad():
        if defs.project:
            model.project_onto_constraint()

    return loss.item(), psnr.item()


def validate_model(model, validloader, defs, stats, key='valid'):
    """Validate the full (non-truncated) model."""
    model.eval()
    epoch_loss, epoch_psnr = 0, 0

    for batch, data in enumerate(validloader):
        data = [t.to(device=model.setup['device']) for t in data]
        with torch.no_grad():
            loss, psnr = _loss_evaluation(model, *data)
            epoch_loss += loss.item()
            epoch_psnr += psnr.item()

        if defs.dryrun:
            break

    stats[f'{key}_loss'].append(epoch_loss / (batch + 1))
    stats[f'{key}_psnr'].append(epoch_psnr / (batch + 1))


def validate_genotype(model, validloader, defs, stats, key='test'):
    """Validate performance of current genotype, i.e. binary weights."""
    # Compute final genotype:
    model.eval()
    epoch_loss, epoch_psnr = 0, 0

    for batch, data in enumerate(validloader):
        data = [t.to(device=model.setup['device']) for t in data]
        with torch.no_grad():
            outputs, aux_loss = model.forward_argmax(data[0], x_true=data[1])
            loss = model.criterion(outputs, data[1])
            psnr = psnr_compute(outputs, data[1])
            if aux_loss is not None:
                loss = (loss + aux_loss) / model.layers
            epoch_psnr += psnr.item()
            epoch_loss += loss.item()
        if defs.dryrun:
            break

    stats[f'{key}_loss'].append(epoch_loss / (batch + 1))
    stats[f'{key}_psnr'].append(epoch_psnr / (batch + 1))


def train_genotype(model, state, trainloader, defs, stats, key='finetune'):
    """Train with fixed alpha. Useful for finetuning with binarized alpha or retraining."""
    epoch_time = time.time()  # roughly estimate training time per epoch
    model.train()
    epoch_loss, epoch_psnr = 0, 0

    for batch, data in enumerate(trainloader):
        data = [t.to(device=model.setup['device']) for t in data]
        if state.param_optimizer is not None:
            # Parameter update step
            state.param_optimizer.zero_grad()
            outputs, aux_loss = model.forward_argmax(data[0], x_true=data[1])
            loss = model.criterion(outputs, data[1])
            psnr = psnr_compute(outputs, data[1])
            if aux_loss is not None:
                loss = (loss + aux_loss) / model.layers
            loss.backward()
            state.param_optimizer.step()
        else:
            with torch.no_grad():
                outputs, aux_loss = model.forward_argmax(data[0], x_true=data[1])
                loss = model.criterion(outputs, data[1])
                psnr = psnr_compute(outputs, data[1])
                if aux_loss is not None:
                    loss = (loss + aux_loss) / model.layers

        epoch_loss += loss.item()
        epoch_psnr += psnr.item()
        if defs.dryrun:
            break

    stats['train_time'].append(time.time() - epoch_time)
    stats[f'{key}_loss'].append(epoch_loss / (batch + 1))
    stats[f'{key}_psnr'].append(epoch_psnr / (batch + 1))


def set_optimizer(model, defs):
    """If necessary, construct an optimizer."""
    if len(list(model.parameters())) > 0:
        param_optimizer, param_scheduler = _retrieve_optimizer(model.parameters(), defs.param_optimizer, defs.param_lr,
                                                               defs.param_weight_decay, defs.param_scheduler,
                                                               defs.epochs, warmup=defs.param_warmup)
    else:
        param_optimizer, param_scheduler = None, None

    if len(list(model.arch_parameters())) > 0 and not any([a.numel() <= 1 for a in model.arch_parameters()]):
        alpha_optimizer, alpha_scheduler = _retrieve_optimizer(model.arch_parameters(), defs.alpha_optimizer, defs.alpha_lr,
                                                               defs.alpha_weight_decay, defs.alpha_scheduler,
                                                               defs.epochs, warmup=defs.alpha_warmup)
    else:
        alpha_optimizer, alpha_scheduler = None, None

    return param_optimizer, alpha_optimizer, param_scheduler, alpha_scheduler


def _retrieve_optimizer(parameter_iterable, optim, lr, weight_decay, schedule, epochs, warmup=False):
    """Build model optimizer and scheduler from defs.

    The linear scheduler drops the learning rate in intervals.
    # Example: epochs=160 leads to drops at 60, 100, 140.
    """
    if warmup:
        lr = lr * 0.1
    if optim in ['SGD', 'GD']:
        optimizer = torch.optim.SGD(parameter_iterable, lr=lr, momentum=0.9,
                                    weight_decay=weight_decay, nesterov=True)
    elif optim == 'Adam':
        optimizer = torch.optim.Adam(parameter_iterable, lr=lr, weight_decay=weight_decay)
    elif optim == 'MirrorAdam':
        optimizer = MirrorAdam(parameter_iterable, lr=lr, betas=(0.5, 0.999), weight_decay=0, entropy=0.0)
    elif optim == 'MirrorDescent':
        optimizer = MirrorDescent(parameter_iterable, lr=lr, momentum=0.0, weight_decay=0.0, entropy=0.0)
    elif optim == 'MirrrorNesterov':
        optimizer = MirrorDescent(parameter_iterable, lr=lr, momentum=0.9, nesterov=True, weight_decay=0.0, entropy=0.0)
    else:
        raise ValueError(f'Invalid optimizer definition {optim} given.')

    if schedule == 'linear':
        # Drop at 5/8, 6/8, 7/8
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50 // 2.667, 50 // 1.6,
                                                                                50 // 1.142], gamma=0.1)
    elif schedule == 'cosine-decay-1':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 25, eta_min=lr / 25)
    elif schedule == 'cosine-decay-2':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 25, eta_min=0.0)
    else:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100_000], gamma=1.0)

    if warmup:
        scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=10, after_scheduler=scheduler)

    return optimizer, scheduler


def training_status(state, stats, tested=True):
    """A basic console printout."""
    try:
        current_lr = f'{state.param_optimizer.param_groups[0]["lr"]:.4f}'
    except AttributeError:
        current_lr = '-'
    try:
        arch_lr = f'{state.alpha_optimizer.param_groups[0]["lr"]:.4f}'
    except AttributeError:
        arch_lr = '-'

    msg = f'Epoch: {state.epoch}| param lr: {current_lr} | arch lr: {arch_lr} | Time: {stats["train_time"][-1]:.0f}s | \n'
    msg += f'TRAIN loss {stats["train_loss"][-1]:6.4f} | TRAIN PSNR {stats["train_psnr"][-1]:.3f} |'
    msg += f'VAL loss {stats["valid_loss"][-1]:6.4f} | VAL PSNR {stats["valid_psnr"][-1]:.3f} |'
    msg += f'Trunc. loss {stats["test_loss"][-1]:6.4f} | Trunc. PSNR {stats["test_psnr"][-1]:.3f}| \n' if tested else ""
    return msg
