#!/usr/bin/python3
import json
import os
import sys
import time

import hydra
from omegaconf import OmegaConf
import wandb

import numpy as np
import random
from sklearn.metrics import roc_auc_score
import torch
from torch.nn import Dropout, GELU, LayerNorm, LeakyReLU, Linear, Module, ModuleList, Parameter, Sequential, Tanh
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.seed import seed_everything

from torch_geometric.nn import BatchNorm, GCNConv, GCN2Conv
from sct_gnn import SCT, SCT_Resid, SCT_Scalar
from sct_gnn.smoothness import NodeFeatureSmoothness
from sct_gnn.kernel_vectors import kernel_vectors, ker_lapl
sys.path.append('/root/workspace/sct-gnn/baselines/')
from egnn import EGNNConv


from datasets.heterophilious_dataset import heterophilious_dataloaders

"""
Hetereophilious Dataset
    Split: 20 Train per class
    10-Fold Cross Validation
"""


import torch.nn as nn
class SReLU(nn.Module):
    """Shifted ReLU"""

    def __init__(self, nc, bias):
        super(SReLU, self).__init__()
        self.srelu_bias = nn.Parameter(torch.Tensor(nc,))
        self.srelu_relu = nn.ReLU(inplace=True)
        nn.init.constant_(self.srelu_bias, bias)

    def forward(self, x):
        return self.srelu_relu(x - self.srelu_bias) + self.srelu_bias

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Models
#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_architecture(Module):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__()
        self.fc_enc = Linear(in_channels,hidden_channels)
        self.gcn_list = ModuleList([])
        self.fc_dec = Linear(hidden_channels,out_channels)
        if act_fn == 'relu':
            self.act = ReLU()
        elif act_fn == 'srelu':
            self.act = SReLU(hidden_channels, -1)
        elif act_fn == 'leaky':
            self.act = LeakyReLU()
        self.dropout = dropout
        self.layers = hidden_layers
        self.reg = 0
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters())
        pass

    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        for i in range(self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            _layers.append(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class gcn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=False) for _ in range(hidden_layers) ])
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)

            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            y = self.bn_list[i](ax)
            x = self.act(y)
            x = self.act(ax)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class gcnii(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            y = self.bn_list[i](ax)
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_bias_architecture(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.bias_list = ModuleList([])
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters()) + list(self.bias_list.parameters()) + list(self.bn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters()) 
        pass

class gcns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        gcn_list = [ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=False) for _ in range(hidden_layers) ]
        self.gcn_list = ModuleList(gcn_list)
        bias_list = [ SCT(ker_vecs.shape[0], hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers) ]
        self.bias_list = ModuleList(bias_list)
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers)])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=ax, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class gcniis(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=None, cached=True) for i in range(hidden_layers) ])
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers) ])
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers) ])
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        self.beta = config['beta']
        self.theta = config['theta']

        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])

        self.loss_weight = 1e-3
        self.weight_standard = Parameter(torch.eye(hidden_channels), requires_grad=False)
        self.weight_first_layer = Parameter(torch.eye(hidden_channels) * np.sqrt(config['c_max']), requires_grad=False)

        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight, beta=self.beta, residual_weight=self.theta)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        if self.training:
            loss_orthogonal = 0.
            loss_orthogonal += torch.norm(self.gcn_list[0].weight - self.weight_first_layer)
            for i in range(1, self.layers):
                loss_orthogonal += torch.norm(self.gcn_list[i].weight - self.weight_standard)

            self.reg =  self.loss_weight * loss_orthogonal
        return x

#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        run_id = wandb.run.id
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(run_id)+'.pt'
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------


def train(cfg, data, model, optimizer, mask):
    model.train()
    optimizer.zero_grad()
    output = model(data.x,data.edge_index,data.edge_weight).squeeze()
    # output = F.log_softmax(output, dim=1)
    loss = F.binary_cross_entropy_with_logits(output[mask], data.y[mask].to(torch.float64))
    loss.backward()
    optimizer.step()

    acc = roc_auc_score(data.y[mask].squeeze().detach().cpu().numpy(), output[mask].squeeze().detach().cpu().numpy())
    return loss.item(), acc


@torch.no_grad()
def validate(cfg, data, model, mask):
    model.eval()
    output = model(data.x,data.edge_index,data.edge_weight).squeeze()
    # output = F.log_softmax(output, dim=1)
    loss = F.binary_cross_entropy_with_logits(output[mask], data.y[mask].to(torch.float64))

    acc = roc_auc_score(data.y[mask].squeeze().detach().cpu().numpy(), output[mask].squeeze().detach().cpu().numpy())
    return loss.item(), acc


@torch.no_grad()
def test(cfg, data, model, mask):
    model.eval()
    output = model(data.x,data.edge_index,data.edge_weight).squeeze()
    # output = F.log_softmax(output, dim=1)
    loss = F.binary_cross_entropy_with_logits(output[mask], data.y[mask].to(torch.float64))

    acc = roc_auc_score(data.y[mask].squeeze().detach().cpu().numpy(), output[mask].squeeze().detach().cpu().numpy())
    return loss.item(), acc

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, data, model, tr_mask, val_mask):
    args = cfg.train
    optimizer = optim.AdamW([{'params':model.graph_params, 'weight_decay':cfg.train['wd1']},
            {'params':model.fc_params, 'weight_decay':cfg.train['wd2']}],
            lr=cfg.train['lr'])


    best = 1e8
    for epoch in range(args['epochs']):

        model.train()
        start = time.time()
        train_loss, train_acc = train(cfg, data, model, optimizer, tr_mask)
        end = time.time()
        val_loss, val_acc = validate(cfg, data, model, val_mask)

        perf_metric = (1-val_acc) #your performance metric here

        if perf_metric < best:
            best = perf_metric
            best_acc = val_acc
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'loss': val_loss,
                },
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'train_acc':train_acc,
            'val_loss':val_loss,
            'val_acc':val_acc,
            'best':best,
            'time':end-start})
        print(f'Epoch({epoch}) '
            f'| train({100*train_acc:.2f},{train_loss:.4f}) '
            f'| val({100*val_acc:.2f},{val_loss:.4f}) '
            f'| best({best:.4f}) '
            f'| time({end-start:.4f})'
            f'| lr({optimizer.param_groups[0]["lr"]:.2e})'
            )

        if bad_itr>args['patience']:
            break

    return best_acc

#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_folds(cfg):
    kfolds = 10 # we report the mean accuracy of 100 runs with random node ordering
    val_accs = [ -1 for _ in range(kfolds) ]
    test_accs = [ -1 for _ in range(kfolds) ]

    original_path = cfg.load['checkpoint_path']

    for k in range(kfolds):
        cfg['load']['checkpoint_path']=original_path[:-3]+f'_fold_{k}.pt'
        # Load
        dataset, _, _, _, _ = heterophilious_dataloaders(
        name=cfg.load['dataset'],
        adjacency='sym-norm',

        )
        data = dataset[0]
        data.to(cfg.setup['device'])

        model_name = cfg.model['name']
        model = globals().get(model_name)

        dropout = cfg.model['dropout']
        hidden_channels = cfg.model['hidden_channels']
        hidden_layers = cfg.model['hidden_layers']
        model = model(dataset.num_features, hidden_channels, hidden_layers, 1, dropout, cfg.model['act_fn'], cfg.model, data)
        model.to(cfg.setup['device'])

        train_mask = data.train_mask[:,k]
        val_mask = data.val_mask[:,k]
        test_mask = data.test_mask[:,k]

        total = sum(train_mask) + sum(val_mask) + sum(test_mask)
        print(f'Fold {k} Splits: train({100*sum(train_mask)/total:.2f})'
            f'\tval({100*sum(val_mask)/total:.2f})'
            f'\ttest({100*sum(test_mask)/total:.2f})'
            f'\ttrv({sum(train_mask)+sum(val_mask)})'
        )
        if cfg.setup['train']:
            val_acc = run_training(cfg, data, model, train_mask, val_mask)
            val_accs[k] = val_acc

        # Test
        checkpoint = torch.load(cfg.load['checkpoint_path'])
        model.load_state_dict(checkpoint['model_state_dict'])
        test_loss, test_acc = test(cfg, data, model, test_mask)
        test_accs[k] = test_acc
        
    print({'val_mean':np.mean(val_accs),
        'val_std':np.std(val_accs),
        'test_mean':np.mean(test_accs),
        'test_std':np.std(test_accs)})
    wandb.log({'val_mean':np.mean(val_accs),
        'val_std':np.std(val_accs),
        'test_mean':np.mean(test_accs),
        'test_std':np.std(test_accs)})

    return 1


#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/sct-gnn/config/", config_name="heterophilious")
def run_heterophilious(cfg):
    # Initialize settings to wandb server
    mode = 'online' if cfg.setup['sweep'] else 'disabled'
    wandb.init(
        dir='/root/workspace/out/',
        entity='',
        mode=mode,
        name='prgnn-'+cfg.load['dataset'],
        project='pr-inspired-aggregation',
        tags=['prgnn', cfg.load['dataset']],
        config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
    )
    
    # Execute
    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    run_folds(cfg)
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_heterophilious()
