import torch.nn.functional as fn
import torch.nn as nn
import torch.autograd
import numpy as np
import tempfile
import random
import sys

from utils import eval_acc, eval_rocauc, get_rnd_seed, compute_avg_results
from dataset import load_nc_dataset
from model import SAF


def set_rng_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def print_both(args, str, end="\r\n"):
    if not args.is_print:
        return 0
    print(str, file=sys.stderr, end=end)


class EvalHelper:
    def __init__(self, args, trn_idx, val_idx, tst_idx):
        use_cuda = torch.cuda.is_available() and not args.cpu
        dev = torch.device('cuda' if use_cuda else 'cpu')

        # load dataset
        dataset = load_nc_dataset(args)
        n = dataset.graph['num_nodes']
        c = dataset.label.max().item() + 1
        d = dataset.graph['node_feat'].shape[1]

        # verify random splits
        trn_idx, val_idx, tst_idx = np.array(trn_idx), np.array(val_idx), np.array(tst_idx)
        assert len(set(trn_idx).intersection(val_idx)) == 0
        assert len(set(trn_idx).intersection(tst_idx)) == 0
        assert len(set(val_idx).intersection(tst_idx)) == 0

        # load decomposed eigen-values & vectors
        dcp_dic = torch.load(f"{args.DATAPATH}eigen_dcp/{args.dataset}_{args.sub_dataset}_dcp.pt")

        # to-cuda
        self.dcpW, self.dcpU = dcp_dic["W"].to(dev), dcp_dic["U"].to(dev)
        trn_idx = torch.from_numpy(trn_idx).to(dev)
        val_idx = torch.from_numpy(val_idx).to(dev)
        tst_idx = torch.from_numpy(tst_idx).to(dev)
        dataset.label = dataset.label.to(dev)
        dataset.graph['edge_index'] = dataset.graph['edge_index'].to(dev)
        dataset.graph['node_feat'] = dataset.graph['node_feat'].to(dev)

        # model
        model = SAF(args, n, dev=dev, in_dim=d, hid_dim=args.hidden_channels, out_dim=c, dcpW=self.dcpW, dcpU=self.dcpU,
                    agg_layers=args.agg_K, agg_alpha=args.agg_alpha, drop_ama=args.drop_ama, sparse_eps=args.eps).to(dev)

        # optimizer
        optmz = torch.optim.Adam(
            [{'params': model.lin1.parameters(), 'weight_decay': args.reg, 'lr': args.lr},
             {'params': model.lin2.parameters(), 'weight_decay': args.reg, 'lr': args.lr},
             {'params': model.filter_prop.parameters(), 'weight_decay': args.filter_wd, 'lr': args.filter_lr},
             {'params': model.lin_f.parameters(), 'weight_decay': args.ama_wd, 'lr': args.ama_lr},
             {'params': model.lin_a.parameters(), 'weight_decay': args.ama_wd, 'lr': args.ama_lr}])

        # using rocauc as the eval function
        if args.dataset in ['twitch-e', 'minesweeper', 'tolokers']:
            loss_fn = nn.BCEWithLogitsLoss()
            eval_fn = eval_rocauc
        else:
            loss_fn = nn.NLLLoss()
            eval_fn = eval_acc

        self.dataset = dataset
        self.trn_idx, self.val_idx, self.tst_idx = trn_idx, val_idx, tst_idx
        self.model, self.optmz = model, optmz
        self.loss_fn, self.eval_fn = loss_fn, eval_fn
        self.c = c
        self.n = n

    def before_loss(self, args, out):
        if args.dataset in ['twitch-e', 'minesweeper', 'tolokers']:
            true_label = fn.one_hot(self.dataset.label, self.c).type(out.dtype)
            return out, true_label
        else:
            true_label = self.dataset.label
            return fn.log_softmax(out, dim=1), true_label

    def run_epoch(self, args, end='\n'):
        self.model.train()
        self.optmz.zero_grad()

        out = self.model(self.dataset)

        logS_out, true_label = self.before_loss(args, out)
        task_loss = self.loss_fn(logS_out[self.trn_idx], true_label[self.trn_idx])

        loss = task_loss

        loss.backward()
        self.optmz.step()

        print_both(args, "epoch-loss={:.4f}".format(loss.item()), end=end)
        return loss.item()

    def evaluate(self):
        self.model.eval()
        out = self.model(self.dataset)
        trn_acc = self.eval_fn(self.dataset.label[self.trn_idx], out[self.trn_idx])
        val_acc = self.eval_fn(self.dataset.label[self.val_idx], out[self.val_idx])
        tst_acc = self.eval_fn(self.dataset.label[self.tst_idx], out[self.tst_idx])
        return trn_acc, val_acc, tst_acc


def train_and_eval(args, trn_idx, val_idx, tst_idx):
    # fix random initialization
    set_rng_seed(args.rnd_seed)
    # build model
    agent = EvalHelper(args, trn_idx, val_idx, tst_idx)
    # trn and val
    wait_cnt, best_epoch = 0, 0
    best_val_acc = 0.0
    best_model_sav = tempfile.TemporaryFile()
    # epoch training
    for t in range(args.nepoch):
        agent.run_epoch(args, end=", ")
        trn_acc, val_acc, tst_acc = agent.evaluate()
        print_both(args, "epoch: {}/{}, trn-acc={:.4f}%, val-acc={:.4f}%, tst-acc={:.4f}".format(
            t + 1, args.nepoch, trn_acc * 100, val_acc * 100, tst_acc * 100))
        # training with early-stop
        if val_acc > best_val_acc:
            wait_cnt = 0
            best_val_acc = val_acc
            best_model_sav.close()
            best_model_sav = tempfile.TemporaryFile()
            torch.save(agent.model.state_dict(), best_model_sav)
            best_epoch = t + 1
        else:
            wait_cnt += 1
            if wait_cnt > args.early:
                break
    # final results
    print_both(args, "Load selected model ...", end="\r\n")
    best_model_sav.seek(0)
    agent.model.load_state_dict(torch.load(best_model_sav))
    trn_acc, val_acc, tst_acc = agent.evaluate()
    print_both(args, "trn-acc={:.4f}%, val-acc={:.4f}%, tst-acc={:.4f}%".format(trn_acc * 100, val_acc * 100, tst_acc * 100))
    return val_acc, tst_acc, best_epoch


def Run_Exp(args):
    if args.sl_mode == "full":
        file_end = "denseSplits"
    elif args.sl_mode == "semi":
        file_end = "sparseSplits"
    else:
        raise ValueError('Invalid sl_mode')
    rdsp_idxes_ls = np.load(f"{args.DATAPATH}random_splits/{args.dataset}_{args.sub_dataset}_{file_end}.npy", allow_pickle=True)
    results = []
    for i, rd_sp_dic in enumerate(rdsp_idxes_ls):
        args.rnd_seed = get_rnd_seed()
        val_acc, tst_acc, epochs = train_and_eval(args, rd_sp_dic["trn_idx"], rd_sp_dic["val_idx"], rd_sp_dic["tst_idx"])
        print(f"{i}-run: val_acc={val_acc}, tst_acc={tst_acc}")
        results.append([tst_acc])
    r_mean, _, r_conf = compute_avg_results(np.concatenate(results, axis=0))
    print("datname={}, sub_datname={}, avg_10_mean={:.4f}%+-{:.4f}".format(args.dataset, args.sub_dataset, r_mean * 100, r_conf * 100))
