import os
import torch
import numpy as np
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from model import MLP, TwoModel, Repara, RandomFeature, SkipConnection
from flow import Flow, ScoreMatch, RectifiedFlow, ScoreMatchRF
from misc import draw_plot, count_parameters
# from data_config import samples_0, samples_1, save_path, pis, d, mus, sigmas
from data_config import config, batchsize, iterations


for trajectory in ['vp', 'linear', 'subvp'][0:1]:
    for reparam in [None]:
        model = MLP(config['d'], structure_str='100-100-100', relu=False).to('cuda')
        # model = TwoModel(config['d'], structure_str='10-10-10', relu=False).to('cuda')
        # model = Repara(config['d'], structure_str='100-100-100', relu=False).to('cuda')
        # model = SkipConnection(config['d'], structure_str='100-100-100', relu=False).to('cuda')
        flow = ScoreMatch(model=model, trajectory=trajectory,
                          config=config, reparam=reparam)
        save_path = config['save_path']
        if not os.path.exists(f'{save_path}/{flow.name}'):
            os.makedirs(f'{save_path}/{flow.name}')

        count_parameters(flow.model)
        print(model.name)

        print(f'{save_path}/{flow.name}/{model.name}.pth')
        # if True:
        if not os.path.exists(f'{save_path}/{flow.name}/{model.name}.pth'):
            optimizer = torch.optim.Adam(flow.model.parameters(), lr=5e-3)
            current_random_state = torch.random.get_rng_state()
            if config['train_by_gt']:
                loss_curve = flow.train_by_gt(optimizer, config['samples_0'],
                                        config['samples_1'], batchsize, iterations)
            else:
                loss_curve = flow.train(optimizer, config['samples_0'],
                                        config['samples_1'], batchsize, iterations)
            torch.set_rng_state(current_random_state)
            plt.plot(np.linspace(0, iterations, iterations + 1), loss_curve[:(iterations + 1)])
            plt.title('Training Loss Curve')
            plt.savefig(f'./{save_path}/{flow.name}/loss_curve.png')
            plt.show()
            plt.close()

        flow = ScoreMatch(model=model, trajectory=trajectory, config=config, reparam=reparam)
        model.load_state_dict(torch.load(f'./{save_path}/{flow.name}/{model.name}.pth'))

        from misc import draw_plot
        draw_plot(flow, config['samples_0'], config['samples_1'], 100)

