import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Normal, MixtureSameFamily, Categorical

from model import MLP
from flow import Flow, ScoreMatch, RectifiedFlow
from misc import draw_plot, estimate_marginal_accuracy, count_parameters
# from data_config import samples_0, samples_1, save_path, pis, d, mus, sigmas
from data_config import batchsize, iterations, data_config4, config
from visualize import draw_x_line

device = 'cuda'




marginal_accuracies = []
enumerate_values = range(5)

gaussian_scores = []
data_scores = []
first_d = True
norm_grad = False
gaussian_t = 0.
data_t = 0.99
strings = [1, 0.2, 0.1, 0.01, 0.001]

x = torch.linspace(-3, 3, 1000).to(device).unsqueeze(-1)

plt.figure()
for structure_string in strings:
    pis, d, mus, sigmas, name, samples_0, samples_1 = data_config4(structure_string)
    save_path = rf'./saved/{name}-iter{iterations}-b{batchsize}-nums{samples_0.shape[0]}'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    pis, d, mus, sigmas, name, samples_0, samples_1 = data_config4(structure_string)
    iterations = 2000
    batchsize = 2000

    save_path = rf'./saved/{name}-iter{iterations}-b{batchsize}-nums{samples_0.shape[0]}'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    probability_flow = True
    train_by_gt = False
    if train_by_gt:
        config = {
            'pis':pis,
            'd':d,
            'mus':mus,
            'sigmas':sigmas,
            'name':name,
            'num_samples':samples_0.shape[0],
            'samples_1': samples_1,
            'samples_0': samples_0,
            'probability_flow':probability_flow,
            'save_path':save_path,
            'train_by_gt':train_by_gt
    }
    for trajectory in ['vp', 'linear', 'subvp'][0:1]:
        for reparam in ['x_pred', 'epsilon_pred', None, 'negative'][2:3]:
            # model = MLP(config['d'],  structure_str=structure_string).to(device)
            model = MLP(config['d'],  structure_str='100-100').to(device)
            print(model.name)
            flow = ScoreMatch(model=model, trajectory=trajectory,
                              config=config, reparam=reparam)
            print(flow.config['name'])
            save_path = config['save_path']
            if not os.path.exists(f'{save_path}/{flow.name}'):
                os.makedirs(f'{save_path}/{flow.name}')

            gaussian_score = flow.score(x, torch.ones_like(x) * gaussian_t)
            data_score = flow.score(x, torch.ones_like(x) * data_t)
            gaussian_scores.append(gaussian_score.cpu().numpy())
            data_scores.append(data_score.cpu().numpy())

            if not os.path.exists(f'{save_path}/{flow.name}/{model.name}.pth'):
                # if True:
                # optimizer = torch.optim.Adam(flow.model.parameters(), lr=5e-3)
                # optimizer = torch.optim.Adam(flow.model.parameters(), lr=5e-6)
                optimizer = torch.optim.Adam(flow.model.parameters(), lr=5e-3)
                current_random_state = torch.random.get_rng_state()
                torch.manual_seed(1)
                if config['train_by_gt']:
                    loss_curve = flow.train_by_gt(optimizer, config['samples_0'],
                                                  config['samples_1'], batchsize, iterations)
                else:
                    loss_curve = flow.train(optimizer, config['samples_0'],
                                            config['samples_1'], batchsize, iterations)
                torch.set_rng_state(current_random_state)

            # flow = ScoreMatch(model=model, trajectory=trajectory, config=config, reparam=reparam)
            model.load_state_dict(torch.load(f'./{save_path}/{flow.name}/{model.name}.pth'))

            # plt.figure()
            # plt.plot(x.cpu().numpy(), gaussian_score.cpu().numpy())
            # plt.legend()
            # plt.grid()
            # plt.savefig(f'{save_path}/gaussian_scores_{structure_string}.png')
            # plt.show()
            # plt.close()
            #
            # plt.figure()
            # plt.plot(x.cpu().numpy(), data_score.cpu().numpy())
            # plt.legend()
            # plt.grid()
            # plt.ylim(-150, 150)
            # plt.savefig(f'{save_path}/data_scores_{structure_string}.png')
            # plt.show()
            # plt.close()

            test_z0 = torch.randn(50000,1).to('cuda')
            traj = flow.sample_ode(z0=test_z0, N=100, use_gt=False)

            plt.hist(traj[40].detach().cpu().numpy(), bins=200, density=True, alpha=0.5, range=(-3, 3),
                     label=f'$\sigma=${structure_string}')
#
# print(save_path)
# plt.figure()
# plt.plot(x.cpu().numpy(), gaussian_scores[0])
# # plt.legend()
# plt.grid()
# plt.savefig(f'{save_path}/gaussian_scores.png')
# plt.show()
# plt.close()
#
# plt.figure()
# for i, data_score in enumerate(data_scores):
#     plt.plot(x.cpu().numpy(), data_score, label=strings[i])
# plt.legend()
# plt.grid()
# plt.savefig(f'{save_path}/data_scores.png')
# plt.show()
# plt.close()
plt.legend()
plt.grid()
plt.savefig(f'./dis_sigma.png')
plt.show()
plt.close()
