"""
gcn adapted from pygcn: https://github.com/tkipf/pygcn
"""

import torch as t
import torch.nn as nn
import torch.nn.functional as F
import math
import os; import sys; sys.path.append(os.getcwd())
import dataset.graph as GD
import traceback
import time
import json
import dkm.train_util as TU
import dataset.meta as DM
from warnings import warn

class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(t.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(t.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = t.mm(input, self.weight)
        output = t.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

act_dict = dict(relu=F.relu, id=lambda x: x)

class KipfWellingGCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, norm='none', act='relu'):
        super().__init__()
        self.dropout = dropout
        self.gcn1 = GraphConvolution(nfeat, nhid)
        self.gcn2 = GraphConvolution(nhid, nclass)
        if norm == 'none':
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()
        else:
            self.bn1 = nn.BatchNorm1d(nhid, affine=True)
            self.bn2 = nn.BatchNorm1d(nclass, affine=True)
        self.act = act_dict[act]
        self._eye = None
    def forward(self, x, adj):
        x = self.gcn1(x, adj)
        x = self.bn1(x)
        x = self.act(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gcn2(x, adj)
        x = self.bn2(x)
        return F.log_softmax(x, dim=1)


class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, nlayers=2, norm='none', act='relu'):
        super(GCN, self).__init__()
        assert nlayers >= 1
        self.input_linear = nn.Linear(nfeat, nhid)
        self.output_linear = nn.Linear(nhid, nclass)
        self.lin_1 = nn.ModuleList([nn.Linear(nhid, nhid) for _ in range(nlayers)])
        self.lin_2 = nn.ModuleList([nn.Linear(nhid, nhid) for _ in range(nlayers)])
        if norm=='batch':
            self.bns = nn.ModuleList([nn.BatchNorm1d(nhid) for _ in range(nlayers)])
            self.output_bn = nn.BatchNorm1d(nhid)
        else:
            self.bns = nn.ModuleList([nn.Identity() for _ in range(nlayers)])
            self.output_bn = nn.Identity()
        self.dropout = dropout
        self.nlayers = nlayers
        self.act = act_dict[act]
    def forward(self, x, adj):
        x = self.input_linear(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(x)

        for i in range(self.nlayers):
            x_res = self.bns[i](x)
            x_res = t.spmm(adj, x_res) # GCN

            # FF block
            x_res = self.lin_1[i](x_res)
            x_res = F.dropout(x_res, self.dropout, training=self.training)
            x_res = self.act(x_res)
            x_res = self.lin_2[i](x_res)
            x_res = F.dropout(x_res, self.dropout, training=self.training)
            x = x + x_res

        x = self.output_bn(x)
        x = self.output_linear(x)
        return F.log_softmax(x, dim=1)

def train_model(**kwargs):
    timestart = time.time()
    try:
        _train_model(**kwargs)
    except Exception:
        traceback.print_exc()
        print("DKMTRAINFAIL")
    timeend = time.time()
    print("TOTAL_ELAPSED = ", timeend - timestart)

def _train_model(**kwargs):
    print("begin args ========")
    print(json.dumps(kwargs, indent=2))
    print("end args ========")

    seed = kwargs['seed']
    typ = t.float32
    t.set_default_dtype(typ)
    TU.set_seed(seed)

    if 'normalize_features' in kwargs:
      normalize_features = kwargs['normalize_features']
    else:
      normalize_features = False
    split = kwargs['split']
    ds = GD.get_dataset(kwargs['dataset'], split=split, normalize_features=normalize_features)
    device = kwargs['device']
    model_typ = kwargs['model']
    nhidden = kwargs['nhidden']
    dropout = kwargs['dropout']
    num_layers = kwargs['num_layers']
    norm = kwargs['norm']
    num_epochs = kwargs['num_epochs']
    act = kwargs['act'] if 'act' in kwargs else 'relu'


    X = ds.X.to(device, dtype=typ); y = ds.y.to(device, dtype=t.long)
    adj = ds.adj_sp.to(device, dtype=typ)
    num_class = ds.num_classes
    num_feat = ds.num_features
    train_mask = ds.train_mask.to(device, dtype=t.bool)
    val_mask = ds.val_mask.to(device, dtype=t.bool)
    test_mask = ds.test_mask.to(device, dtype=t.bool)
    if model_typ == 'gcn':
        model = GCN(num_feat, nhidden, num_class, dropout, nlayers=num_layers, norm=norm, act=act).to(device, dtype=typ)
    elif model_typ == 'kipfgcn':
        model = KipfWellingGCN(num_feat, nhidden, num_class, dropout, norm=norm, act=act).to(device, dtype=typ)
    else:
        raise ValueError(f"unknown model_typ = {model_typ}")

    import warnings
    from torch import optim
    init_lr = 1e-3; max_lr = 1e-2; min_lr = 1e-5; warmup = kwargs['num_epochs'] // 4
    warnings.filterwarnings('ignore', message='.*scheduler\.step.*')
    optimizer = optim.Adam(model.parameters(), lr=max_lr, weight_decay=5e-4)
    start = init_lr / max_lr; end = max_lr / max_lr
    get_warmup_lr = lambda x: start + (end - start) * x / warmup
    s1 = t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=get_warmup_lr)
    s2 = t.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)
    scheduler = t.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[s1, s2], milestones=[warmup])

    stats = None

    timestart = time.time()

    for epoch_i in range(num_epochs):
        epochstarttime = time.time()
        # train
        model.train()
        optimizer.zero_grad()
        out = model(X, adj)

        train_loss = F.nll_loss(out[train_mask], y[train_mask])
        train_loss.backward()
        optimizer.step()

        # test, need to re-evaluate because we assume we are using dropout
        model.eval()
        out = model(X, adj)
        pred = out.argmax(dim=1)
        correct_train = (pred[train_mask] == y[train_mask]).sum()
        train_acc = int(correct_train) / int(train_mask.sum())
        correct_val = (pred[val_mask] == y[val_mask]).sum()
        val_acc = int(correct_val) / int(val_mask.sum())
        correct_test = (pred[test_mask] == y[test_mask]).sum()
        test_acc = int(correct_test) / int(test_mask.sum())
        train_loss = F.nll_loss(out[train_mask], y[train_mask]).item()
        val_loss = F.nll_loss(out[val_mask], y[val_mask]).item()
        test_loss = F.nll_loss(out[test_mask], y[test_mask]).item()
        stats = dict(train_acc=train_acc, val_acc=val_acc, test_acc=test_acc,
                    train_loss=train_loss, val_loss=val_loss, test_loss=test_loss)
        to_print = dict(epoch=epoch_i) | stats
        to_print = to_print | dict(lr=optimizer.param_groups[0]['lr'], t=time.time() - epochstarttime)


        buf = TU.epoch_metrics_buf(to_print)
        print(buf.getvalue())
        scheduler.step()

    print("========")
    elapsed = time.time() - timestart
    print("ELAPSED = ", elapsed)
    print(TU.epoch_metrics_buf(stats).getvalue())
    print("========")
    print("TRAINLOOPEND")
    return stats

import hashlib
import pickle as pk
import os
def hash_dict(d):
    jsond = json.dumps(d)
    sha512 = hashlib.sha256()
    sha512.update(jsond.encode('utf-8'))
    return sha512.hexdigest()
import fcntl
class Lock:
    def __init__(self, lockfile="./lock.file"):
        self.lockfile=lockfile
    def __enter__ (self):
        self.fp = open(self.lockfile)
        fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX)
    def __exit__ (self, _type, value, tb):
        fcntl.flock(self.fp.fileno(), fcntl.LOCK_UN)
        self.fp.close()
def _search(default_params, grid_search_params, seeds_splits, fname):
    lockfile = f"lock.file"
    if os.path.exists(fname):
        with Lock(lockfile=lockfile):
            with open(fname, 'rb') as f:
                results = pk.load(f)
    else:
        results = dict()
    for ps in grid_search_params:

        params = default_params | ps
        hash_params = hash_dict(params)

        statss = [] if hash_params not in results else results[hash_params]['metrics']
        def already_trained(split, seed): return any(x['seed'] == seed and x['split'] == split for x in statss)
        for (seed, split) in seeds_splits:
            if already_trained(split, seed):
                print(f"already trained with params/split/seed: -/{split}/{seed}")
                continue

            params['seed'] = seed
            params['split'] = split
            try:
                stats = _train_model(**params)
            except KeyboardInterrupt:
                print("Exiting")
                return
            except:
                import traceback
                traceback.print_exc()
                warn(f"failed to train with params = {params}")
            stats['seed'] = seed
            stats['split'] = split
            statss.append(stats)
            del params['split']
            del params['seed']
        results[hash_params] = dict(metrics=statss, params=params)
        with Lock(lockfile=lockfile):
            with open(fname, 'wb') as f:
                pk.dump(results, f)

def grid_search(**kwargs):
    args = kwargs
    fname = f"hyperparam_results/gcn_results/{args['dataset']}.pkl"
    ds_name = args['dataset']
    num_epochs = 200
    default_params = dict(
        dtype='float32',
        dataset=ds_name,
        num_layers=2,
        device='cuda',
        num_epochs = num_epochs,
        kernel='relu',
        normalize_features=True
    )

    grid_search_params = [
        {
            'norm': norm,
            'dropout': dropout,
            'nhidden': nhidden,
            'model': model,
            'normalize_features': normalize_features
        }
        for norm in ['none', 'batch']
        for model in ['kipfgcn', 'gcn']
        for nhidden in [100, 200]
        for dropout in [0., 0.5]
        for normalize_features in [True, False]
    ]

    ## define # of splits etc.
    if args['dataset'] in DM.planetoid_datasets:
        seeds_splits = [(i, 'public') for i in range(10)]
    elif args['dataset'] in DM.bigger_datasets:
        seeds_splits = [(i, 'public') for i in range(3)]
    else:
        seeds_splits = [(seed, f'test-{seed}-{split}') for split in range(10) for seed in range(2)]

    print("number of combos to test", len(grid_search_params))
    _search(default_params, grid_search_params, seeds_splits, fname)

def final_acc(**kwargs):
    args = kwargs
    fname = f"hyperparam_results/final_gcn_results/{args['dataset']}.pkl"
    ds_name = args['dataset']
    num_epochs = 200
    default_params = dict(
        dtype='float32',
        dataset=ds_name,
        num_layers=2,
        device='cuda',
        num_epochs = num_epochs,
        kernel='relu',
        normalize_features=True
    ) | DM.best_gcn[ds_name]

    grid_search_params = [{}]

    ## define # of splits etc.
    if args['dataset'] in DM.planetoid_datasets:
        seeds_splits = [(i, 'public') for i in range(1000, 1010)]
    elif args['dataset'] in DM.bigger_datasets:
        seeds_splits = [(i, 'public') for i in range(1000, 1005)]
    else:
        seeds_splits = [(seed, f'test-{seed}-{split}') for split in range(10) for seed in range(1000, 1005)]

    print("number of combos to test", len(grid_search_params))
    _search(default_params, grid_search_params, seeds_splits, fname)

    default_params_no_dropout = default_params | DM.best_gcn_no_dropout[ds_name]
    print("number of combos to test", len(grid_search_params))
    _search(default_params_no_dropout, grid_search_params, seeds_splits, fname)


if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--split', type=str, default='', help='name or number of dataset split')
    parser.add_argument("--num-epochs", type=int, default=120)
    parser.add_argument('--device', type=str, nargs='?', default='cuda', choices=['cpu', 'cuda'])
    parser.add_argument('--dtype', type=str, nargs='?', default='float32', choices=['float32', 'float64'])
    parser.add_argument("--seed", type=int, default=0)
    # model params
    parser.add_argument("--model", type=str, default='gcn', choices=['gcn', 'kipfgcn'])
    parser.add_argument("--dropout", type=float, default=0.)
    parser.add_argument("--nhidden", type=int, default=64)
    parser.add_argument("--num-layers", type=int, default=2)
    parser.add_argument("--norm", type=str, default='none', choices=['none', 'batch'])
    parser.add_argument("--normalize-features", action='store_true')
    parser.add_argument('--grid-search', action='store_true')
    parser.add_argument('--final-acc', action='store_true')
    args = parser.parse_args()
    args_dict = args.__dict__


    def keyboard_interrupt(f):
        try: f()
        except KeyboardInterrupt: print("Exiting")

    if args.grid_search:
        keyboard_interrupt(lambda: grid_search(**args_dict))
    elif args.final_acc:
        keyboard_interrupt(lambda: final_acc(**args_dict))
    else:
        model = train_model(**args_dict)