import logging

import torch
import tqdm

import utils
from .label_smoothing import LabelSmoothingLoss


class Trainer:

    def __init__(
        self,
        loss, optim_factory, sched_factory,
        train_batchsize, test_batchsize, epochs, device,
        model=None, optim=None, sched=None
    ):
        self._model = model
        self._optim = optim
        self._sched = sched

        self._loss = loss
        self._optim_factory = optim_factory
        self._sched_factory = sched_factory

        self._train_batchsize = train_batchsize
        self._test_batchsize = test_batchsize
        self._epochs = epochs
        self._device = device

    @staticmethod
    def from_params(device, loss, optim, sched, train_batchsize, test_batchsize, epochs):
        lossf = {
            "cross-entropy": torch.nn.CrossEntropyLoss,
            "label-smooth": LabelSmoothingLoss
        }[loss["type"]](**loss["config"])

        optim_factory = {
            name: lambda model: Optimizer(
                model.parameters(), **optim["config"]
            ) for name, Optimizer in [
                ("adam",  torch.optim.Adam),
                ("sdg", torch.optim.SGD)
            ]
        }[optim["type"]]

        sched_factory = {
            None: lambda optim, epochs: None,
            "multistep-lr": lambda optim, epochs: torch.optim.lr_scheduler.MultiStepLR(
                optim, [int(epochs * 0.80), int(epochs * 0.95)]
            )
        }[sched]

        return Trainer(
            loss=lossf,
            optim_factory=optim_factory,
            sched_factory=sched_factory,
            train_batchsize=train_batchsize,
            test_batchsize=test_batchsize,
            epochs=epochs,
            device=device
        )

    def init(self, model, epochs=None):
        assert self._model is None
        assert self._optim is None
        epochs = epochs or self._epochs
        optim = self._optim_factory(model)
        sched = self._sched_factory(optim, epochs)
        logging.info("Creating trainer with model on device: {}".format(
            self._device
        ))
        trainer = Trainer(
            loss=self._loss.to(self._device),
            optim_factory=None,
            sched_factory=None,
            train_batchsize=self._train_batchsize,
            test_batchsize=self._test_batchsize,
            epochs=epochs,
            device=self._device,
            model=model.to(self._device),
            optim=optim,
            sched=sched
        )
        return trainer

    def train(self, datapool):
        assert self._model is not None
        logging.info(
            "Training {} across {} data points in {}...".format(
                self._model.get_name(), len(datapool), datapool.get_name()
            )
        )
        self._model.train()
        avg = utils.MovingAvg(alpha=0.95)
        loss_avg = utils.MovingAvg(alpha=0.95)
        for epoch, bar, yh, y in self._iter_model_process(
            dataset=datapool, batchsize=self._train_batchsize,
            shuffle=True, epochs=self._epochs, update_sched=True
        ):
            loss = self._loss(yh, y)

            self._optim.zero_grad()
            loss.backward()
            self._optim.step()

            self._update_accuracy(avg, yh, y)
            loss_avg.update(loss.item())
            bar.set_description("[E{}] {:.2f} (L)/ {:.2f} (A)".format(epoch, loss_avg.peek(), avg.peek()))

        logging.info("Training accuracy: {:.4f}".format(avg.peek()))
        return avg.peek()

    def evaluate(self, testset):
        assert self._model is not None
        logging.info("Testing on {} data points...".format(len(testset)))
        self._model.eval()
        avg = utils.Average()
        with torch.no_grad():
            for _, bar, yh, y in self._iter_model_process(
                dataset=testset, batchsize=self._test_batchsize, shuffle=False
            ):
                assert len(yh.size()) == 2 and len(y.size()) == 1
                self._update_accuracy(avg, yh, y)
                bar.set_description("Accuracy {:.4f}".format(avg.peek()))

        final_accuracy = avg.peek()
        return final_accuracy

    # === PROTECTED ===

    def _update_accuracy(self, avg, yh, y):
        total_accuracy = (yh.argmax(dim=1) == y).sum().item()
        avg.update(total_accuracy, n=len(y))

    def _iter_model_process(
        self,
        dataset, batchsize, shuffle,
        epochs=1, update_sched=False
    ):
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batchsize, shuffle=shuffle
        )
        with utils.Bar(range(len(dataloader) * epochs)) as bar:
            for epoch in range(1, epochs+1):
                for X, y in dataloader:
                    bar.update()
                    X = X.to(self._device)
                    y = y.to(self._device)

                    yh = self._model(X)
                    yield epoch, bar, yh, y

                if update_sched and self._sched is not None:
                    self._sched.step()
