import time

from tqdm import tqdm
import torch as pt


class Loop:

    def __init__(
        self,
        dataset_t,
        dataset_v,
        model,
        optimiz,
        loss_fn,
        metric_fn,
        callback_t,
        callback_v,
        total_step,
        val_interval,
    ):
        self.dataset_t = dataset_t
        self.dataset_v = dataset_v
        self.model = model
        self.optimiz = optimiz
        self.loss_fn = loss_fn
        self.metric_fn = metric_fn
        self.callback_t = callback_t
        self.callback_v = callback_v
        self.total_step = total_step
        self.val_interval = val_interval
        self.step_count = 0

    def __call__(self):
        pack = self.__dict__
        [_.before_train(**pack) for _ in self.callback_t]

        epoch_count = 0
        epoch_count_v = 0

        while self.step_count < self.total_step:
            pack["epoch"] = epoch_count

            pt.cuda.empty_cache()
            pt.cuda.memory_allocated() / 1024**3
            pt.cuda.memory_reserved() / 1024**3
            self.train_epoch(pack)

            if (
                self.step_count >= (epoch_count_v + 1) * self.val_interval
                or self.step_count >= self.total_step
            ):
                pt.cuda.empty_cache()
                pt.cuda.memory_allocated() / 1024**3
                pt.cuda.memory_reserved() / 1024**3
                self.val_epoch(pack)

                epoch_count_v += 1

            epoch_count += 1

        assert self.step_count == self.total_step
        [_.after_train(**pack) for _ in self.callback_t]

    def train_epoch(self, pack):
        t0 = time.time()
        self.model.train()
        [_.before_epoch(**pack) for _ in self.callback_t]

        # @pt.compile  # causes nan -> only compile model and metric
        def train_step():
            batch = next(datait_t)
            pack["step"] = step
            pack["batch"] = batch
            [_.before_step(**pack) for _ in self.callback_t]

            self.optimiz.zero_grad()

            with pt.autocast("cuda", enabled=True):
                output = self.model(batch)

                pack["output"] = output
                [_.after_forward(**pack) for _ in self.callback_t]

                loss = self.loss_fn(output, batch)
            metric = self.metric_fn(output, batch)  # in autocast may cause inf

            self.optimiz.gscale.scale(sum(loss.values())).backward()
            if self.optimiz.gclip is not None:
                self.optimiz.gscale.unscale_(self.optimiz)
                self.optimiz.gclip(self.model.parameters())
            self.optimiz.gscale.step(self.optimiz)
            self.optimiz.gscale.update()

            pack["loss"] = loss
            pack["metric"] = metric
            [_.after_step(**pack) for _ in self.callback_t]

        progress = tqdm(total=len(self.dataset_t))
        datait_t = iter(self.dataset_t)
        step = 0
        # for step, batch in enumerate(tqdm(self.dataset_t)):
        while step < len(self.dataset_t):
            if self.step_count + 1 > self.total_step:
                break
            train_step()
            self.step_count += 1
            step += 1
            progress.update()

        progress.close()

        [_.after_epoch(**pack) for _ in self.callback_t]
        print("b/s:", len(self.dataset_t) / (time.time() - t0))

    @pt.no_grad()
    def val_epoch(self, pack):
        self.model.eval()
        [_.before_epoch(**pack) for _ in self.callback_v]

        # @pt.compile  # causes nan -> only compile model and metric
        def val_step():
            batch = next(datait_v)
            pack["step"] = step
            pack["batch"] = batch
            [_.before_step(**pack) for _ in self.callback_v]

            with pt.autocast("cuda", enabled=True):
                output = self.model(batch)

                pack["output"] = output
                [_.after_forward(**pack) for _ in self.callback_v]

                loss = self.loss_fn(output, batch)
            metric = self.metric_fn(output, batch)  # in autocast may cause inf

            pack["loss"] = loss
            pack["metric"] = metric
            [_.after_step(**pack) for _ in self.callback_v]

        datait_v = iter(self.dataset_v)
        step = 0
        # for step, batch in enumerate(self.dataset_v):
        while step < len(self.dataset_v):
            val_step()
            step += 1

        [_.after_epoch(**pack) for _ in self.callback_v]
