#!/usr/bin/python3
import os
import time

import hydra
from omegaconf import OmegaConf
import wandb

import numpy as np
import pandas as pd
import torch
from torch.nn import Module
import torch.optim as optim
import sys
sys.path.insert(1, '/root/workspace/stable-equiv/stable_equiv')
from dw4_dataset import dw4_dataloaders
from _utils import (
    remove_mean,
)
from _prior import PositionPrior
from _dynamics import EGNN_dynamics
from _flows import FFJORD

"""
DW4 Dataset
"""

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Model
#----------------------------------------------------------------------------------------------------------------------------------------------------

class Model(Module):
    def __init__(self, particle_count, particle_dim, hidden_channels, hidden_layers, normalize, norm_const, normalize_type, reg_para, ode_solver, ode_mesh, reg_clip):
        super(Model, self).__init__()

        self.prior = PositionPrior()

        net_dynamics = EGNN_dynamics(
            n_particles=particle_count,
            n_dimension=particle_dim,
            hidden_nf=hidden_channels,
            act_fn=torch.nn.SiLU(),
            n_layers=hidden_layers,
            recurrent=True,
            tanh=True,
            attention=True,
            condition_time=True,
            agg='sum',
            normalize=normalize,
            norm_const = norm_const,
            normalize_type=normalize_type,
            reg_para=reg_para,
            reg_clip = reg_clip
        )

        self.flow = FFJORD(
            net_dynamics,
            trace_method='hutch',
            hutch_noise='bernoulli',
            ode_regularization=0, 
            ode_solver=ode_solver, 
            ode_mesh=ode_mesh
        )

    def forward(self, x):
        x = remove_mean(x)
        z, delta_logp, _ = self.flow(x)

        log_pz = self.prior(z)
        log_px = (log_pz + delta_logp.view(-1)).mean()
        return log_px, self.flow.dynamics_reg

    def gradient_listener(self):
        pointer = self.flow.odefunc.dynamics.egnn
        log_vals = {}
        for i in range(0, pointer.n_layers):
            inner_prod = None
            for k in pointer._modules["gcl_%d" % i].get_listener():
                if k not in ['h[l]', 'r_ij', 'm_ij', 'phi_x', 'x^[l+1]', 'phi_h', 'h[l+1]']:
                    log_vals.update({"gcl_%d_%s" % (i, k): pointer._modules["gcl_%d" % i].listener[k]})
                elif inner_prod is None and k in ['m_ij', 'phi_x']:
                    inner_prod = torch.norm(pointer._modules["gcl_%d" % i].listener['phi_x'].T @ pointer._modules["gcl_%d" % i].listener['m_ij'])
                log_vals.update({"gcl_%d_%s" % (i, 'inner_prod'): inner_prod})

        log_vals.update({'nfe':self.flow.odefunc.num_evals,
                         'max_d_ldj':self.flow.max_d_ldj.item(),
                         'max_d_x':self.flow.max_d_x.item(),
                         'max_max_d_ldj':self.flow.max_max_d_ldj.item(),
                         'max_max_d_x':self.flow.max_max_d_x.item(),
                         'jac_norm':self.flow.odefunc.jacobian.norm().item(),
        })
        wandb.log(log_vals)

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Config/Model/Dataset
#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        rand_id = np.random.randint(0,1e3)
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(rand_id)+'.pt'
    # Set Backends
    torch.backends.cudnn.deterministic = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

def load(cfg):
    args = cfg.load
    dataset, _, _, _, train_dl, val_dl, test_dl = dw4_dataloaders(
        split =  args['split'],
        batch_size = args['batch_size'],
    )

    model_kwargs = OmegaConf.to_container(cfg.model)
    model = Model(
        particle_count = dataset.num_nodes,
        particle_dim = dataset.pos_dim,
        hidden_channels = model_kwargs['hidden_channels'],
        hidden_layers = model_kwargs['hidden_layers'],
        normalize = model_kwargs['normalize'],
        norm_const = model_kwargs['norm_const'],
        normalize_type = model_kwargs['normalize_type'],
        reg_para = model_kwargs['reg_para'], 
        ode_solver=model_kwargs['ode_solver'], 
        ode_mesh=model_kwargs['ode_mesh'],
        reg_clip=model_kwargs['reg_clip'],
    )

    if os.path.exists(args['checkpoint_path']) and args['load_checkpoint']:
        checkpoint = torch.load(cfg.load['checkpoint_path'])
        model.load_state_dict(checkpoint['model_state_dict'])
    return model, train_dl, val_dl, test_dl


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Train/Validate/Test
#----------------------------------------------------------------------------------------------------------------------------------------------------

def train(cfg, data, model, optimizer):
    model.train()
    optimizer.zero_grad()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, reg_term = model(pos)
    loss = -log_px + cfg.train['reg_weight']*reg_term.mean()
    loss.backward()
    model.gradient_listener()
    optimizer.step()
    return loss.item(), reg_term.mean().item()

@torch.no_grad()
def validate(cfg, data, model):
    model.eval()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, reg_term = model(pos)
    loss = -log_px
    return loss.item(), reg_term.mean().item()

@torch.no_grad()
def test(cfg, data, model):
    model.eval()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, _ = model(pos)
    loss = -log_px
    return loss.item()


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, model, train_dl, val_dl):
    args = cfg.train

    optimizer = optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=args['wd'], amsgrad=True)

    model = model.to(cfg.setup['device'])

    best = 1e8
    for epoch in range(args['epochs']):

        model.train()
        model.flow.set_trace('hutch')
        train_loss, train_reg_dynamics, count = 0, 0, 0
        start = time.time()
        for i,data in enumerate(train_dl):
            data = data.to(cfg.setup['device'])
            batch_loss, batch_reg_dynamics = train(cfg, data, model, optimizer)

            batch_size = data.batch.max()
            train_loss += batch_loss * batch_size
            train_reg_dynamics += batch_reg_dynamics * batch_size
            count += batch_size
            if i%1 == 0:
                print(f'Train({epoch}) '
                      f'| batch({i:03d}) '
                      f'| loss({batch_loss:.4f}) '
                      f'| nfe({model.flow.odefunc.num_evals:04d}) '
                      f'| norm({cfg.model.normalize}) - reg({cfg.train.reg_weight==0.0})'
                      f' reg({model.flow.dynamics_reg.mean().item():05.2f}) '
                      f'| time({time.time()-start:05.2f})')
        end = time.time()
        train_loss = train_loss/count
        train_reg_dynamics = train_reg_dynamics/count

        model.eval()
        model.flow.set_trace('exact')
        val_loss, val_reg_dynamics, count = 0, 0, 0
        for i,data in enumerate(val_dl): 
            data = data.to(cfg.setup['device'])
            batch_loss, batch_reg_dynamics = validate(cfg, data, model)

            batch_size = data.batch.max()
            val_loss += batch_loss * batch_size
            val_reg_dynamics += batch_reg_dynamics * batch_size
            count += batch_size
            if i%1 == 0:
                print(f'Valid({epoch}) '
                      f'| batch({i:03d}) '
                      f'| loss({batch_loss:.4f}) '
                      f'| nfe({model.flow.odefunc.num_evals:04d}) '
                      f'| reg(--.--) '
                      f'| time(--.--)')
        val_loss = val_loss/count
        val_reg_dynamics = val_reg_dynamics/count

        perf_metric = val_loss #your performance metric here
        lr = optimizer.param_groups[0]['lr']

        if perf_metric < best:
            best = perf_metric
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': lr,
                'loss': val_loss,
                },
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'val_loss':val_loss,
            'best':best,
            'lr':lr,
            'time':end-start,
            'train_reg_dynamics': train_reg_dynamics,
            'val_reg_dynamics': val_reg_dynamics})
        print(f'Epoch({epoch}) '
            f'| train({train_loss:.4f}) '
            f'| val({val_loss:.4f}) '
            f'| lr({lr:.2e}) '
            f'| best({best:.4f}) '
            f'| time({end-start:.4f})'
            f'\n')

        if bad_itr>args['patience']:
            break

    return best

#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/stable-equiv/stable_enf/", config_name="dw4")
def run_dw4(cfg):
    """
    Execute run saving details to wandb server.
    """
    # Setup Weights and Bias
    wandb.config = OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True
    )
    mode = 'online' if cfg.setup['sweep'] else 'disabled'
    wandb.init(entity='',
                project='stable-equiv',
                mode=mode,
                name='stable-dLdtheta'.format(cfg.train.lr, 
                                                                                            cfg.model.normalize,
                                                                                            cfg.model.normalize_type,
                                                                                            cfg.train.reg_weight > 0.0,
                                                                                            cfg.train.reg_weight,
                                                                                            cfg.load.split,
                                                                                            cfg.load.batch_size,
                                                                                            cfg.train.remark),
                dir='/root/workspace/out/',
                tags=['dw4', 'enflows', 'stable', 'listener-v2'],
                config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
    )
    
    # Execute
    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    model, train_dl, val_dl, test_dl = load(cfg)
    if cfg.setup['train']:
        run_training(cfg, model, train_dl, val_dl)

    checkpoint = torch.load(cfg.load['checkpoint_path'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(cfg.setup['device'])
    test_loss, count = 0, 0
    for data in test_dl:
        data.to(cfg.setup['device'])
        batch_loss = test(cfg, data, model)

        batch_size = data.batch.max()
        test_loss += batch_loss * batch_size
        count += batch_size
    test_loss = test_loss/count

    print(f'\ntest({test_loss})')
    wandb.log({'test_loss':test_loss})
    # Terminate
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_dw4()