from ray import tune
from argparse import Namespace
import logging
import math
import random
import sys
import os
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from torch.optim import SGD
from optimizers.new_adam import NewAdam
from optimizers.new_lamb import NewLAMB
from optimizers.new_lars import NewLARS
from optimizers.new_lookahead import NewLookahead
from optimizers.new_radam import NewRAdam
from optimizers.new_yogi import NewYogi

import fastergcn.preprocess as preprocess
from fastergcn.model import (
    add_diag,
    renormalize_adj,
    sparse_mx_to_torch_sparse_tensor,
    evaluate,
    calc_f1,
    ClusterGCN,
)

data = "ppi"
hidden_dim = 1024
lr = 0.01
batch_size = 16
use_gcn = False
diag_lambda = 0.0
epoch = 200
debug = True
weight_decay = 0
use_pp = True
opt = "adam"
dropout = 0.2
num_layers = 3
partition_list = [50]
bsize_list = [1]
use_cluster = True


def load_data(root):
    global data, hidden_dim, dropout, use_gcn, diag_lambda, debug, use_pp, num_layers, partition_list, bsize_list, use_cluster
    partition_list = list(map(int, partition_list))
    bsize_list = list(map(int, bsize_list))
    # Load data
    (
        num_data,
        train_adj,
        full_adj,
        train_part_adj_list,
        full_part_adj,
        feats,
        train_feats,
        test_feats,
        labels_new,
        train,
        val,
        test,
        train_parts_list,
        full_parts,
    ) = preprocess.load_data_new(root, "ppi", partition_list=partition_list)

    num_features = feats.shape[1]
    num_class = labels_new.shape[1]

    print(labels_new[0])

    if diag_lambda > 0.0:
        print("Add diag lambda {} to all adjs!".format(diag_lambda))
        train_adj = add_diag(train_adj, diag_lambda, list(val) + list(test))
        full_adj = add_diag(full_adj, diag_lambda)
        full_part_adj = add_diag(full_part_adj, diag_lambda)
        train_part_adj_list = [
            add_diag(adj, diag_lambda) for adj in train_part_adj_list
        ]

        # recalculate AX
        train_feats = train_adj.dot(feats)
        test_feats = full_adj.dot(feats)

    print(train_adj[0, :])
    print("full")
    print(full_adj[0, :])

    print("train nonzeros", len(train_adj.data))
    print("full nonzeros", len(full_adj.data))

    target_dist = sum(labels_new)

    print("train/val/test: %d/%d/%d" % (len(train), len(val), len(test)))

    if not use_gcn:
        if isinstance(train_feats, sp.csr_matrix):
            train_feats = sp.hstack([train_feats, feats]).tocsr()
            test_feats = sp.hstack([test_feats, feats]).tocsr()
        else:
            train_feats = np.hstack([train_feats, feats])
            test_feats = np.hstack([test_feats, feats])

    real_train_adj = train_adj
    real_full_adj = full_adj
    if partition_list[0] > 1:
        print("using first partitioned graph!")
        train_adj = train_part_adj_list[0]
        full_adj = full_part_adj
    print("train_adj shape", train_adj.shape, real_train_adj.shape)

    adj_lists = []
    if feats.shape == (num_data, num_data):
        use_feat = False
        print("Feature is not used!")
    else:
        use_feat = True
    # Build model
    model = ClusterGCN(
        feats,
        num_features,
        hidden_dim,
        num_class,
        dropout=dropout,
        num_layers=num_layers,
        use_gcn=use_gcn,
        use_pp=use_pp,
        use_feat=use_feat,
    )
    model.cuda()

    # Loss function
    print("Using multi-label loss")
    loss_f = nn.BCEWithLogitsLoss()
    labels = labels_new

    labels = torch.tensor(labels, dtype=torch.float32)

    if use_cluster:
        train_set = set(train)
        partition_data_list = []
        for psize, bsize, train_part_adj, train_parts in zip(
            partition_list, bsize_list, train_part_adj_list, train_parts_list
        ):
            only_train_parts = [[] for _ in range(len(train_parts))]
            train_odrs = [[] for _ in range(len(train_parts))]
            part_adjs = []
            train_sparse_feats = []
            train_labels = []
            for pid, part in enumerate(train_parts):
                for nid, nd in enumerate(part):
                    if nd in train_set:
                        only_train_parts[pid].append(nd)
                        train_odrs[pid].append(nid)
                # should be train_adj
                part_adjs.append(
                    sparse_mx_to_torch_sparse_tensor(
                        train_part_adj[part, :][:, part]
                    ).cuda()
                )

                train_labels.append(labels[part])
                if not use_feat:
                    train_sparse_feats.append(
                        sparse_mx_to_torch_sparse_tensor(train_feats[part, :]).cuda()
                    )

            partition_data = dict()
            partition_data["partition_size"] = psize
            partition_data["balance_size"] = bsize
            partition_data["only_train_parts"] = only_train_parts
            partition_data["train_parts"] = train_parts
            partition_data["train_odrs"] = train_odrs
            partition_data["train_labels"] = train_labels
            partition_data["part_adjs"] = part_adjs
            if not use_feat:
                partition_data["train_sparse_feats"] = train_sparse_feats

            partition_data_list.append(partition_data)

        val_set = set(val)
        val_parts = [[] for _ in range(len(full_parts))]
        val_odrs = [[] for _ in range(len(full_parts))]
        val_part_adjs = []
        for pid, part in enumerate(full_parts):
            for nid, nd in enumerate(part):
                if nd in val_set:
                    val_parts[pid].append(nd)
                    val_odrs[pid].append(nid)
            val_part_adjs.append(
                sparse_mx_to_torch_sparse_tensor(full_adj[part, :][:, part]).cuda()
            )

        print("cluster sampling is enabled. arg batch size is not used")
    idx_full_parts = list(range(len(full_parts)))
    print("start optimizing...")

    print(type(train_feats))
    # TODO: Handle the case when train_feats is a sparse matrix
    if use_feat and isinstance(train_feats, sp.csr_matrix):
        print("Sparse train_feats case not done yet")
        return
        feats = feats.toarray()
        train_feats = train_feats.toarray()
        test_feats = test_feats.toarray()

    if use_feat:
        # feats = torch.from_numpy(feats).float().cuda()
        feats = torch.from_numpy(feats).float()
        # train_feats = torch.from_numpy(train_feats).float().cuda()
        train_feats = torch.from_numpy(train_feats).float()
        # test_feats = torch.from_numpy(test_feats).float().cuda()
        test_feats = torch.from_numpy(test_feats).float()

    # real_full_batch_adj = sparse_mx_to_torch_sparse_tensor(real_full_adj).cuda()
    real_full_batch_adj = sparse_mx_to_torch_sparse_tensor(real_full_adj)
    idx_parts = list(range(partition_list[0]))

    return (
        use_feat,
        idx_parts,
        partition_data_list,
        train_feats,
        test_feats,
        feats,
        labels,
        real_train_adj,
        real_full_batch_adj,
        test,
        model,
        loss_f,
    )


class ClusterGcnTask(tune.Trainable):
    def _setup(self, config):
        print(config)
        self.config = config
        opt_name = config["args"].optimizer
        opt_args = {}
        if opt_name.startswith("Adam") or opt_name.startswith("RAdam"):
            opt_args["new_beta1"] = self.config["new_beta1"]
            opt_args["new_beta2"] = self.config["new_beta2"]
            opt_args["eps"] = float(self.config["eps"])
            opt_args["lr"] = self.config["lr"]
            Optimizer = NewAdam if opt_name.startswith("Adam") else NewRAdam
        elif opt_name.startswith("SGD") or opt_name.startswith("LARS"):
            Optimizer = SGD if opt_name.startswith("SGD") else NewLARS
            opt_args["momentum"] = self.config["momentum"]
            opt_args["lr"] = self.config["lr"]
        elif opt_name.startswith("Yogi"):
            opt_args["new_beta1"] = self.config["new_beta1"]
            opt_args["new_beta2"] = self.config["new_beta2"]
            opt_args["eps"] = float(self.config["eps"])
            opt_args["lr"] = self.config["lr"]
            Optimizer = NewYogi
        elif opt_name.startswith("Lookahead"):
            opt_args["new_beta1"] = self.config["new_beta1"]
            opt_args["new_beta2"] = self.config["new_beta2"]
            opt_args["eps"] = float(self.config["eps"])
            opt_args["lr"] = self.config["lr"]
            opt_args["k"] = self.config["k"]
            opt_args["alpha"] = self.config["alpha"]
            Optimizer = NewLookahead
        elif opt_name.startswith("LAMB"):
            opt_args["new_beta1"] = self.config["new_beta1"]
            opt_args["new_beta2"] = self.config["new_beta2"]
            opt_args["eps"] = float(self.config["eps"])
            opt_args["lr"] = self.config["lr"]
            Optimizer = NewLAMB
        else:
            raise ValueError

        (
            self.use_feat,
            self.idx_parts,
            self.partition_data_list,
            self.train_feats,
            self.test_feats,
            self.feats,
            self.labels,
            self.real_train_adj,
            self.real_full_batch_adj,
            self.test,
            self.model,
            self.loss_f,
        ) = load_data(self.config["data_prefix"])
        # Training
        self.optimizer = Optimizer(self.model.parameters(), **opt_args)

    def _train(self):
        train_feats, test_feats, feats, labels, real_train_adj, real_full_batch_adj = (
            self.train_feats,
            self.test_feats,
            self.feats,
            self.labels,
            self.real_train_adj,
            self.real_full_batch_adj,
        )
        optimizer, loss_f, model = self.optimizer, self.loss_f, self.model
        use_feat = self.use_feat
        test = self.test
        if use_cluster:
            nowp = 0
            partition_data = self.partition_data_list[nowp]

            psize = partition_data["partition_size"]
            bsize = partition_data["balance_size"]
            only_train_parts = partition_data["only_train_parts"]
            train_parts = partition_data["train_parts"]
            train_odrs = partition_data["train_odrs"]
            train_labels = partition_data["train_labels"]
            part_adjs = partition_data["part_adjs"]
            if not self.use_feat:
                train_sparse_feats = partition_data["train_sparse_feats"]

        np.random.shuffle(self.idx_parts)

        if bsize > 1:
            # """ Randomly choose bsize clusters
            for i, st in enumerate(range(0, psize, bsize)):
                self.optimizer.zero_grad()

                batch = []
                # idx_parts[st:st+bsize]
                for pt_idx in self.idx_parts[st : st + bsize]:
                    batch += only_train_parts[pt_idx]

                # NOTE: this works when clusters contain only training nodes
                batch_neigh = batch

                # NOTE: Better to do renormalization when adding off-diagonal edges
                # batch_adj = sparse_mx_to_torch_sparse_tensor(renormalize_adj_ver2(real_train_adj[batch_neigh, :][:,batch_neigh])).cuda()
                batch_adj = sparse_mx_to_torch_sparse_tensor(
                    renormalize_adj(real_train_adj[batch_neigh, :][:, batch_neigh])
                ).cuda()

                if isinstance(train_feats, sp.csr_matrix):
                    batch_feats = sparse_mx_to_torch_sparse_tensor(
                        train_feats[batch_neigh, :]
                    ).cuda()
                    batch_orig_feats = None
                else:
                    batch_feats = train_feats[batch_neigh, :].cuda()
                    batch_orig_feats = feats[batch_neigh, :].cuda()

                # forward prop
                logits = model(
                    (None, None), (None, None), batch_adj, batch_feats, batch_orig_feats
                )
                loss = loss_f(logits, labels[batch].cuda())

                # backward
                loss.backward()
                self.optimizer.step()

                if i % 20 == 0:
                    micro, macro = calc_f1(
                        labels[batch].cpu(), logits.cpu().data.numpy(), True
                    )
                    print(
                        "batch %d loss %f F1 mi %f ma %f batch_neigh %d"
                        % (i, loss.item(), micro, macro, len(batch_neigh))
                    )
        else:
            # 1 cluster
            for i, pid in enumerate(self.idx_parts):
                if len(only_train_parts[pid]) == 0:
                    continue
                optimizer.zero_grad()
                batch_nodes = only_train_parts[pid]
                cluster_nodes = train_parts[pid]
                mapping = train_odrs[pid]
                batch_labels = train_labels[pid].cuda()

                # batch_adj = sparse_mx_to_torch_sparse_tensor(full_adj[cluster_nodes,:][:,cluster_nodes]).cuda()
                batch_adj = part_adjs[pid]
                if isinstance(train_feats, sp.csr_matrix):
                    batch_feats = train_sparse_feats[pid]
                    batch_orig_feats = None
                else:
                    batch_feats = train_feats[cluster_nodes, :].cuda()
                    batch_orig_feats = feats[cluster_nodes, :].cuda()

                # forward prop
                logits = model(
                    (None, None), (None, None), batch_adj, batch_feats, batch_orig_feats
                )

                # calculate loss
                loss = loss_f(logits[mapping], batch_labels)
                # backward
                loss.backward()
                optimizer.step()

                if i % 20 == 0:
                    # print(batch_labels)
                    micro, macro = calc_f1(
                        batch_labels.cpu(), logits[mapping].cpu().data.numpy(), True
                    )
                    print("batch %d loss %f F1 %f" % (i, loss.item(), micro))

        """
        if epoch % 1 == 0:
            model.eval()
            micro, macro = evaluate(model, val, labels, real_full_batch_adj, test_feats, feats, use_feat)
            model.train()

            val_f1.append(micro)
            print("epoch {} val F1 micro {} macro {} elapsed train time".format(epoch, micro, macro))
        """
        model.eval()
        micro, macro = evaluate(
            model, test, labels, real_full_batch_adj, test_feats, feats, use_feat
        )
        model.train()
        return {"f1": micro, "early_stop": False}

    def _save(self, checkpoint_dir):
        out_f = os.path.join(checkpoint_dir, "model_opt_state.pth")
        state = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }
        torch.save(state, out_f)
        return checkpoint_dir

    def _restore(self, checkpoint_dir):
        out_f = os.path.join(checkpoint_dir, "model_opt_state.pth")
        state = torch.load(out_f)
        self.model.load_state_dict(state["model"])
        self.optimizer.load_state_dict(state["optimizer"])
