import os
import torch
import torch.nn.functional as F
from ray import tune
from get_loaders import *
from utils.utils import get_linear_schedule_with_warmup
from vae_models import *
import math


def train(model, optimizer, scheduler, train_loader, M_N):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        outputs = model(data, labels=target)
        train_loss = model.loss_function(*outputs,
                                         M_N=M_N,
                                         )['loss']
        train_loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()


def test(model, data_loader, M_N):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.cuda(), target.cuda()
            outputs = model(data)
            val_loss += model.loss_function(*outputs,
                                            M_N=M_N,
                                            )['loss'].item()

    return val_loss / len(data_loader)


class VAECelebA(tune.Trainable):
    def _setup(self, config):
        args = config["args"]
        task_config = config["task_config"]
        Opt = config["optimizer"]
        # arguments
        fpath = task_config["data_root"]
        batch_size = task_config["batch_size"]
        max_t = task_config["max_t"]
        warmup = config["warmup"]
        decay_rate = config["decay_rate"]

        self.model = VanillaVAE(3, 128).cuda()
        self.train_loader, self.test_loader = get_celeba_loaders(fpath, batch_size)
        # self.M_N = batch_size / len(self.train_loader.dataset)
        # print(len(self.train_loader.dataset))
        self.M_N = 0.005

        # remove args and task_config so that the remaining keys are purely related to optimizer
        del (
            config["args"],
            config["task_config"],
            config["optimizer"],
            config["warmup"],
            config["decay_rate"],
        )

        self.optimizer = Opt(self.model.parameters(), **config)
        warmup_iters = int(warmup * len(self.train_loader))
        total_iters = max_t * len(self.train_loader)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, warmup_iters, total_iters, decay_rate
        )

    def _train(self):
        train(self.model, self.optimizer, self.scheduler, self.train_loader, self.M_N)
        loss = test(self.model, self.test_loader, self.M_N)
        if math.isnan(loss):
            early_stop = True
        else:
            early_stop = False
        return {"mean_loss": loss, "early_stop": early_stop}

    def _save(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return checkpoint_path

    def _restore(self, checkpoint_path):
        self.model.load_state_dict(torch.load(checkpoint_path))
