import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from graph.dataset import load
from graph.argparser import args
from graph.igsd import IGSD

class GCNLayer(nn.Module):
    def __init__(self, in_ft, out_ft, bias=True):
        super(GCNLayer, self).__init__()
        self.fc = nn.Linear(in_ft, out_ft, bias=False)
        self.act = nn.PReLU()

        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_ft))
            self.bias.data.fill_(0.0)
        else:
            self.register_parameter('bias', None)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, adj, feat):
        feat = self.fc(feat)
        out = torch.bmm(adj, feat)
        if self.bias is not None:
            out += self.bias
        return self.act(out)


class GCN(nn.Module):
    def __init__(self, in_ft, out_ft, num_layers):
        super(GCN, self).__init__()
        n_h = out_ft
        self.layers = []
        self.num_layers = num_layers
        self.layers.append(GCNLayer(in_ft, n_h).to(args.device))
        for __ in range(num_layers - 1):
            self.layers.append(GCNLayer(n_h, n_h).to(args.device))
        self.layers = nn.ModuleList(self.layers)

    def forward(self, adj, feat, mask=None):
        h_1 = self.layers[0](adj, feat)
        h_1g = torch.sum(h_1, 1)
        for idx in range(self.num_layers - 1):
            h_1 = self.layers[idx + 1](adj, h_1)
            h_1g = torch.cat((h_1g, torch.sum(h_1, 1)), -1)
        return h_1, h_1g


class MLP(nn.Module):
    def __init__(self, in_ft, out_ft):
        super(MLP, self).__init__()
        self.ffn = nn.Sequential(
            nn.Linear(in_ft, out_ft),
            nn.PReLU(),
            nn.Linear(out_ft, out_ft),
            nn.PReLU(),
            nn.Linear(out_ft, out_ft),
            nn.PReLU()
        )
        self.linear_shortcut = nn.Linear(in_ft, out_ft)

    def forward(self, x):
        return self.ffn(x) + self.linear_shortcut(x)


class Model(nn.Module):
    def __init__(self, n_in, n_h, num_layers):
        super(Model, self).__init__()
        self.mlp1 = MLP(1 * n_h, n_h)
        self.mlp2 = MLP(num_layers * n_h, n_h)
        self.gnn1 = GCN(n_in, n_h, num_layers)
        self.gnn2 = GCN(n_in, n_h, num_layers)

    def forward(self, adj, diff, feat, mask):
        lv1, gv1 = self.gnn1(feat, adj, mask)
        lv2, gv2 = self.gnn2(feat, diff, mask)

        lv1 = self.mlp1(lv1)
        lv2 = self.mlp1(lv2)

        gv1 = self.mlp2(gv1)
        gv2 = self.mlp2(gv2)

        return lv1, gv1, lv2, gv2

    def embed(self, feat, adj, diff, mask):
        __, gv1, __, gv2 = self.forward(adj, diff, feat, mask)
        return (gv1 + gv2).detach()

def train(dataset, gpu, num_layer=4, epoch=40, batch=64):
    nb_epochs = epoch
    batch_size = batch
    patience = 20
    lr = 0.001
    l2_coef = 0.0
    hid_units = 512
    args.projection_size = hid_units

    # feat: [B,N,feat_dim]
    adj, diff, feat, labels, num_nodes = load(dataset)

    feat = torch.FloatTensor(feat).to(args.device)
    diff = torch.FloatTensor(diff).to(args.device)
    adj = torch.FloatTensor(adj).to(args.device)
    labels = torch.LongTensor(labels).to(args.device)

    feat_dim = feat[0].shape[1]
    #adj_size = feat[0].shape[1]
    adj_size = feat[0].shape[0]
    max_nodes = feat[0].shape[0]

    #model = Model(feat_dim, hid_units, num_layer)
    online_encoder = GCN(feat_dim, hid_units, num_layer)
    # net, graph_size, hidden_layer = -2, projection_size = 256, projection_hidden_size = 4096, augment_fn = None, moving_average_decay = 0
    model = IGSD(online_encoder, feat_dim, num_layer, args.projection_size, args.projection_hidden_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)

    model.to(args.device)

    cnt_wait = 0
    best = 1e9

    itr = (adj.shape[0] // batch_size) + 1
    for epoch in range(nb_epochs):
        epoch_loss = 0.0
        train_idx = np.arange(adj.shape[0])
        np.random.shuffle(train_idx)
        
        #TODO debug
        #if epoch > 1:
        #    break


        for idx in range(0, len(train_idx), batch_size):
            #TODO debug
            #if idx > 2:
            #    break
            model.train()
            optimizer.zero_grad()

            batch = train_idx[idx: idx + batch_size]
            mask = num_nodes[idx: idx + batch_size]

            loss = model(adj[batch], feat[batch], diff=diff[batch], mask=mask)
            neg_loss = model.neg_loss(adj[batch], feat[batch], diff=diff[batch], mask=mask)
            loss = loss - args.beta * neg_loss

            epoch_loss += loss
            loss.backward()
            model.update_moving_average()
            optimizer.step()

        epoch_loss /= itr

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(model.state_dict(), f'{args.model}-{dataset}-{gpu}-{batch_size}.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == patience:
            break

    model.load_state_dict(torch.load(f'{args.model}-{dataset}-{gpu}-{batch_size}.pkl'))

    features = feat.to(args.device)
    adj = adj.to(args.device)
    diff = diff.to(args.device)
    labels = labels.to(args.device)

    embeds = model.embed(adj, diff, features)

    x = embeds.cpu().detach().numpy()
    y = labels.cpu().detach().numpy()

    from sklearn.svm import LinearSVC
    from sklearn.metrics import accuracy_score
    params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
    accuracies = []
    for train_index, test_index in kf.split(x, y):

        x_train, x_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]
        classifier = GridSearchCV(LinearSVC(), params, cv=5, scoring='accuracy', verbose=0)
        classifier.fit(x_train, y_train)
        accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))
    print(np.mean(accuracies), np.std(accuracies))
    return np.mean(accuracies)


if __name__ == '__main__':
    import warnings
    warnings.filterwarnings("ignore")
    print(args)
    gpu = 1
    #torch.cuda.set_device(gpu)
    layers = [2, 8, 12] #2
    batch = [16, 32, 64, 128, 256, 512]
    if args.ablation_batch > 0:
        batch = []
        batch.append(args.ablation_batch)
    epoch = [20, 40, 100]
    ds = []
    ds.append(args.dataset)
    best_d, best_l, best_b, best_e, best_seed = '', -1, -1, -1, -1
    #ds = ['REDDIT-BINARY', 'MUTAG', 'PTC_MR', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-MULTI-5K']
    seeds = [123, 132, 321, 312, 231]
    best_acc = -1
    best_mean_acc = -1
    best_accs = []
    for d in ds:
        print(f'####################{d}####################')
        for l in layers:
            for b in batch:
                for e in epoch:
                    res = []
                    for i in range(5):
                        seed = seeds[i]
                        torch.manual_seed(seed)
                        torch.backends.cudnn.deterministic = True
                        torch.backends.cudnn.benchmark = False
                        np.random.seed(seed)
                        print(f'Dataset: {d}, Layer:{l}, Batch: {b}, Epoch: {e}, Seed: {seed}, neg_prob: {args.neg_prob}, beta: {args.beta}')
                        acc = train(d, gpu, l, e, b)
                        if acc > best_acc:
                            best_acc, best_d, best_l, best_b, best_e, best_seed  = acc, d, l, b, e, seed
                        res.append(acc)
                    if sum(res)/len(res) > best_mean_acc:
                        best_mean_acc = sum(res)/len(res)
                        best_accs = res
                    print("Best mean accuracy and list", best_mean_acc, best_accs)
        print('################################################')
    print("Best acc{}".format(str(best_acc)))
    print(f'Dataset: {best_d}, Layer:{best_l}, Batch: {best_b}, Epoch: {best_e}, Seed: {best_seed}')