import time

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

from approaches.pgn import pnn_utils


class Appr(object):

    def __init__(self, device: str, model: nn.Module,
                 nepochs: int, sbatch: int,
                 lr: float, lr_min: float, lr_factor: float,
                 lr_patience: int, clipgrad=10000, args=None):
        self.device = device
        self.model = model

        self.nepochs = nepochs
        self.sbatch = sbatch
        self.lr = lr
        self.lr_min = lr_min
        self.lr_factor = lr_factor
        self.lr_patience = lr_patience
        self.clipgrad = clipgrad

        self.criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = self._get_optimizer(self.lr)

        return

    def _get_optimizer(self, lr: float):
        return torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=lr)
    # enddef

    def train(self, t: int, dl_train: DataLoader, dl_val: DataLoader):
        best_loss = np.inf
        best_model = pnn_utils.get_model(self.model)
        lr = self.lr
        patience = self.lr_patience

        # train only the column for the current task
        self.model.unfreeze_column(t)

        # the optimizer trains solely the params for the current task
        self.optimizer = self._get_optimizer(lr)

        # Loop epochs
        for e in range(self.nepochs):
            # Train
            clock0 = time.time()
            self.train_epoch(t, dl_train)
            clock1 = time.time()
            train_loss, train_acc = self.eval(t, dl_train)
            clock2 = time.time()
            print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format(
                e + 1,
                1000 * self.sbatch * (clock1 - clock0) / len(dl_train.dataset),
                1000 * self.sbatch * (clock2 - clock1) / len(dl_train.dataset),
                train_loss, 100 * train_acc), end='')
            # Valid
            valid_loss, valid_acc = self.eval(t, dl_val)
            print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, 100 * valid_acc), end='')
            # Adapt lr
            if valid_loss < best_loss:
                best_loss = valid_loss
                best_model = pnn_utils.get_model(self.model)
                patience = self.lr_patience
                print(' *', end='')
            else:
                patience -= 1
                if patience <= 0:
                    lr /= self.lr_factor
                    print(' lr={:.1e}'.format(lr), end='')
                    if lr < self.lr_min:
                        print()
                        break
                    # endif

                    patience = self.lr_patience
                    self.optimizer = self._get_optimizer(lr)
                # endif
            # endif

            print()
        # endfor

        # Restore best
        pnn_utils.set_model_(self.model, best_model)

        return
    # enddef

    def train_epoch(self, t: int, dl_train: DataLoader):
        self.model.train()

        # Loop batches
        for x, y in dl_train:
            images = x.to(self.device)
            targets = y.to(self.device)

            # Forward
            outputs = self.model.forward(images, t)
            output = outputs[t]
            loss = self.criterion(output, targets)

            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.model.parameters(), self.clipgrad)
            self.optimizer.step()
        # endfor

        return
    # enddef

    def eval(self, t: int, dl: DataLoader):
        total_loss = 0
        total_acc = 0
        total_num = 0
        self.model.eval()

        # Loop batches
        for x, y in dl:
            images = x.to(self.device)
            targets = y.to(self.device)

            # Forward
            outputs = self.model.forward(images, t)
            output = outputs[t]
            loss = self.criterion(output, targets)
            _, pred = output.max(1)
            hits = (pred == targets).float()

            # Log
            total_loss += loss.data.cpu().numpy().item() * targets.shape[0]
            total_acc += hits.sum().data.cpu().numpy().item()
            total_num += targets.shape[0]
        # endfor

        return total_loss / total_num, total_acc / total_num
    # enddef
