#!/usr/bin/python3
import os
import time

import hydra
from omegaconf import OmegaConf
import wandb

import numpy as np
import torch
from torch.nn import Module
import torch.optim as optim

from stable_equiv import lj13_dataloaders
from _utils import (
    remove_mean,
)
from _prior import PositionPrior
from _dynamics import EGNN_dynamics
from _flows import FFJORD

"""
LJ13 Dataset
"""

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Model
#----------------------------------------------------------------------------------------------------------------------------------------------------

class Model(Module):
    def __init__(self, particle_count, particle_dim, hidden_channels, hidden_layers):
        super(Model, self).__init__()

        self.prior = PositionPrior()

        net_dynamics = EGNN_dynamics(
            n_particles=particle_count,
            n_dimension=particle_dim,
            hidden_nf=hidden_channels,
            act_fn=torch.nn.SiLU(),
            n_layers=hidden_layers,
            recurrent=True,
            tanh=True,
            attention=True,
            condition_time=True,
            agg='sum'
        )

        self.flow = FFJORD(
            net_dynamics,
            trace_method='hutch',
            hutch_noise='bernoulli',
            ode_regularization=0
        )

    def forward(self, x):
        x = remove_mean(x)

        z, delta_logp, reg_term = self.flow(x)

        log_pz = self.prior(z)
        log_px = (log_pz + delta_logp.view(-1)).mean()
        return log_px, reg_term

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Config/Model/Dataset
#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        rand_id = np.random.randint(0,1e3)
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(rand_id)+'.pt'
    # Set Backends
    torch.backends.cudnn.deterministic = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

def load(cfg):
    args = cfg.load
    dataset, _, _, _, train_dl, val_dl, test_dl = lj13_dataloaders(
        split =  args['split'],
        batch_size = args['batch_size'],
    )

    model_kwargs = OmegaConf.to_container(cfg.model)
    model = Model(
        particle_count = dataset.num_nodes,
        particle_dim = dataset.pos_dim,
        hidden_channels = model_kwargs['hidden_channels'],
        hidden_layers = model_kwargs['hidden_layers'],
    )

    if os.path.exists(args['checkpoint_path']) and args['load_checkpoint']:
        checkpoint = torch.load(cfg.load['checkpoint_path'])
        model.load_state_dict(checkpoint['model_state_dict'])
    return model, train_dl, val_dl, test_dl


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Train/Validate/Test
#----------------------------------------------------------------------------------------------------------------------------------------------------

def train(cfg, data, model, optimizer):
    model.train()
    optimizer.zero_grad()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, reg_term = model(pos)
    loss = -log_px + reg_term.mean()
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate(cfg, data, model):
    model.eval()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, _ = model(pos)
    loss = -log_px
    return loss.item()

@torch.no_grad()
def test(cfg, data, model):
    model.eval()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, _ = model(pos)
    loss = -log_px
    return loss.item()


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, model, train_dl, val_dl):
    args = cfg.train

    optimizer = optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=args['wd'], amsgrad=True)

    model = model.to(cfg.setup['device'])

    best = 1e8
    for epoch in range(args['epochs']):

        model.train()
        model.flow.set_trace('hutch')
        train_loss, count = 0, 0
        start = time.time()
        for i,data in enumerate(train_dl):
            data = data.to(cfg.setup['device'])
            batch_loss = train(cfg, data, model, optimizer)

            batch_size = data.batch.max().item() + 1
            train_loss += batch_loss * batch_size
            count += batch_size
            if i%1 == 0:
                print(f'Train({epoch}) | batch({i:03d}) | loss({batch_loss:.4f})')
        end = time.time()
        train_loss = train_loss/count
        
        model.eval()
        model.flow.set_trace('exact')
        val_loss, count = 0, 0
        for i,data in enumerate(val_dl): 
            data = data.to(cfg.setup['device'])
            batch_loss = validate(cfg, data, model)

            batch_size = data.batch.max().item() + 1
            val_loss += batch_loss * batch_size
            count += batch_size
            if i%1 == 0:
                print(f'Valid({epoch}) | batch({i:03d}) | loss({batch_loss:.4f})')
        val_loss = val_loss/count

        perf_metric = val_loss #your performance metric here
        lr = optimizer.param_groups[0]['lr']

        if perf_metric < best:
            best = perf_metric
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': lr,
                'loss': val_loss,
                },
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'val_loss':val_loss,
            'best':best,
            'lr':lr,
            'time':end-start})
        print(f'Epoch({epoch}) '
            f'| train({train_loss:.4f}) '
            f'| val({val_loss:.4f}) '
            f'| lr({lr:.2e}) '
            f'| best({best:.4f}) '
            f'| time({end-start:.4f})'
            f'\n')

        if bad_itr>args['patience']:
            break

    return best

#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/stable-equiv/baselines/unstable_enflows/", config_name="lj13")
def run_lj13(cfg):
    """
    Execute run saving details to wandb server.
    """
    # Setup Weights and Bias
    wandb.config = OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True
    )
    wandb.init(entity='',
                project='stable-equiv',
                mode='disabled',
                name='unstable-enflows-lj13',
                dir='/root/workspace/out/',
                tags=['lj13', 'enflows', 'unstable'],
                config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
    )
    
    # Execute
    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    model, train_dl, val_dl, test_dl = load(cfg)
    if cfg.setup['train']:
        run_training(cfg, model, train_dl, val_dl)

    checkpoint = torch.load(cfg.load['checkpoint_path'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(cfg.setup['device'])
    test_loss, count = 0, 0
    for data in test_dl:
        data.to(cfg.setup['device'])
        batch_loss = test(cfg, data, model)

        batch_size = data.batch.max().item() + 1
        test_loss += batch_loss * batch_size
        count += batch_size
    test_loss = test_loss/count

    print(f'\ntest({test_loss})')
    wandb.log({'test_loss':test_loss})
    # Terminate
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_lj13()
