import os
import sys

import hydra
from omegaconf import OmegaConf

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import Module
from torch_geometric.data import DataLoader

sys.path.insert(1, '/root/workspace/stable-equiv/stable_equiv')
from lj13_dataset import LJ13, lj13_dataloaders
from samplers import sample_graphs
from _utils import (
    remove_mean,
)
from _prior import PositionPrior
from _dynamics import EGNN_dynamics
from _flows import FFJORD


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Model
#----------------------------------------------------------------------------------------------------------------------------------------------------



class Model(Module):
    def __init__(self, particle_count, particle_dim, hidden_channels, hidden_layers, normalize, norm_const, normalize_type, reg_para, ode_solver, ode_mesh, reg_clip):
        super(Model, self).__init__()

        self.prior = PositionPrior()

        net_dynamics = EGNN_dynamics(
            n_particles=particle_count,
            n_dimension=particle_dim,
            hidden_nf=hidden_channels,
            act_fn=torch.nn.SiLU(),
            n_layers=hidden_layers,
            recurrent=True,
            tanh=True,
            attention=True,
            condition_time=True,
            agg='sum',
            normalize=normalize,
            norm_const = norm_const,
            normalize_type=normalize_type,
            reg_para=reg_para,
            reg_clip = reg_clip
        )

        self.flow = FFJORD(
            net_dynamics,
            trace_method='hutch',
            hutch_noise='bernoulli',
            ode_regularization=0, 
            ode_solver=ode_solver, 
            ode_mesh=ode_mesh
        )

    def forward(self, x):
        x = remove_mean(x)
        z, delta_logp, _ = self.flow(x)

        log_pz = self.prior(z)
        log_px = (log_pz + delta_logp.view(-1)).mean()
        return log_px, self.flow.dynamics_reg
#----------------------------------------------------------------------------------------------------------------------------------------------------
# Config/Model/Dataset
#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        rand_id = np.random.randint(0,1e3)
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(rand_id)+'.pt'
    # Set Backends
    torch.backends.cudnn.deterministic = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

def load(cfg):
    args = cfg.load
    dataset, _, _, _, train_dl, val_dl, test_dl = lj13_dataloaders(
        split =  args['split'],
        batch_size = args['batch_size'],
    )

    model_kwargs = OmegaConf.to_container(cfg.model)
    model = Model(
        particle_count = dataset.num_nodes,
        particle_dim = dataset.pos_dim,
        hidden_channels = model_kwargs['hidden_channels'],
        hidden_layers = model_kwargs['hidden_layers'],
        normalize = model_kwargs['normalize'],
        norm_const = model_kwargs['norm_const'],
        normalize_type = model_kwargs['normalize_type'],
        reg_para = model_kwargs['reg_para'], 
        ode_solver=model_kwargs['ode_solver'], 
        ode_mesh=model_kwargs['ode_mesh'],
        reg_clip=model_kwargs['reg_clip'],
    )

    checkpoint = torch.load(cfg.load['checkpoint_path'], map_location=torch.device(cfg.setup['device']))
    model.load_state_dict(checkpoint['model_state_dict'])
    return model, train_dl, val_dl, test_dl

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Analysis
#----------------------------------------------------------------------------------------------------------------------------------------------------

def compute_loss(cfg, dataloader, model):
    model.eval()
    loss, count = 0, 0
    for data in dataloader:
        data.to(cfg.setup['device'])
        pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
        log_px, _ = model(pos)
        loss += -log_px.detach().item()
        batch_size = data.batch.max().item() + 1
        count += batch_size
    return loss/count

class Sampling:
    @staticmethod
    def gaussian(size):
        return torch.randn(size)

    @staticmethod
    def center_gravity_zero_gaussian(size):
        s = Sampling.gaussian(size)

        # This projection only works because Gaussian is rotation invariant
        # around zero and samples are independent! TODO where to put this comment?
        s_projected = remove_mean(s)
        return s_projected

    @staticmethod
    def x(size):
        return Sampling.gaussian(size)

    @staticmethod
    def pos(size):
        assert len(size) == 3
        return Sampling.center_gravity_zero_gaussian(size)

def energy_histogram(true, true_ind, new, new_ind, fname):
    plt.style.use('classic')
    plt.rcParams["figure.figsize"] = (16,9)
    plt.rcParams["font.size"] = 50
    # # plt.rcParams["font.weight"] = 'bold'
    plt.rcParams["xtick.color"] = 'black'
    plt.rcParams["ytick.color"] = 'black'
    plt.rcParams["axes.edgecolor"] = 'black'
    plt.rcParams["axes.linewidth"] = 1

    plt.figure()
    energies = LJ13.potential_function(true, true_ind).numpy()
    print(energies.shape)
    mean, stddev = np.mean(energies), np.std(energies)
    counts, bins = np.histogram(energies, bins=200)
    print(len(counts), len(bins))
    plt.stairs(counts / counts.sum(), bins, label=f'True({mean:.2f}{stddev:.1f})', linewidth=4)
    # ax.set_xlim(*xbounds)

    energies = LJ13.potential_function(new, new_ind).numpy()
    mean, stddev = np.mean(energies), np.std(energies)
    counts, bins = np.histogram(energies, bins=100)
    plt.stairs(counts / counts.sum(), bins, label=f'Sampled({mean:.2f}{stddev:.1f})', linewidth=4)

    # plt.ylim(0,.08)
    # plt.yticks([0.01, 0.03, 0.05, 0.07])
    # plt.xlim(-26,6)
    # plt.xticks([-25, -15, -5, 5])
    plt.legend(fancybox=True,handlelength=1.5,handletextpad=.4,shadow=True,loc='upper right',bbox_to_anchor=(1.,1.02),ncol=1, fontsize=50)
    plt.savefig(f'/root/workspace/out/stable_equiv/lj13_stable_energy.pdf', format='pdf', bbox_inches='tight')
    plt.close()
    pass


def relative_distances_histogram(true, true_ind, new, new_ind, fname):
    plt.style.use('classic')
    plt.rcParams["figure.figsize"] = (16,9)
    plt.rcParams["font.size"] = 50
    # # plt.rcParams["font.weight"] = 'bold'
    plt.rcParams["xtick.color"] = 'black'
    plt.rcParams["ytick.color"] = 'black'
    plt.rcParams["axes.edgecolor"] = 'black'
    plt.rcParams["axes.linewidth"] = 1

    plt.figure()
    pos = true.view(-1, 2)
    node_i_idxs, node_j_idxs = true_ind
    pos_diffs = pos[node_i_idxs] - pos[node_j_idxs]
    dij = (torch.sum(pos_diffs**2, dim=1)**(1/2)).ravel().numpy()
    mean, stddev = np.mean(dij), np.std(dij)
    counts, bins = np.histogram(dij, bins=100)
    plt.stairs(counts / counts.sum(), bins, label=f'True({mean:.2f}$\pm${stddev:.2f})', linewidth=4)

    pos = new.view(-1, 2)
    node_i_idxs, node_j_idxs = new_ind
    pos_diffs = pos[node_i_idxs] - pos[node_j_idxs]
    dij = (torch.sum(pos_diffs**2, dim=1)**(1/2)).ravel().numpy()
    mean, stddev = np.mean(dij), np.std(dij)
    counts, bins = np.histogram(dij, bins=100)
    plt.stairs(counts / counts.sum(), bins, label=f'Sample({mean:.2f}$\pm${stddev:.2f})',linewidth=4)

    # plt.ylim(0,.045)
    # plt.yticks([0.01, 0.02, 0.03])
    # plt.xlim(0,6.5)
    plt.xlabel('Relative Distance', color='black',fontsize=55)
    plt.ylabel('Probability', color='black',fontsize=55)
    plt.legend(fancybox=True,handlelength=0.8,handletextpad=.4,shadow=True,loc='upper right',bbox_to_anchor=(1.02,1.04),ncol=1, fontsize=50)
    plt.savefig(f'/root/workspace/out/stable_equiv/lj13_stable_rel_dist.pdf', format='pdf', bbox_inches='tight')
    plt.close()

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Drivers
#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/stable-equiv/checkpoints/", config_name="stable_lj13")
def analyze_lj13(cfg):
    """
    Execute run saving details to wandb server.
    """
    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    model, train_dl, val_dl, test_dl = load(cfg)

    model.to(cfg.setup['device'])
    model.eval()
    # train_loss = compute_loss(cfg, train_dl, model)
    # val_loss = compute_loss(cfg, val_dl, model)
    # test_loss = compute_loss(cfg, test_dl, model)

    # print(f'train({train_loss})\tval({val_loss})\ttest({test_loss})')

    for data in val_dl:
        true = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
        true_ind = data.edge_index
        break

    count = cfg.load['batch_size']
    pos = torch.randn((count, 13, 3))
    dataset = sample_graphs(pos=pos)
    sample_dl = DataLoader(dataset, batch_size=count, shuffle=True)
    for data in sample_dl:
        data.to(cfg.setup['device'])
        pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
        new = model.flow.reverse(z=pos, edges=data.edge_index, batch=data.batch)
        new_ind = data.edge_index
        break
    energy_histogram(true, true_ind, new, new_ind, fname='lj13_'+cfg.load['split']+'_sampled')
    relative_distances_histogram(true, true_ind, new, new_ind, fname='lj13_'+cfg.load['split']+'_sampled')

    # Terminate
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    analyze_lj13()
