import os.path
from itertools import cycle

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from datasets import get_dataset
from flow import RectifiedFlow
from models import get_model


def train_flow(config):
    train_set, test_set = get_dataset(config)
    config.train_set, config.test_set = train_set, test_set
    train_loader = cycle(DataLoader(train_set, batch_size=config.batch_size, shuffle=True, num_workers=16))

    model_cls = get_model(config.model)
    model = model_cls(**config.model_config).to(config.device)


    flow = config.flow(model=model, trajectory=config.trajectory,
                       config=config, reparam=config.repara)
    assert isinstance(flow, RectifiedFlow)
    if True:
        opt = torch.optim.Adam(flow.model.parameters(), lr=config.learning_rate)
        loss_curve = flow.train(opt, train_loader)

        iterations = config.iterations

        plt.plot(np.linspace(0, iterations, iterations + 1), loss_curve[:(iterations + 1)])
        plt.title('Training Loss Curve')
        plt.show()
        plt.close()




if __name__ == '__main__':
    from config.poisson_config import TrainConfig
    train_flow(TrainConfig())
