#!/usr/bin/python3
import os
import time

import hydra
from omegaconf import OmegaConf
import wandb

import torch
from torch.nn import Module
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.transforms as T

from stable_equiv import qm9_gen_dataloaders
from _utils import (
    Queue,
    gradient_clipping,
    random_rotation,
    remove_mean,
)
from _prior import PositionPrior
from _dynamics import EGNN_dynamics
from _flows import FFJORD


"""
QM9 Positional Dataset
"""

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Model
#----------------------------------------------------------------------------------------------------------------------------------------------------

class Model(Module):
    """
    Args
    -----
        Dependent Args
        ~~~~~~~~~~~~~~~
        + particle_count: (int) number of particles in the system
        + particle_dim: (int) dimension of the physical space

        Independent Args
        ~~~~~~~~~~~~~~~~~
        + augment_data: (bool) Wether or not to augment the data with a rotation.
        + hidden_channels: (int) 
        + hidden_layers: (int) 

        Module Args
        ~~~~~~~~~~~
        + aggregation: (str) EGNN aggregation
        + attention: (bool) EGNN attention
        + tanh: (bool) EGNN tanh
        + condition_time: (bool) Dynamics conditioning of time
        + ode_reg: (bool) Dynamics regularization
        + noise: (float) Flow trace estimator noise
        + trace: (bool) Flow trace estimator

    """
    def __init__(self,
        particle_count: int,
        particle_dim: int,
        augment_data: bool,
        hidden_channels: int,
        hidden_layers: int,
        aggregation: str,
        attention: bool,
        tanh: bool,
        condition_time: bool,
        ode_reg: float,
        noise: bool,
        trace: bool,
    ):
        super(Model, self).__init__()

        self.aug_data = augment_data
        self.ode_reg = ode_reg

        self.prior = PositionPrior()

        net_dynamics = EGNN_dynamics(
            n_particles=particle_count,
            n_dimension=particle_dim,
            hidden_nf=hidden_channels,
            act_fn=torch.nn.SiLU(),
            n_layers=hidden_layers,
            recurrent=True,
            tanh=tanh,
            attention=attention,
            condition_time=condition_time,
            agg=aggregation,
        )

        self.flow = FFJORD(
            net_dynamics,
            trace_method=trace,
            hutch_noise=noise,
            ode_regularization=ode_reg,
        )

        self.gradnorm_queue = Queue()
        self.gradnorm_queue.add(1e30)

    def forward(self, x):

        x = remove_mean(x)
        if self.aug_data:
            x = random_rotation(x).detach()

        z, delta_logp, reg_term = self.flow(x)

        log_pz = self.prior(z)
        log_px = (log_pz + delta_logp.view(-1)).mean()
        self.mean_abs_z = torch.mean(torch.abs(z)).item()

        return log_px, self.ode_reg*reg_term.mean()

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Config/Model/Dataset
#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        wandb_id = wandb.run.id
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(wandb_id)+'.pt'
    # Set Backends
    torch.backends.cudnn.deterministic = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

def load(cfg):
    args = cfg.load
    dataset, _, _, _, train_dl, val_dl, test_dl = qm9_gen_dataloaders(
        filter = True,
        adjacency='full',
        split =  args['split'],
        batch_size = args['batch_size'],
    )

    model_kwargs = OmegaConf.to_container(cfg.model)
    model = Model(
        particle_count = dataset.num_nodes,
        particle_dim = dataset.pos_dim,
        augment_data = model_kwargs['aug_data'],
        hidden_channels = model_kwargs['hidden_channels'],
        hidden_layers = model_kwargs['hidden_layers'],
        attention = model_kwargs['attention'],
        tanh = model_kwargs['tanh'],
        aggregation = model_kwargs['aggregation'],
        condition_time = model_kwargs['condition_time'],
        ode_reg = model_kwargs['ode_reg'],
        trace = model_kwargs['trace'],
        noise = model_kwargs['noise'],
    )

    if os.path.exists(args['checkpoint_path']) and args['load_checkpoint']:
        checkpoint = torch.load(cfg.load['checkpoint_path'])
        model.load_state_dict(checkpoint['model_state_dict'])
    return model, train_dl, val_dl, test_dl


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Train/Validate/Test
#----------------------------------------------------------------------------------------------------------------------------------------------------

def train(cfg, data, model, optimizer):
    model.train()
    optimizer.zero_grad()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, reg_term = model(pos)
    loss = -log_px + reg_term
    loss.backward()
    grad_norm = gradient_clipping(model, model.gradnorm_queue) if cfg.train['clip_grad'] else 0
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate(cfg, data, model):
    model.eval()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, _ = model(pos)
    loss = -log_px
    return loss.item()

@torch.no_grad()
def test(cfg, data, model):
    model.eval()
    pos = data.pos.view(data.batch.max()+1, -1, data.pos.shape[-1])
    log_px, _ = model(pos)
    loss = -log_px
    return loss.item()


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, model, train_dl, val_dl):
    args = cfg.train

    optimizer = optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=args['wd'], amsgrad=True)

    model = model.to(cfg.setup['device'])

    best = 1e8
    for epoch in range(args['epochs']):

        model.train()
        model.flow.set_trace('hutch')
        train_loss, count = 0, 0
        start = time.time()
        for i,data in enumerate(train_dl):
            data = data.to(cfg.setup['device'])
            batch_loss = train(cfg, data, model, optimizer)

            batch_size = data.batch.max().item() + 1
            train_loss += batch_loss #* batch_size
            count += 1 #batch_size
            if i%1 == 0:
                print(f'Train({epoch}) | batch({i:03d}) | loss({batch_loss:.4f})')
        end = time.time()
        train_loss = train_loss/count
        
        model.eval()
        model.flow.set_trace('hutch')
        val_loss, count = 0, 0
        for i,data in enumerate(val_dl): 
            data = data.to(cfg.setup['device'])
            batch_loss = validate(cfg, data, model)

            batch_size = data.batch.max().item() + 1
            val_loss += batch_loss #* batch_size
            count += 1 #batch_size
            if i%1 == 0:
                print(f'Valid({epoch}) | batch({i:03d}) | loss({batch_loss:.4f})')
        val_loss = val_loss/count

        perf_metric = val_loss #your performance metric here
        lr = optimizer.param_groups[0]['lr']

        if perf_metric < best:
            best = perf_metric
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': lr,
                'loss': val_loss,
                },
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'val_loss':val_loss,
            'best':best,
            'lr':lr,
            'time':end-start})
        print(f'Epoch({epoch}) '
            f'| train({train_loss:.4f}) '
            f'| val({val_loss:.4f}) '
            f'| lr({lr:.2e}) '
            f'| best({best:.4f}) '
            f'| time({end-start:.4f})'
            f'\n')

        if bad_itr>args['patience']:
            break

    return best

#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/stable-equiv/stable_enf/", config_name="qm9_pos")
def run_qm9_pos(cfg):
    """
    Execute run saving details to wandb server.
    """
    # Setup Weights and Bias
    wandb.config = OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True
    )
    wandb.init(entity='',
                project='stable-equiv',
                mode='disabled',
                name='semistable-enflows-qm9_pos',
                dir='/root/workspace/out/',
                tags=['qm9_pos', 'enflows', 'semistable'],
                config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
    )
    
    # Execute
    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    model, train_dl, val_dl, test_dl = load(cfg)
    if cfg.setup['train']:
        run_training(cfg, model, train_dl, val_dl)

    checkpoint = torch.load(cfg.load['checkpoint_path'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(cfg.setup['device'])
    test_loss, count = 0, 0
    for data in test_dl:
        data.to(cfg.setup['device'])
        batch_loss = test(cfg, data, model)

        batch_size = data.batch.max().item() + 1
        test_loss += batch_loss #* batch_size
        count += 1 #batch_size
    test_loss = test_loss/count

    print(f'\ntest({test_loss})')
    wandb.log({'test_loss':test_loss})
    # Terminate
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_qm9_pos()