import torch
from tqdm import tqdm
from .helpers import ModelCheckpoint, BasicCallbacks


class Trainer(BasicCallbacks):
    def __init__(self, train_cfg, model, loaders):
        self.cfg = train_cfg
        self.model = model
        self.data = loaders

    def start(self):
        checkpoint = ModelCheckpoint(
            filepath=self.cfg.logdir / "model.pt",
            model=self.model,
            period=1,
        )
        self.model.to(self.cfg.device)
        self.training_loop(self.cfg.logdir, checkpoint)

    def training_loop(self, logdir, checkpoint):
        # Initializing weight-specific objects
        self.optim = self.cfg.get_optimizer(self.model)
        self.lr_sched = self.cfg.get_lr_sched(self.optim)

        # Train+Val Loop
        ftrain = open(logdir / "train.csv", "a", newline="")
        fval = open(logdir / "val.csv", "a", newline="")
        for epoch in tqdm(range(self.cfg.n_epochs), desc="Epochs", leave=False):
            self.train_epoch(ftrain, epoch)
            current_score = self.validate(fval, epoch)
            checkpoint.update(epoch, current_score)
        ftrain.close()
        fval.close()

        checkpoint.save()
        checkpoint.save_last()

        # Test (with final model and best model)
        ftest = open(logdir / "test.csv", "a", newline="")
        self.test(ftest, 0)
        checkpoint.reload_best()
        self.on_model_change()
        self.test(ftest, 1)
        ftest.close()

    def train_epoch(self, ftrain, epoch):
        self.on_epoch_start(epoch)
        self.model.train()
        for iteration, batch in enumerate(tqdm(self.data["train"], desc="Train", leave=False)):
            self.on_iter_start(epoch, iteration)
            inputs, targets = batch
            inputs = inputs.to(self.cfg.device, non_blocking=True)
            targets = targets.to(self.cfg.device, non_blocking=True)

            self.optim.zero_grad()
            outputs = self.model(inputs)
            loss = self.cfg.loss(outputs, targets)
            loss.backward()

            self.optim.step()
            self.cfg._update_metrics(outputs, targets)
            self.on_iter_end(epoch, iteration)
            self.on_model_change()

        self.lr_sched.step()
        self.cfg._reset_metrics(ftrain, index=epoch)
        self.on_epoch_end(epoch)

    @torch.no_grad()
    def validate(self, fval, epoch):
        self.model.eval()
        for batch in tqdm(self.data["val"], desc="Val", leave=False):
            inputs, targets = batch
            inputs = inputs.to(self.cfg.device)
            targets = targets.to(self.cfg.device)
            outputs = self.model(inputs)
            self.cfg._update_metrics(outputs, targets)
        current_score = self.cfg.monitor()
        self.cfg._reset_metrics(fval, index=epoch)
        return current_score

    @torch.no_grad()
    def test(self, ftest, idx):
        self.model.eval()
        for batch in tqdm(self.data["test"], desc="Test", leave=False):
            inputs, targets = batch
            inputs = inputs.to(self.cfg.device)
            targets = targets.to(self.cfg.device)
            outputs = self.model(inputs)
            self.cfg._update_metrics(outputs, targets)
        self.cfg._reset_metrics(ftest, index=idx)
