import os.path
from itertools import cycle

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from datasets import get_dataset
from flow import RectifiedFlow
from models import get_model


def train_flow(config):
    train_set, test_set = get_dataset(config)
    config.train_set, config.test_set = train_set, test_set
    train_set.data = train_set.get_condition(train_set.data, config.condition_type)
    train_loader = cycle(DataLoader(train_set, batch_size=config.batch_size, shuffle=True, num_workers=16))

    model_cls = get_model(config.model)
    model = model_cls(**config.model_config).to(config.device)
    config.model += '_condition'

    flow = config.flow(model=model, trajectory=config.trajectory,
                       config=config, reparam=config.repara)
    assert isinstance(flow, RectifiedFlow)
    if True:
        opt = torch.optim.Adam(flow.model.parameters(), lr=config.learning_rate)
        loss_curve = flow.train(opt, train_loader)

        iterations = config.iterations

        plt.plot(np.linspace(0, iterations, iterations + 1), loss_curve[:(iterations + 1)])
        plt.title('Training Loss Curve')
        plt.show()
        plt.close()


def eval_flow(config):
    train_set, test_set = get_dataset(config)
    config.train_set, config.test_set = train_set, test_set
    model_cls = get_model(config.model)
    model = model_cls(**config.model_config).to(config.device)
    config.model += '_condition'

    flow = config.flow(model=model, trajectory=config.trajectory,
                       config=config, reparam=config.repara)
    assert isinstance(flow, RectifiedFlow)

    state_dict = torch.load(config.model_path + f'_iter_{config.iterations}', map_location=config.device)['model']
    new_state_dict = {k.replace("model.model.", "model."): v for k, v in state_dict.items()}
    flow.model.load_state_dict(new_state_dict)


    dummy = torch.randn((256, 1))
    samples, gaussian = flow.sample_ode(num_samples=256, num_steps=config.sampling_steps,
                                        batch_size=256, return_z0=True, c=dummy, param={})

    train_set.visualize_batch(samples=samples.detach().cpu(), save_path=rf"{config.data_dir}/pred.png")


if __name__ == '__main__':
    from config.poisson_config import TrainConfig
    train_flow(TrainConfig())
    eval_flow(TrainConfig())
