import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Normal, MixtureSameFamily, Categorical

from model import MLP, TwoModel, Repara, SkipConnection
from flow import Flow, ScoreMatch, RectifiedFlow
from misc import draw_plot, estimate_marginal_accuracy, count_parameters, draw_end_data_velocity, alpha_hill_estimate
# from data_config import samples_0, samples_1, save_path, pis, d, mus, sigmas
from data_config import batchsize, iterations, data_config4, config
from visualize import draw_x_line

# config['save_path'] = 'dummy'
device = 'cuda'


# iterations = 2000
# batchsize = 2048


marginal_accuracies = []

zero_means = []
zero_stds = []
end_means = []
end_stds = []
gts_data = []
first_d = True
norm_grad = False
gaussian_t = 0.0
data_t = 0.9

relu=False
strings = [
    # '10-10', '20-20',
    '30-30', '40-40', '50-50', '60-60', '70-70', '80-80', '90-90', '100-100',
         '200-200', '300-300', '400-400', '500-500', '600-600', '700-700', '800-800',
    '900-900', '1000-1000']
# strings=['30-30']
# strings = [
#     '10', '20', '30', '40', '50', '60', '70', '80', '90', '100',
#          '200', '300', '400', '500', '600', '700', '800', '900',
#          '1000']
# strings = ['10-10', '50-50', '100-100', '200-200', '300-300']
# strings = [
#     # '10-10-10',
#            '20-20-20', '30-30-30', '40-40-40', '50-50-50', '60-60-60', '70-70-70', '80-80-80', '90-90-90', '100-100-100',
#          '200-200-200', '300-300-300', '400-400-400', '500-500-500', '600-600-600', '700-700-700', '800-800-800', '900-900-900',
#          '1000-1000-1000']
# strings = ['-'.join(['100'] * i) for i in range(2, 10 + 1)]
# strings = [
#     # '10-10', '20-20',
#             '30-30', '40-40', '50-50', '60-60', '70-70', '80-80', '90-90', '100-100',
#            '200-200', '300-300', '400-400', '500-500', '600-600', '700-700', '800-800', '900-900',
#            '1000-1000']
# strings = list(map(lambda n: f"{n}", range(100, 2100, 100)))
# strings = list(map(lambda n: f"{n}", range(100, 2100, 100)))
# strings = list(map(lambda n: f"{n}", range(1000, 11000, 1000)))
# strings = list(map(lambda n: f"{n}-{n}", range(1000, 11000, 1000)))
# strings = list(map(lambda n: f"{n}-{n}-{n}-{n}", range(10, 110, 10)))
# strings = list(map(lambda n: f"{n}-{n}", range(2, 11)))
# strings = list(map(lambda n: f"{n}-{n}", range(11, 100)))
# strings = list(map(lambda n: f"{n}-{n}", range(11, 47)))
# strings = list(map(lambda n: f"{n}-{n}", range(30, 210, 10)))
# strings = list(map(lambda n: f"{n}-{n}", range(100, 1100, 100)))
# strings = [2, 0.2, 0.02, 0.002, 0.0002]
# strings = list(map(lambda n: f"{n}-{n}-{n}", range(100, 1100, 100)))
# strings = list(map(lambda n: f"{n}-{n}-{n}", range(1000, 11000, 1000)))
# strings = list(map(lambda n: f"{n}-{n}-{n}", range(10, 110, 10)))

# width_interest = [30, 100, 200]
width_interest = [30, 100, 500]

z1 = config['samples_1']
eps = 0.02
alpha_data = alpha_hill_estimate(z1[:20000,0:1], eps=eps)
# plt.hist(z1[:50000, 0:1].cpu().detach().numpy(), bins=100)
# plt.show()
# plt.close()
# quit()

easy_parts = []
hard_parts = []

alpha_diffs = []
alpha_diffs_std = []

enumerate_values = range(5)
for str_id, structure_string in enumerate(strings):
    error_at_zero = []
    error_at_end = []
    _alpha_diffs = []
    for _ in enumerate_values:
        for trajectory in ['vp', 'linear', 'subvp'][0:1]:
            # for reparam in ['x_pred', 'epsilon_pred', None, 'negative'][1:2]:
            for reparam in ['x_pred', 'epsilon_pred', None, 'negative'][2:3]:
                # model = MLP(config['d'],
                #             structure_str=structure_string,
                #             relu=relu,
                #             ).to(device)
                model = TwoModel(config['d'],
                            structure_str=structure_string,
                            relu=relu,
                            ).to(device)
                # model = Repara(config['d'],
                #                  structure_str=structure_string,
                #                  relu=relu,
                #                  ).to(device)
                # model = SkipConnection(config['d'],
                #                  structure_str=structure_string,
                #                  relu=relu,
                #                  ).to(device)
                # model = TwoModel(config['d'],  structure_str=structure_string).to(device)
                # model = MLP(config['d'],  structure_str='50-50-50').to(device)
                # print(model.name)
                flow = ScoreMatch(model=model, trajectory=trajectory,
                                  config=config, reparam=reparam)
                # print(flow.config['name'])
                save_path = config['save_path']
                # print(f'{save_path}/{flow.name}')
                if not os.path.exists(f'{save_path}/{flow.name}'):
                    os.makedirs(f'{save_path}/{flow.name}')

                # count_parameters(flow.model)

                # flow.model.load_state_dict(torch.load(f'./{save_path}/{flow.name}/model.pth'))
                if _ > 0:
                    model.name = model.name + '_' + str(_)

                print(f'{save_path}/{flow.name}/{model.name}.pth')
                if not os.path.exists(f'{save_path}/{flow.name}/{model.name}.pth'):
                # if True:
                    # optimizer = torch.optim.Adam(flow.model.parameters(), lr=5e-3)
                    # optimizer = torch.optim.Adam(flow.model.parameters(), lr=5e-6)
                    optimizer = torch.optim.Adam(flow.model.parameters(), lr=5e-3)
                    current_random_state = torch.random.get_rng_state()
                    torch.manual_seed(1)
                    if config['train_by_gt']:
                        loss_curve = flow.train_by_gt(optimizer, config['samples_0'],
                                                      config['samples_1'], batchsize, iterations)
                    else:
                        loss_curve = flow.train(optimizer, config['samples_0'],
                                                config['samples_1'], batchsize, iterations)
                    torch.set_rng_state(current_random_state)

                # flow = ScoreMatch(model=model, trajectory=trajectory, config=config, reparam=reparam)
                model.load_state_dict(torch.load(f'./{save_path}/{flow.name}/{model.name}.pth'))

                rectified_flow = flow
                z0 = config['samples_0']

                # traj = rectified_flow.sample_ode(z0=z0[:], N=100, batch_size=None)
                # alpha_samples = alpha_hill_estimate(traj[-1][:20000, 0:1], eps=eps)
                # alpha_diff = alpha_data - alpha_samples
                # print(alpha_data, alpha_samples, alpha_diff)
                # _alpha_diffs.append(alpha_diff)
                #dummy
                _alpha_diffs.append(0)


                def preprocess_x(x):
                    _x = torch.zeros(len(x), z0.shape[1]).to(device)
                    _x[:, 0:1] = x.clone()
                    return _x


                if _ == 0:
                    if int(structure_string.split('-')[0]) in width_interest:
                    # if str_id == 0 or str_id == (len(strings)-1):
                        easy_part, gt_easy_part, hard_part, gt_hard_part =\
                            draw_end_data_velocity(rectified_flow, return_values=True, vis=False, t=[gaussian_t, data_t])
                        easy_parts.append(easy_part)
                        hard_parts.append(hard_part)


                def velocity_error(x_grid, t):
                    x_grid.requires_grad_()  # 
                    t = t.clone().detach().requires_grad_(True)
                    p =rectified_flow.p(preprocess_x(x_grid.unsqueeze(-1)), t)
                    p = p / torch.sum(p)
                    p = p.detach()
                    predicted_v = rectified_flow.predicted_velocity(
                        preprocess_x(x_grid[:, None]), t)[:, :1]
                    if gaussian_t in t:
                        predicted_v = predicted_v - torch.sum(p*predicted_v)
                    # if t[0][0] == 0:
                    #     predicted_v = predicted_v - torch.mean(predicted_v)
                    error = (p * ((predicted_v
                                   - rectified_flow.velocity(preprocess_x(x_grid[:, None]), t)[:, :1])))[:, 0]
                    # error = error - torch.mean(error)
                    # error = (((predicted_v
                    #                - rectified_flow.velocity(preprocess_x(x_grid[:, None]), t))))[:, 0]
                    return error


                def velocity_gradient(x_grid, t):
                    x_grid.requires_grad_()  # 
                    x_grid.grad = None
                    t = t.clone().detach().requires_grad_(True)
                    p = rectified_flow.p(preprocess_x(x_grid.unsqueeze(-1)), t)
                    p = p / torch.sum(p)
                    p = p.detach()
                    predicted_v = rectified_flow.predicted_velocity(
                        preprocess_x(x_grid[:, None]), t)
                    # predicted_v = rectified_flow.velocity(
                    #     preprocess_x(x_grid[:, None]), t)
                    # error = (p * ((predicted_v
                    #                - rectified_flow.velocity(preprocess_x(x_grid[:, None]), t))))[:, 0:1]

                    error = (p * predicted_v)[:, 0:1]
                    error.backward(torch.ones_like(error), create_graph=True)

                    # predicted_v[:, 0:1].backward(torch.ones_like(predicted_v[:, 0:1]), create_graph=True)

                    if not first_d:
                        grad = x_grid.grad
                        x_grid.grad = None
                        grad.backward(torch.ones_like(grad))
                    grad = x_grid.grad

                    # grad = p.squeeze() * grad

                    if norm_grad:
                        grad = grad / torch.max(torch.abs(grad))
                    return grad


                # draw_grid(velocity_error,
                #           'Velocity Error',
                #           f'{save_path}/{rectified_flow.name}/velocity_error', vmin=-1., vmax=1.)

                for _t in torch.tensor([gaussian_t, data_t]).to(device):
                    x_values = torch.linspace(-3, 3, 1000).to(device)
                    # func_values = velocity_gradient(x_values, _t * torch.ones_like(x_values.unsqueeze(-1))).detach().cpu().numpy()
                    func_values = velocity_error(x_values, _t * torch.ones_like(x_values.unsqueeze(-1))).detach().cpu().numpy()
                    # error = np.sum(func_values)
                    # error = np.mean(np.abs(func_values))
                    # mean = np.mean(func_values)
                    # std = np.std(func_values)
                    # func_values = (func_values - mean) / std
                    # error = np.trapz(np.abs(func_values), x_values.detach().cpu().numpy())
                    error = np.sum(np.abs(func_values))
                    # error = np.std(func_values)
                    if _t ==gaussian_t:
                        error_at_zero.append(error)
                    else:
                        error_at_end.append(error)

    _alpha_diffs_mean = np.mean(_alpha_diffs)
    _alpha_diffs_std = np.std(_alpha_diffs)
    alpha_diffs.append(_alpha_diffs_mean)
    alpha_diffs_std.append(_alpha_diffs_std)

    error_zero_mean = np.mean(error_at_zero)
    error_zero_std = np.std(error_at_zero)
    zero_means.append(error_zero_mean)
    zero_stds.append(error_zero_std)

    error_end_mean = np.mean(error_at_end)
    error_end_std = np.std(error_at_end)
    end_means.append(error_end_mean)
    end_stds.append(error_end_std)

    print(error_at_zero, error_at_end)
    print('error at zero mean: ', error_zero_mean)
    print('error at zero std: ', error_zero_std)
    print('error at end mean: ', error_end_mean)
    print('error at end std: ', error_end_std)


    def velocity_gradient_gt(x_grid, t):
        x_grid.requires_grad_()  # 
        x_grid.grad = None
        t = t.clone().detach().requires_grad_(True)
        p = rectified_flow.p(preprocess_x(x_grid.unsqueeze(-1)), t)
        p = p / torch.sum(p)
        p = p.detach()
        predicted_v = rectified_flow.velocity(
            preprocess_x(x_grid[:, None]), t)
        error = (p * predicted_v)[:, 0:1]
        error.backward(torch.ones_like(error), create_graph=True)

        if not first_d:
            grad = x_grid.grad
            x_grid.grad = None
            grad.backward(torch.ones_like(grad))

        grad = x_grid.grad
        if norm_grad:
            grad = grad / torch.max(torch.abs(grad))
        return grad


    x_values = torch.linspace(-3, 3, 1000).to(device)
    # gt_data = np.trapz(np.abs(velocity_gradient_gt(x_values, data_t * torch.ones_like(x_values.unsqueeze(-1))).detach().cpu().numpy()), x_values.detach().cpu().numpy())
    gt_data = 0
    gts_data.append(gt_data)
    print(f'error at end gt: {gt_data}')


d = config['d']
sigma = config['sigmas'].detach().cpu().numpy()[0][0][0]
act = 'relu'
if 'relu' not in model.name:
    act = 'tanh'
model_type = '{}_layer'.format(len(strings[0].split('-')))

from matplotlib import pyplot as plt
x = np.arange(len(alpha_diffs))
# plt.plot(x, end_means, label='close to data')
print(alpha_diffs)
print(alpha_diffs_std)
plt.errorbar(x, alpha_diffs, yerr=alpha_diffs_std, fmt='-o', capsize=5, color='blue')
dummy_strings = [string.split('-')[0] for string in strings]
plt.xticks(x, dummy_strings, rotation=90, fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel('Model Width', fontsize=20)
plt.ylabel('TID', fontsize=20)
# plt.ylim(0, 0.4)
# plt.legend(fontsize=12)
plt.grid()
plt.tight_layout()
plt.savefig(f'{model_type}_{act}_{d}_{sigma:.4f}_TID_along_model_width.png', dpi=300)
plt.show()
plt.close()
# quit()

print(zero_means)
print(zero_stds)
print(end_means)
print(end_stds)
save_path = config['save_path']
plt.figure()
x = np.arange(len(end_means))
# plt.plot(x, end_means, label='close to data')
plt.errorbar(x, end_means, yerr=end_stds, fmt='-o', capsize=5, label=rf'$v_\theta(\cdot,{1-data_t:.1f})$', color='blue')

# plt.plot(x, gts_data, label='close to data gt')
# plt.xticks(x, strings, rotation=90)
# # plt.ylim(0, 0.4)
# plt.legend()
# plt.grid()
# plt.tight_layout()
# plt.savefig(f'{save_path}/{flow.name}/error at data.png')
# plt.show()
# plt.close()

# plt.figure()
# x = np.arange(len(end_means))
# plt.plot(x, zero_means, label='close to gaussian')
plt.errorbar(x, zero_means, yerr=zero_stds, fmt='-o', capsize=5, label=r'$v_\theta(\cdot,1)$', color='orange')
# plt.plot(x, np.zeros_like(end_means) , label='close to gaussian gt')
# plt.xticks(np.array([int(i.split('-')[0]) for i in strings]), strings)
dummy_strings = [string.split('-')[0] for string in strings]
plt.xticks(x, dummy_strings, rotation=90, fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel('Model Width', fontsize=20)
plt.ylabel('Mean Absolute Error', fontsize=20)
# plt.ylim(0, 0.4)
plt.legend(fontsize=18)
plt.grid()
plt.tight_layout()
plt.savefig(f'{model_type}_{act}_{d}_{sigma:.4f}_velocity_mae.png', dpi=300)
plt.show()
plt.close()

for i, easy_part in enumerate(easy_parts):
    plt.plot(torch.linspace(-3, 3, 1000), easy_part - np.mean(easy_part), label=f'Width={width_interest[i]}')
plt.plot(torch.linspace(-3, 3, 1000), gt_easy_part, label='Ground Truth')
plt.xlabel(r'$x$', fontsize=20)
plt.ylabel(r'$v_\theta(x, 1)$', fontsize=20)
plt.legend(fontsize=15)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.grid()
plt.tight_layout()
plt.savefig(f'{model_type}_{act}_{d}_{sigma:.4f}_velocity_easy_part.png', dpi=300)
plt.show()
plt.close()
for i, hard_part in enumerate(hard_parts):
    plt.plot(torch.linspace(-3, 3, 1000), hard_part- np.mean(hard_part), label=f'Width={width_interest[i]}')
plt.plot(torch.linspace(-3, 3, 1000), gt_hard_part, label='Ground Truth')
plt.xlabel(r'$x$', fontsize=20)
plt.ylabel(rf'$v_\theta(x, {1-data_t:.1f})$', fontsize=20)
plt.legend(fontsize=15)
plt.xticks(fontsize=20)
plt.ylim(-20,20)
plt.yticks(fontsize=20)
plt.grid()
plt.tight_layout()
plt.savefig(f'{model_type}_{act}_{d}_{sigma:.4f}_velocity_hard_part.png', dpi=300)
plt.show()
plt.close()
