#!/usr/bin/python3
import os, json, sys
import time

import hydra
from omegaconf import OmegaConf
import wandb

import numpy as np
import random
import torch
from torch.nn import LeakyReLU, Linear, Module, ModuleList, Parameter, ReLU
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.utils import degree, to_undirected

from torch_geometric.nn import BatchNorm, GCNConv, GCN2Conv
from sct_gnn import SCT, SCT_Resid
from sct_gnn.smoothness import NodeFeatureSmoothness
from sct_gnn.kernel_vectors import kernel_vectors, ker_lapl
sys.path.append('/root/workspace/sct-gnn/baselines/')
from egnn import EGNNConv

from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
"""
OGB Arxiv Dataset
"""

import torch.nn as nn
class SReLU(nn.Module):
    """Shifted ReLU"""

    def __init__(self, nc, bias):
        super(SReLU, self).__init__()
        self.srelu_bias = nn.Parameter(torch.Tensor(nc,))
        self.srelu_relu = nn.ReLU(inplace=True)
        nn.init.constant_(self.srelu_bias, bias)

    def forward(self, x):
        return self.srelu_relu(x - self.srelu_bias) + self.srelu_bias

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Models
#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_architecture(Module):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__()
        self.fc_enc = Linear(in_channels,hidden_channels)
        self.gcn_list = ModuleList([])
        self.fc_dec = Linear(hidden_channels,out_channels)
        if act_fn == 'relu':
            self.act = ReLU()
        elif act_fn == 'srelu':
            self.act = SReLU(hidden_channels, -1)
        elif act_fn == 'leaky':
            self.act = LeakyReLU()
        self.dropout = dropout
        self.layers = hidden_layers
        self.reg = 0
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters())
        pass

    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        for i in range(self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            _layers.append(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_bias_architecture(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.bias_list = ModuleList([])
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters()) + list(self.bias_list.parameters()) + list(self.bn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters()) 
        pass

class gcns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        gcn_list = [ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=False) for _ in range(hidden_layers-1) ]
        gcn_list.append(GCNConv(hidden_channels,out_channels,normalize=False,cached=True,bias=False))
        self.gcn_list = ModuleList(gcn_list)
        bias_list = [ SCT(ker_vecs.shape[0], hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers-1) ]
        bias_list.append(SCT(ker_vecs.shape[0], out_channels, ker_vecs=ker_vecs, indicators=None, cached=True) )
        self.bias_list = ModuleList(bias_list)
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers-1) ]+[BatchNorm(out_channels)])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=ax, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            if i<self.layers-1:
                x = x + _layers[-1]
            _layers.append(x)

        self._layers = _layers
        return x

class gcniis(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=None, cached=True) for i in range(hidden_layers) ])
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers) ])
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers) ])
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        self.beta = config['beta']
        self.theta = config['theta']

        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])

        self.loss_weight = 1e-3
        self.weight_standard = Parameter(torch.eye(hidden_channels), requires_grad=False)
        self.weight_first_layer = Parameter(torch.eye(hidden_channels) * np.sqrt(config['c_max']), requires_grad=False)

        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight, beta=self.beta, residual_weight=self.theta)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        if self.training:
            loss_orthogonal = 0.
            loss_orthogonal += torch.norm(self.gcn_list[0].weight - self.weight_first_layer)
            for i in range(1, self.layers):
                loss_orthogonal += torch.norm(self.gcn_list[i].weight - self.weight_standard)

            self.reg =  self.loss_weight * loss_orthogonal
        return x

#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        run_id = wandb.run.id
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(run_id)+'.pt'
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

def train(model, data, train_idx, optimizer):
    model.train()

    optimizer.zero_grad()

    out = model(data.x, data.edge_index, data.edge_weight)[train_idx]
    pred = F.log_softmax(out, dim=1)

    loss = F.nll_loss(pred, data.y.squeeze(1)[train_idx])
    loss.backward()
    optimizer.step()

    return loss.item()



@torch.no_grad()
def test(model, data, y_true,split_idx, evaluator):
    model.eval()

    out = model(data.x, data.edge_index, data.edge_weight)
    out = F.log_softmax(out, dim=1)
    y_pred = out.argmax(dim=-1, keepdim=True)

    train_acc = evaluator.eval({
        'y_true': y_true[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': y_true[split_idx['valid']],
        'y_pred': y_pred[split_idx['valid']],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': y_true[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })['acc']

    return train_acc, valid_acc, test_acc


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_folds(cfg):
    args = cfg.train

    dataset = PygNodePropPredDataset(root='/root/workspace/data',name='ogbn-arxiv')
    split_idx = dataset.get_idx_split()
    data = dataset[0]
    data = data.to(cfg.setup['device'])
    train_idx = split_idx['train'].to(cfg.setup['device'])
    data.edge_index = to_undirected(data.edge_index, num_nodes=data.num_nodes)
    data = T.GCNNorm()(data)
    evaluator = Evaluator(name='ogbn-arxiv')
    acc_list = []
    for run in range(10):
        model_name = cfg.model['name']
        model = globals().get(model_name)

        dropout = cfg.model['dropout']
        hidden_channels = cfg.model['hidden_channels']
        hidden_layers = cfg.model['hidden_layers']
        model = model(dataset.num_features, hidden_channels, hidden_layers, dataset.num_classes, dropout, cfg.model['act_fn'], cfg.model, data)
        model.to(cfg.setup['device'])

        optimizer = optim.Adam([{'params':model.graph_params, 'weight_decay':cfg.train['wd1']},
                    {'params':model.fc_params, 'weight_decay':cfg.train['wd2']}],
                    lr=cfg.train['lr'])

        bad_counter = 0
        best_val = 0
        final_test_acc = 0
        for epoch in range(1, 1 + args.epochs):
            loss = train(model, data, train_idx, optimizer)
            result = test(model, data, data.y,split_idx, evaluator)
            train_acc, valid_acc, test_acc = result
            if epoch % 1 == 0:
                train_acc, valid_acc, test_acc = result
                print(f'Run: {run + 1:02d}, '
                      f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}, '
                      f'Train: {100 * train_acc:.2f}%, '
                      f'Valid: {100 * valid_acc:.2f}% '
                      f'Test: {100 * test_acc:.2f}%')
            if valid_acc > best_val:
                best_val = valid_acc
                final_test_acc = test_acc
                bad_counter = 0
            else:
                bad_counter += 1

            wandb.log({'epoch':epoch,
                'train_acc':train_acc,
                'val_acc':valid_acc,
                'test_acc_epoch':test_acc,
                })

            if bad_counter == args.patience:
                break
        acc_list.append(final_test_acc*100)
        print(run+1,':',acc_list[-1])
    acc_list=torch.tensor(acc_list)
    print(f'Avg Test: {acc_list.mean():.2f} ± {acc_list.std():.2f}')
    wandb.log({'test_acc':acc_list.mean(),
        'test_std':acc_list.std(),
        })
    

#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/sct-gnn/config/", config_name="obgn-arxiv")
def run_arxiv(cfg):
    """
    Execute run saving details to wandb server.
    """
    # Setup Weights and Bias
    wandb.config = OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True
    )
    wandb.init(entity='',
                project='sct-gnn',
                name='sct',
                tags=['test', 'gcn-conv'])
    
    # Execute

    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    run_folds(cfg)
    # Terminate
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_arxiv()
