#!/usr/bin/python3
import numpy as np
import random
import torch

import hydra
from omegaconf import OmegaConf
import wandb

# Set Backends
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Set Seed
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
random.seed(0)

import os, sys, time

from torch.nn import LeakyReLU, Linear, Module, ModuleList, Parameter, ReLU
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import WebKB, Planetoid
import torch_geometric.transforms as T

from torch_geometric.nn import GCNConv, GCN2Conv
from sct_gnn import SCT, SCT_Resid, SCT_Scalar
from sct_gnn.smoothness import NodeFeatureSmoothness
from sct_gnn.kernel_vectors import kernel_vectors, ker_lapl
sys.path.append('/root/workspace/sct-gnn/baselines/')
from egnn import EGNNConv

"""
PyG Planetoid Dataset
    Split: 48/32/20
    10-Fold Cross Validation
"""

import torch.nn as nn
class SReLU(nn.Module):
    """Shifted ReLU"""

    def __init__(self, nc, bias):
        super(SReLU, self).__init__()
        self.srelu_bias = nn.Parameter(torch.Tensor(nc,))
        self.srelu_relu = nn.ReLU(inplace=True)
        nn.init.constant_(self.srelu_bias, bias)

    def forward(self, x):
        return self.srelu_relu(x - self.srelu_bias) + self.srelu_bias

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Models
#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_architecture(Module):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__()
        self.fc_enc = Linear(in_channels,hidden_channels)
        self.gcn_list = ModuleList([])
        self.fc_dec = Linear(hidden_channels,out_channels)
        if act_fn == 'relu':
            self.act = ReLU()
        elif act_fn == 'srelu':
            self.act = SReLU(hidden_channels, -10)
        elif act_fn == 'leaky':
            self.act = LeakyReLU()
        self.dropout = dropout
        self.layers = hidden_layers
        self.reg = 0
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters())
        pass

    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        for i in range(self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            _layers.append(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

class egnn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        self.beta = config['beta']
        self.theta = config['theta']

        self.loss_weight = 20
        self.weight_standard = Parameter(torch.eye(hidden_channels), requires_grad=False)
        self.weight_first_layer = Parameter(torch.eye(hidden_channels) * np.sqrt(config['c_max']), requires_grad=False)

        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight, beta=self.beta, residual_weight=self.theta)
            x = self.act(ax)
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        if self.training:
            loss_orthogonal = 0.
            loss_orthogonal += torch.norm(self.gcn_list[0].weight - self.weight_first_layer)
            for i in range(1, self.layers):
                loss_orthogonal += torch.norm(self.gcn_list[i].weight - self.weight_standard)

            self.reg =  self.loss_weight * loss_orthogonal

        return x

#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_bias_architecture(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.bias_list = ModuleList([])
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters()) 
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters()) + list(self.bias_list.parameters())
        pass

class gcns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, single=True) # jb: get ker
        self.gcn_list = ModuleList([ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=True) for _ in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT(ker_vecs.shape[0], hidden_channels, ker_vecs=ker_vecs, indicators=indicators, cached=True) for _ in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=indicators, cached=True) for i in range(hidden_layers) ])
        self.bias_list = ModuleList([ SCT_Scalar(ker_vecs=ker_vecs, cached=True) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        self.reg = 0
        return x

class gcniis(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, return_all=True, single=True) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for i in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=None, cached=True) for i in range(hidden_layers) ])
        self.bias_list = ModuleList([ SCT_Scalar(ker_vecs=ker_vecs, cached=True) for _ in range(hidden_layers) ])
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)

        split = max(self.layers-2,0)

        for i in range(0, split):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            _layers.append(x)

        for i in range(split, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        self.reg = 0
        return x

class egnns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        indicators, deg, ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, return_all=True) # jb: get ker
        # self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=indicators, cached=True) for i in range(hidden_layers) ])
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0], hidden_channels, ker_vecs=ker_vecs, indicators=indicators, cached=True) for _ in range(hidden_layers) ])
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        self.beta = config['beta']
        self.theta = config['theta']

        self.loss_weight = 20
        self.weight_standard = Parameter(torch.eye(hidden_channels), requires_grad=False)
        self.weight_first_layer = Parameter(torch.eye(hidden_channels) * np.sqrt(config['c_max']), requires_grad=False)

        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight, beta=self.beta, residual_weight=self.theta)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        if self.training:
            loss_orthogonal = 0.
            loss_orthogonal += torch.norm(self.gcn_list[0].weight - self.weight_first_layer)
            for i in range(1, self.layers):
                loss_orthogonal += torch.norm(self.gcn_list[i].weight - self.weight_standard)

            self.reg =  self.loss_weight * loss_orthogonal

        return x

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Helper
#----------------------------------------------------------------------------------------------------------------------------------------------------

def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def index_to_mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool, device=index.device)
    mask[index] = 1
    return mask

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Config/Model/Dataset
#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        run_id = wandb.run.id
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(run_id)+'.pt'
    pass

#----------------------------------------------------------------------------------------------------------------------------------------------------

def load(cfg):
    args = cfg.load
    # Set Transforms
    transform = T.Compose([T.NormalizeFeatures(), T.ToUndirected(), T.GCNNorm()])
    # Load Dataset
    dataset = Planetoid(
        root="/root/workspace/data/"+args['dataset'],
        name=args['dataset'],
        split='geom-gcn',
        transform=transform
    )
    return dataset

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Train/Validate/Test
#----------------------------------------------------------------------------------------------------------------------------------------------------

def train(cfg, data, model, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(data.x,data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    acc = accuracy(output[data.train_mask],data.y[data.train_mask])
    return loss.item(), acc

@torch.no_grad()
def validate(cfg, data, model):
    model.eval()
    output = model(data.x,data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.val_mask], data.y[data.val_mask])
    
    acc = accuracy(output[data.val_mask],data.y[data.val_mask])
    return loss.item(), acc

@torch.no_grad()
def test(cfg, data, model):
    checkpoint = torch.load(cfg.load['checkpoint_path'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    output = model(data.x,data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.test_mask], data.y[data.test_mask])
    
    acc = accuracy(output[data.test_mask],data.y[data.test_mask])
    return loss.item(), acc

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, data, model):
    args = cfg.train
    optimizer = optim.Adam([{'params':model.graph_params, 'weight_decay':cfg.train['wd1']},
                    {'params':model.fc_params, 'weight_decay':cfg.train['wd2']}],
                    lr=cfg.train['lr'])

    best = 1e8
    bad_itr = 0
    for epoch in range(args['epochs']):
        start = time.time()
        train_loss, train_acc = train(cfg, data, model, optimizer)
        val_loss, val_acc = validate(cfg, data, model)
        stop = time.time()

        perf_metric = val_loss #your performance metric here

        if perf_metric < best:
            best = perf_metric
            acc, loss = val_acc, val_loss
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'acc': val_acc},
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        t_time = stop - start
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'train_acc':train_acc,
            'val_loss':val_loss,
            'val_acc':val_acc,
            'perf_metric':perf_metric,
            'best':best,
            'time':t_time})
        print(f'Epoch({epoch:03d}) | train({train_loss:.4f},{100*train_acc:.2f}) '
            f'| val({val_loss:.4f},{100*val_acc:.2f}) '+
            f'| perf({perf_metric:.3f}) | time({t_time:.3f})')

        if bad_itr>args['patience']:
            break

    return loss, acc

#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_folds(cfg):
    kfolds = 10
    val_accs = [ None for _ in range(kfolds) ]
    test_accs = [ None for _ in range(kfolds) ]

    dataset = load(cfg)
    data = dataset[0]
    data.to(cfg.setup['device'])

    masks = [data.train_mask, data.val_mask, data.test_mask]

    for k in range(kfolds):
        # Model
        model_name = cfg.model['name']
        model = globals().get(model_name)

        dropout = cfg.model['dropout']
        hidden_channels = cfg.model['hidden_channels']
        hidden_layers = cfg.model['hidden_layers']
        model = model(dataset.num_features, hidden_channels, hidden_layers, dataset.num_classes, dropout, cfg.model['act_fn'], cfg.model, data)
        model.to(cfg.setup['device'])
        # print(model)

        # Split Masks
        data.train_mask = index_to_mask(masks[0][:,k], data.num_nodes)
        data.val_mask = index_to_mask(masks[1][:,k], data.num_nodes)
        data.test_mask = index_to_mask(masks[2][:,k], data.num_nodes)
        total = sum(data.train_mask) + sum(data.val_mask) + sum(data.test_mask)

        print(f'{dataset} fold {k} '
            f'| train({100*sum(masks[0][:,k])/total:.2f}) '
            f'| \tval({100*sum(masks[1][:,k])/total:.2f})'
            f'| \ttest({100*sum(masks[2][:,k])/total:.2f})'
            f'| \ttrv({sum(masks[0][:,k])+sum(masks[1][:,k])})'
        )

        # Train
        if cfg.setup['train']:
            val_loss, val_acc = run_training(cfg, data, model)
            print(f'val({100*val_acc:2.2f})')
            val_accs[k] = val_acc.item()

        # Test
        test_loss, test_acc = test(cfg, data, model)
        print(f'test({100*test_acc:2.2f})')
        test_accs[k] = test_acc.item()
        
    print({'val_mean':np.mean(val_accs),
        'val_std':np.std(val_accs),
        'test_mean':np.mean(test_accs),
        'test_std':np.std(test_accs)})
    wandb.log({'val_mean':np.mean(val_accs),
        'val_std':np.std(val_accs),
        'test_mean':np.mean(test_accs),
        'test_std':np.std(test_accs)})

    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/sct-gnn/config/", config_name="pyg-planetoid")
def run_planetoid(cfg):
    # Initialize settings to wandb server
    wandb.init(
        config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
        entity='',
        mode='disabled',
        name='sct-conv-'+cfg.load['dataset'],
        project='pyg-planetoid',
        tags=['test', 'sct-conv', cfg.load['dataset'], '48/32/20']
    )
    
    # Execute
    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    run_folds(cfg)
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    run_planetoid()