#!/usr/bin/python3
import os, json, sys
import time

import hydra
from omegaconf import OmegaConf
import wandb

import numpy as np
import random
import torch
from torch.nn import LeakyReLU, Linear, Module, ModuleList, ReLU
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import WebKB
import torch_geometric.transforms as T

from torch_geometric.nn import GCNConv, GCN2Conv
from sct_gnn import SCT, SCT_Resid, SCT_Scalar
from sct_gnn.smoothness import NodeFeatureSmoothness
from sct_gnn.kernel_vectors import kernel_vectors, ker_lapl
sys.path.append('/root/workspace/sct-gnn/baselines/')
from egnn import EGNNConv

"""
PyG WebKB Dataset
    Split: 48/32/20
    10-Fold Cross Validation
"""

import torch.nn as nn
class SReLU(nn.Module):
    """Shifted ReLU"""

    def __init__(self, nc, bias):
        super(SReLU, self).__init__()
        self.srelu_bias = nn.Parameter(torch.Tensor(nc,))
        self.srelu_relu = nn.ReLU(inplace=True)
        nn.init.constant_(self.srelu_bias, bias)

    def forward(self, x):
        return self.srelu_relu(x - self.srelu_bias) + self.srelu_bias

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Models
#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_architecture(Module):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__()
        self.fc_enc = Linear(in_channels,hidden_channels)
        self.gcn_list = ModuleList([])
        self.fc_dec = Linear(hidden_channels,out_channels)
        if act_fn == 'relu':
            self.act = ReLU()
        elif act_fn == 'srelu':
            self.act = SReLU(hidden_channels, -10)
        elif act_fn == 'leaky':
            self.act = LeakyReLU()
        self.dropout = dropout
        self.layers = hidden_layers
        self.reg = 0
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters())
        pass

    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        for i in range(self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            _layers.append(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_bias_architecture(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.bias_list = ModuleList([])
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters()) + list(self.bias_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters()) 
        pass

class gcns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        indicators, deg, ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, return_all=True) # jb: get ker
        gcn_list = [ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=False) for _ in range(hidden_layers-1) ]
        gcn_list.append(GCNConv(hidden_channels,out_channels,normalize=False,cached=True,bias=False))
        self.gcn_list = ModuleList(gcn_list)
        bias_list = [ SCT(ker_vecs.shape[0], hidden_channels, ker_vecs=ker_vecs, indicators=indicators, cached=True) for _ in range(hidden_layers-1) ]
        bias_list.append(SCT(ker_vecs.shape[0], out_channels, ker_vecs=ker_vecs, indicators=indicators, cached=True) )
        self.bias_list = ModuleList(bias_list)
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)

        for i in range(0,self.layers):
            # x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=ax, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            _layers.append(x)

        self._layers = _layers
        return x

class gcniis(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        indicators, deg, ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, return_all=True, largest=True) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=indicators, cached=True) for i in range(hidden_layers) ])
        self.bias_list = ModuleList([ SCT_Scalar(ker_vecs=ker_vecs, cached=True) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)

        for i in range(0, self.layers):
            # x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        indicators, deg, ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, return_all=True, largest=True) # jb: get ker
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=indicators, cached=True) for _ in range(hidden_layers) ])
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        self.beta = config['beta']
        self.theta = config['theta']

        self.loss_weight = 20
        self.weight_standard = Parameter(torch.eye(hidden_channels), requires_grad=False)
        self.weight_first_layer = Parameter(torch.eye(hidden_channels) * np.sqrt(config['c_max']), requires_grad=False)

        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight, beta=self.beta, residual_weight=self.theta)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        if self.training:
            loss_orthogonal = 0.
            loss_orthogonal += torch.norm(self.gcn_list[0].weight - self.weight_first_layer)
            for i in range(1, self.layers):
                loss_orthogonal += torch.norm(self.gcn_list[i].weight - self.weight_standard)

            self.reg =  self.loss_weight * loss_orthogonal

        return x


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Helper
#----------------------------------------------------------------------------------------------------------------------------------------------------

def index_to_mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool, device=index.device)
    mask[index] = 1
    return mask

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Config/Model/Dataset
#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        run_id = wandb.run.id
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(run_id)+'.pt'
    # Set Backends
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    # Set Seed
    def set_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
    set_seed(args['seed'])
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

def load(cfg):
    args = cfg.load
    # Set Transforms
    transform = T.Compose([T.NormalizeFeatures(), T.GCNNorm()])
    # Load Dataset
    dataset = WebKB(
        root="/root/workspace/data/"+args['dataset'],
        name=args['dataset'],
        transform=transform
    )
    return dataset

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Train/Validate/Test
#----------------------------------------------------------------------------------------------------------------------------------------------------

def train(cfg, data, model, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(data.x,data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask]) + model.reg
    loss.backward()
    optimizer.step()
    
    pred = output[data.train_mask].max(1)[1]
    acc = pred.eq(data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
    return loss.item(), acc


def validate(cfg, data, model):
    model.eval()
    output = model(data.x,data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.val_mask], data.y[data.val_mask])
    
    pred = output[data.val_mask].max(1)[1]
    acc = pred.eq(data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
    return loss.item(), acc


def test(cfg, data, model):
    checkpoint = torch.load(cfg.load['checkpoint_path'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    output = model(data.x,data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.test_mask], data.y[data.test_mask])
    
    pred = output[data.test_mask].max(1)[1]
    acc = pred.eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return loss.item(), acc

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, data, model):
    args = cfg.train
    optimizer = optim.Adam([{'params':model.graph_params, 'weight_decay':cfg.train['wd1']},
                    {'params':model.fc_params, 'weight_decay':cfg.train['wd2']}],
                    lr=cfg.train['lr'])

    best = 1e8
    bad_itr = 0
    for epoch in range(args['epochs']):
        start = time.time()
        train_loss, train_acc = train(cfg, data, model, optimizer)
        val_loss, val_acc = validate(cfg, data, model)

        perf_metric = val_loss

        if perf_metric < best:
            best = perf_metric
            acc = val_acc
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'acc': val_acc},
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        t_time = time.time() - start
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'train_acc':train_acc,
            'val_loss':val_loss,
            'val_acc':val_acc,
            'perf_metric':perf_metric,
            'best':best,
            'time':t_time})
        print(f'Epoch({epoch:03d}) | train({train_loss:.4f},{100*train_acc:.2f}) '
            f'| val({val_loss:.4f},{100*val_acc:.2f}) '+
            f'| perf({perf_metric:.3f}) | time({t_time:.3f})')

        if bad_itr>args['patience']:
            break

    return acc

#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_folds(cfg):
    kfolds = 10
    val_accs = [ None for _ in range(kfolds) ]
    test_accs = [ None for _ in range(kfolds) ]

    dataset = load(cfg)
    data = dataset[0]
    data.to(cfg.setup['device'])
    masks = [data.train_mask, data.val_mask, data.test_mask]

    for k in range(kfolds):
        # Set Model
        model_name = cfg.model['name']
        model = globals().get(model_name)

        dropout = cfg.model['dropout']
        hidden_channels = cfg.model['hidden_channels']
        hidden_layers = cfg.model['hidden_layers']
        model = model(dataset.num_features, hidden_channels, hidden_layers, dataset.num_classes, dropout, cfg.model['act_fn'], cfg.model, dataset[0])
        model.to(cfg.setup['device'])
        print(model)

        # Split Masks
        data.train_mask = index_to_mask(masks[0][:,k], data.num_nodes)
        data.val_mask = index_to_mask(masks[1][:,k], data.num_nodes)
        data.test_mask = index_to_mask(masks[2][:,k], data.num_nodes)
        total = sum(masks[0][:,k]) + sum(masks[1][:,k]) + sum(masks[2][:,k])

        print(f'{dataset} fold {k} '
            f'| train({100*sum(masks[0][:,k])/total:.2f}) '
            f'| \tval({100*sum(masks[1][:,k])/total:.2f})'
            f'| \ttest({100*sum(masks[2][:,k])/total:.2f})'
            f'| \ttrv({sum(masks[0][:,k])+sum(masks[1][:,k])})'
        )

        # Train
        if cfg.setup['train']:
            val_acc = run_training(cfg, data, model)
            val_accs[k] = val_acc

        # Test
        test_loss, test_acc = test(cfg, data, model)
        test_accs[k] = test_acc
        
    print({'val_mean':np.mean(val_accs),
        'val_std':np.std(val_accs),
        'test_mean':np.mean(test_accs),
        'test_std':np.std(test_accs)})
    wandb.log({'val_mean':np.mean(val_accs),
        'val_std':np.std(val_accs),
        'test_mean':np.mean(test_accs),
        'test_std':np.std(test_accs)})

    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------


@hydra.main(version_base=None, config_path="/root/workspace/sct-gnn/config/", config_name="pyg-webkb")
def run_webkb(cfg):
    # Initialize settings to wandb server
    wandb_kwargs = {'mode':'disabled'}
    wandb_kwargs['mode']='online' if cfg.setup['sweep'] else 'disabled'

    wandb.init(
        config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
        entity='',
        name='gcn-sct-'+cfg.load['dataset'],
        project='sct-gnn',
        tags=['test', 'gcn-sct2', cfg.load['dataset'], '48/32/20'],
        **wandb_kwargs,
    )
    
    # Execute
    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    run_folds(cfg)
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_webkb()
