#!/usr/bin/python3
import json
import os
import sys
import time

import hydra
from omegaconf import OmegaConf
import wandb

import numpy as np
import random
import torch
from torch.nn import Dropout, GELU, LayerNorm, LeakyReLU, Linear, Module, ModuleList, Parameter, Sequential, Tanh
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.seed import seed_everything

from torch_geometric.nn import BatchNorm, GCNConv, GCN2Conv
from sct_gnn import SCT, SCT_Resid, SCT_Scalar
from sct_gnn.smoothness import NodeFeatureSmoothness
from sct_gnn.kernel_vectors import kernel_vectors, ker_lapl
sys.path.append('/root/workspace/sct-gnn/baselines/')
from egnn import EGNNConv


from datasets.heterophilious_dataset import heterophilious_dataloaders


"""
Hetereophilious Dataset
    Split: 20 Train per class
    10-Fold Cross Validation
"""


import torch.nn as nn
class SReLU(nn.Module):
    """Shifted ReLU"""

    def __init__(self, nc, bias):
        super(SReLU, self).__init__()
        self.srelu_bias = nn.Parameter(torch.Tensor(nc,))
        self.srelu_relu = nn.ReLU(inplace=True)
        nn.init.constant_(self.srelu_bias, bias)

    def forward(self, x):
        return self.srelu_relu(x - self.srelu_bias) + self.srelu_bias

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Models
#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_architecture(Module):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__()
        self.fc_enc = Linear(in_channels,hidden_channels)
        self.gcn_list = ModuleList([])
        self.fc_dec = Linear(hidden_channels,out_channels)
        if act_fn == 'relu':
            self.act = ReLU()
        elif act_fn == 'srelu':
            self.act = SReLU(hidden_channels, -1)
        elif act_fn == 'leaky':
            self.act = LeakyReLU()
        self.dropout = dropout
        self.layers = hidden_layers
        self.reg = 0
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters())
        pass

    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        for i in range(self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            _layers.append(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class gcn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=False) for _ in range(hidden_layers) ])
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)

            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            y = self.bn_list[i](ax)
            x = self.act(y)
            x = self.act(ax)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class gcnii(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            y = self.bn_list[i](ax)
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_bias_architecture(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.bias_list = ModuleList([])
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters()) + list(self.bias_list.parameters()) + list(self.bn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters()) 
        pass

class gcns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        gcn_list = [ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=False) for _ in range(hidden_layers) ]
        self.gcn_list = ModuleList(gcn_list)
        bias_list = [ SCT(ker_vecs.shape[0], hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers) ]
        self.bias_list = ModuleList(bias_list)
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers)])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=ax, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class gcniis(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=None, cached=True) for i in range(hidden_layers) ])
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers) ])
        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for _ in range(hidden_layers) ])
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        self.beta = config['beta']
        self.theta = config['theta']

        self.bn_enc = BatchNorm(hidden_channels)
        self.bn_list = ModuleList([ BatchNorm(hidden_channels) for _ in range(hidden_layers) ])

        self.loss_weight = 1e-3
        self.weight_standard = Parameter(torch.eye(hidden_channels), requires_grad=False)
        self.weight_first_layer = Parameter(torch.eye(hidden_channels) * np.sqrt(config['c_max']), requires_grad=False)

        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_enc(x)
        x = self.bn_enc(x)
        x = self.act(x)
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight, beta=self.beta, residual_weight=self.theta)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            b = b/b.std(dim=0)
            y = self.bn_list[i](ax)+b
            x = self.act(y)
            x = x + _layers[-1]
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        if self.training:
            loss_orthogonal = 0.
            loss_orthogonal += torch.norm(self.gcn_list[0].weight - self.weight_first_layer)
            for i in range(1, self.layers):
                loss_orthogonal += torch.norm(self.gcn_list[i].weight - self.weight_standard)

            self.reg =  self.loss_weight * loss_orthogonal
        return x

#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        run_id = wandb.run.id
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(run_id)+'.pt'
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------


def train(cfg, data, model, optimizer, mask):
    model.train()
    optimizer.zero_grad()
    output = model(data.x,data.edge_index,data.edge_weight)
    # output = F.log_softmax(output, dim=1)
    loss = F.cross_entropy(output[mask], data.y[mask])
    loss.backward()
    optimizer.step()

    pred = output[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
    return loss.item(), acc


@torch.no_grad()
def validate(cfg, data, model, mask):
    model.eval()
    output = model(data.x,data.edge_index,data.edge_weight)
    # output = F.log_softmax(output, dim=1)
    loss = F.cross_entropy(output[mask], data.y[mask])

    pred = output[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
    return loss.item(), acc


@torch.no_grad()
def test(cfg, data, model, mask):
    model.eval()
    output = model(data.x,data.edge_index,data.edge_weight)
    # output = F.log_softmax(output, dim=1)
    loss = F.cross_entropy(output[mask], data.y[mask])

    pred = output[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
    return loss.item(), acc

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, data, model, tr_mask, val_mask):
    args = cfg.train
    optimizer = optim.AdamW([{'params':model.graph_params, 'weight_decay':cfg.train['wd1']},
            {'params':model.fc_params, 'weight_decay':cfg.train['wd2']}],
            lr=cfg.train['lr'])


    best = 1e8
    for epoch in range(args['epochs']):

        model.train()
        start = time.time()
        train_loss, train_acc = train(cfg, data, model, optimizer, tr_mask)
        end = time.time()
        val_loss, val_acc = validate(cfg, data, model, val_mask)

        perf_metric = (1-val_acc) #your performance metric here

        if perf_metric < best:
            best = perf_metric
            best_acc = val_acc
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'loss': val_loss,
                },
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'train_acc':train_acc,
            'val_loss':val_loss,
            'val_acc':val_acc,
            'best':best,
            'time':end-start})
        print(f'Epoch({epoch}) '
            f'| train({100*train_acc:.2f},{train_loss:.4f}) '
            f'| val({100*val_acc:.2f},{val_loss:.4f}) '
            f'| best({best:.4f}) '
            f'| time({end-start:.4f})'
            f'| lr({optimizer.param_groups[0]["lr"]:.2e})'
            )

        if bad_itr>args['patience']:
            break

    return best_acc

#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_folds(cfg):
    kfolds = 10 # we report the mean accuracy of 100 runs with random node ordering
    val_accs = [ -1 for _ in range(kfolds) ]
    test_accs = [ -1 for _ in range(kfolds) ]

    original_path = cfg.load['checkpoint_path']

    for k in range(kfolds):
        cfg['load']['checkpoint_path']=original_path[:-3]+f'_fold_{k}.pt'
        # Load
        dataset, _, _, _, _ = heterophilious_dataloaders(
        name=cfg.load['dataset'],
        adjacency='sym-norm',

        )
        data = dataset[0]
        data.to(cfg.setup['device'])

        model_name = cfg.model['name']
        model = globals().get(model_name)

        dropout = cfg.model['dropout']
        hidden_channels = cfg.model['hidden_channels']
        hidden_layers = cfg.model['hidden_layers']
        model = model(dataset.num_features, hidden_channels, hidden_layers, dataset.num_classes, dropout, cfg.model['act_fn'], cfg.model, data)
        model.to(cfg.setup['device'])

        train_mask = data.train_mask[:,k]
        val_mask = data.val_mask[:,k]
        test_mask = data.test_mask[:,k]

        total = sum(train_mask) + sum(val_mask) + sum(test_mask)
        print(f'Fold {k} Splits: train({100*sum(train_mask)/total:.2f})'
            f'\tval({100*sum(val_mask)/total:.2f})'
            f'\ttest({100*sum(test_mask)/total:.2f})'
            f'\ttrv({sum(train_mask)+sum(val_mask)})'
        )
        if cfg.setup['train']:
            val_acc = run_training(cfg, data, model, train_mask, val_mask)
            val_accs[k] = val_acc

        # Test
        checkpoint = torch.load(cfg.load['checkpoint_path'])
        model.load_state_dict(checkpoint['model_state_dict'])
        test_loss, test_acc = test(cfg, data, model, test_mask)
        test_accs[k] = test_acc
        
    print({'val_mean':np.mean(val_accs),
        'val_std':np.std(val_accs),
        'test_mean':np.mean(test_accs),
        'test_std':np.std(test_accs)})
    wandb.log({'val_mean':np.mean(val_accs),
        'val_std':np.std(val_accs),
        'test_mean':np.mean(test_accs),
        'test_std':np.std(test_accs)})

    return 1


#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/sct-gnn/config/", config_name="heterophilious")
def run_heterophilious(cfg):
    # Initialize settings to wandb server
    mode = 'online' if cfg.setup['sweep'] else 'disabled'
    wandb.init(
        dir='/root/workspace/out/',
        entity='',
        mode=mode,
        name='prgnn-'+cfg.load['dataset'],
        project='pr-inspired-aggregation',
        tags=['prgnn', cfg.load['dataset']],
        config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
    )
    
    # Execute
    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    run_folds(cfg)
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_heterophilious()
