#!/usr/bin/python3
import os, json, sys
import time

import hydra
from omegaconf import OmegaConf
import wandb

import numpy as np
import random
import torch
from torch.nn import Linear, Module, ReLU, Sequential, SiLU
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.graphgym import global_add_pool
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_self_loops
import torch_geometric.transforms as T

# import pdb

from stable_equiv import qm9_gen_dataloaders
from _utils import (
    Queue,
    gradient_clipping,
    random_rotation,
    remove_mean,
)
from _dequant import ArgmaxAndVariationalDequantizer
from _prior import PositionFeaturePrior
from _distributions import DistributionNodes
from _dynamics import EGNN_dynamics_QM9
from _flows import Flow, FFJORD

"""
QM9 Dataset
"""

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Model
#----------------------------------------------------------------------------------------------------------------------------------------------------

class Model(Module):
    """
    Args
    -----
        Dependent Args
        ~~~~~~~~~~~~~~~
        + context_dim: (int) 
        + in_channels: (int) 
        + particle_dim: (int) dimension of the physical space

        Independent Args
        ~~~~~~~~~~~~~~~~~
        + augment_data: (bool) Wether or not to augment the data with a rotation.
        + hidden_channels: (int) 
        + hidden_layers: (int) 

        Module Args
        ~~~~~~~~~~~
        + aggregation: (str) EGNN aggregation
        + attention: (bool) EGNN attention
        + tanh: (bool) EGNN tanh
        + condition_time: (bool) Dynamics conditioning of time
        + ode_reg: (bool) Dynamics regularization
        + trace: (bool) Flow trace estimator

    """
    def __init__(self, 
        context_dim: int,
        in_channels: int, 
        particle_dim: int,
        augment_data: bool,
        hidden_channels: int,
        hidden_layers: int,
        aggregation: str,
        attention: bool,
        tanh: bool,
        condition_time: bool,
        ode_reg: float,
        trace: bool,
    ):

        super(Model, self).__init__()

        self.n_nodes= {22: 3393, 17: 13025, 23: 4848, 21: 9970, 19: 13832, 20: 9482, 16: 10644, 13: 3060,
                15: 7796, 25: 1506, 18: 13364, 12: 1689, 11: 807, 24: 539, 14: 5136, 26: 48, 7: 16, 10: 362,
                8: 49, 9: 124, 27: 266, 4: 4, 29: 25, 6: 9, 5: 5, 3: 1}
        self.nodes_dist = DistributionNodes(self.n_nodes)

        dynamics_in_node_nf = in_channels + 1 if condition_time else in_channels # because condition time
        self.aug_data = augment_data
        self.ode_reg = ode_reg
        self.n_dims = particle_dim

        self.prior = PositionFeaturePrior(n_dim=particle_dim, in_node_nf=in_channels)
        self.dequantizer = ArgmaxAndVariationalDequantizer(in_channels)

        net_dynamics = EGNN_dynamics_QM9(
            in_node_nf = dynamics_in_node_nf,
            context_node_nf = context_dim,
            n_dims = 3,
            hidden_nf=hidden_channels,
            act_fn=torch.nn.SiLU(),
            n_layers=hidden_layers,
            recurrent=True,
            tanh=tanh,
            attention=attention,
            condition_time=condition_time,
            agg=aggregation
        )

        ffjord = FFJORD(
            net_dynamics,
            trace_method=trace,
            ode_regularization=ode_reg
        )
        self.flow = Flow(transformations=[ffjord])

        self.gradnorm_queue = Queue()
        self.gradnorm_queue.add(3000)

    def forward(self, h, x, edge_index, batch):

        x = remove_mean(x)
        if self.aug_data:
            x = random_rotation(x).detach()

        h, log_qh_x = self.dequantizer(h, x, edge_index, batch)
        h = torch.cat([h['categorical'], h['integer']], dim=1)

        xh = torch.cat([x, h], dim=1)
        z, delta_logp, reg_term = self.flow(xh, edges=edge_index, batch=batch)
        z_x, z_h = z[:, 0:self.n_dims].clone(), z[:, self.n_dims:].clone()

        _, hist = torch.unique(batch, return_counts=True)  #this should be the number of nodes in each batch but needs to be revised.
        hist = hist.to(x.device).to(torch.long)

        log_pN = self.nodes_dist.log_prob(hist)

        log_pz = self.prior(z_x, z_h, batch)
        assert log_pz.size() == delta_logp.size()
        log_px = (log_pz + delta_logp - log_qh_x + log_pN).mean()  # Average over batch.

        self.mean_abs_z = torch.mean(torch.abs(z)).item()
        self.mean_reg = reg_term.mean().item()

        return log_px, self.ode_reg*reg_term.mean()

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Config/Model/Dataset
#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        wandb_id = wandb.run.id
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(wandb_id)+'.pt'
    # Set Backends
    torch.backends.cudnn.deterministic = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

def load(cfg):
    args = cfg.load
    # pdb.set_trace() #Dataset (1/2)
    dataset, _, _, _, train_dl, val_dl, test_dl = qm9_gen_dataloaders(
        featurization =  'dict',
        adjacency =  'full',
        split =  args['split'],
        batch_size = args['batch_size'],
    )

    model_kwargs = OmegaConf.to_container(cfg.model)
    # pdb.set_trace() #Model
    model = Model(
        context_dim = 0,
        in_channels = 6, 
        particle_dim = 3,
        augment_data = model_kwargs['aug_data'],
        hidden_channels = model_kwargs['hidden_channels'],
        hidden_layers = model_kwargs['hidden_layers'],
        aggregation = model_kwargs['aggregation'],
        attention = model_kwargs['attention'],
        tanh = model_kwargs['tanh'],
        condition_time = model_kwargs['condition_time'],
        ode_reg = model_kwargs['ode_reg'],
        trace = model_kwargs['trace'],
    )

    if os.path.exists(args['checkpoint_path']) and args['load_checkpoint']:
        checkpoint = torch.load(cfg.load['checkpoint_path'])
        model.load_state_dict(checkpoint['model_state_dict'])
    return model, train_dl, val_dl, test_dl


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Train/Validate/Test
#----------------------------------------------------------------------------------------------------------------------------------------------------

def train(cfg, data, model, optimizer):
    model.train()
    optimizer.zero_grad()
    log_px, reg_term = model(data.x, data.pos, data.edge_index, data.batch)
    loss = -log_px + reg_term
    loss.backward()
    grad_norm = gradient_clipping(model.flow, model.gradnorm_queue) if cfg.train['clip_grad'] else 0
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate(cfg, data, model):
    model.eval()
    log_px, _ = model(data.x, data.pos, data.edge_index, data.batch)
    loss = -log_px
    return loss.item()

@torch.no_grad()
def test(cfg, data, model):
    model.eval()
    log_px, _ = model(data.x, data.pos, data.edge_index, data.batch)
    loss = -log_px
    return loss.item()


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, model, train_dl, val_dl):
    args = cfg.train

    optimizer = optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=args['wd'], amsgrad=True)

    model = model.to(cfg.setup['device'])

    best = 1e8
    for epoch in range(args['epochs']):

        model.train()
        model.flow.set_trace('hutch')
        train_loss, count = 0, 0
        start = time.time()
        for i,data in enumerate(train_dl):
            # pdb.set_trace() #Dataset (2/2)
            data = data.to(cfg.setup['device'])
            batch_loss = train(cfg, data, model, optimizer)

            batch_size = data.batch.max()
            train_loss += batch_loss #* batch_size
            count += 1 #batch_size
            if i%1 == 0:
                print(f'Train({epoch}) | batch({i:03d}) | loss({batch_loss:.4f}) | reg({model.mean_reg:.2f})')
        end = time.time()
        train_loss = train_loss/count
        
        model.eval()
        model.flow.set_trace('hutch')
        val_loss, count = 0, 0
        for i,data in enumerate(val_dl): 
            data = data.to(cfg.setup['device'])
            batch_loss = validate(cfg, data, model)

            batch_size = data.batch.max()
            val_loss += batch_loss #* batch_size
            count += 1 #batch_size
            if i%1 == 0:
                print(f'Valid({epoch}) | batch({i:03d}) | loss({batch_loss:.4f})')
        val_loss = val_loss/count

        perf_metric = val_loss #your performance metric here
        lr = optimizer.param_groups[0]['lr']

        if perf_metric < best:
            best = perf_metric
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': lr,
                'loss': val_loss,
                },
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'val_loss':val_loss,
            'best':best,
            'lr':lr,
            'time':end-start})
        print(f'Epoch({epoch}) '
            f'| train({train_loss:.4f}) '
            f'| val({val_loss:.4f}) '
            f'| lr({lr:.2e}) '
            f'| best({best:.4f}) '
            f'| time({end-start:.4f})'
            f'\n')

        if bad_itr>args['patience']:
            break

    return best

#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/stable-equiv/stable_enf/", config_name="qm9")
def run_qm9(cfg):
    """
    Execute run saving details to wandb server.
    """
    # Setup Weights and Bias
    wandb.config = OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True
    )
    wandb.init(entity='',
                project='stable-equiv',
                mode='disabled',
                name='semistable-enflows-qm9',
                dir='/root/workspace/out/',
                tags=['gen-qm9', 'enflows', 'semistable'],
                config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
    )
    
    # Execute
    setup(cfg)
    # pdb.set_trace() #Args
    print(OmegaConf.to_yaml(cfg))
    model, train_dl, val_dl, test_dl = load(cfg)
    print(model)
    if cfg.setup['train']:
        run_training(cfg, model, train_dl, val_dl)

    checkpoint = torch.load(cfg.load['checkpoint_path'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(cfg.setup['device'])
    test_loss, count = 0, 0
    for data in test_dl:
        data.to(cfg.setup['device'])
        batch_loss = test(cfg, data, model)

        batch_size = data.batch.max()
        test_loss += batch_loss #* batch_size
        count += 1 #batch_size
    test_loss = test_loss/count

    print(f'\ntest({test_loss})')
    wandb.log({'test_loss':test_loss})
    # Terminate
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_qm9()