import logging
import os
import time
from typing import List

from torch_sparse import SparseTensor
import os.path as osp
from collections import defaultdict

import numpy as np
import torch
from torch import nn

from data import get_data, get_metric
from components.backbone import get_model
from utils.utils import adj_norm, loss_fn, pred_fn
from utils.utils import mask_to_index, index_to_mask
from components.layer import BaseMLP, AttentionChannelMixing
from utils.utils import Dict
from utils.utils import setup_seed

import torch.nn.functional as F
import torchmetrics
from torch_geometric.utils import scatter

import hydra
from omegaconf import OmegaConf, DictConfig

log = logging.getLogger(__name__)
L_MAX = 6


def get_best_hop(_accs, _test_mask, _logits=None, _y=None, start_hop=0, _log=False):
    r = _test_mask.size(0)

    best_acc, best_std, best_hop = 0, 0, 0
    for l in range(start_hop, _accs.size(1)):
        _test_acc = []
        for run in range(_accs.size(0)):
            _mask = _test_mask[run % r]
            if _logits is not None and _logits.shape[-1] == 1:
                metric = torchmetrics.AUROC(task="binary")
                _acc = metric(_logits[run, l, _mask].cuda(), _y[_mask].cuda()).cpu()
            else:
                _acc = _accs[run, l, _mask].mean().item()
            _test_acc.append(_acc)

        if _log: print(f'{l}:  {np.mean(_test_acc):.4f} +- {np.std(_test_acc):.4f}')
        if np.mean(_test_acc) > best_acc:
            best_acc = np.mean(_test_acc)
            best_std = np.std(_test_acc)
            best_hop = l

    return best_hop, best_acc, best_std


def count_peaks(vecs):
    """Ensure input vecs is a 2D tensor (batch_size, vector_length)"""
    l_shift = nn.functional.pad(vecs[:, 1:], (0, 1), "constant", float('-inf'))
    r_shift = nn.functional.pad(vecs[:, :-1], (1, 0), "constant", float('-inf'))
    peaks = (vecs > l_shift) & (vecs > r_shift)

    return peaks[:, :].sum(dim=1)


def check_dist(vecs, dim=0, eps=0.04):
    _all_correct = vecs.mean(dim=0) >= 1 - eps
    _all_wrong = vecs.mean(dim=0) <= eps

    _diffs = torch.diff(vecs, dim=dim)
    _mono_inc = torch.all(_diffs >= 0 - eps, dim=dim)
    _mono_dec = torch.all(_diffs <= 0 + eps, dim=dim)

    _cnt_peaks = count_peaks(vecs.t() if dim == 0 else vecs)
    _plateau_peak = _cnt_peaks == 0
    _single_peak = _cnt_peaks == 1 | _plateau_peak
    _double_peak = _cnt_peaks == 2
    _triple_peak = _cnt_peaks == 3

    return (
        _all_correct,
        _all_wrong,
        _mono_inc & ~_all_correct & ~_all_wrong,
        _mono_dec & ~_all_correct & ~_all_wrong,
        _single_peak & ~_all_correct & ~_all_wrong & ~_mono_inc & ~_mono_dec,
        _double_peak & ~_all_correct & ~_all_wrong & ~_mono_inc & ~_mono_dec,
        _triple_peak & ~_all_correct & ~_all_wrong & ~_mono_inc & ~_mono_dec,
    )


def _get_pagerank(adj, alpha_=0.15, epsilon_=1e-6, max_iter=100):
    adj = adj_norm(adj, norm='rw', add_self_loop=False)
    adj = adj.t()

    num_nodes = adj.size(dim=0)
    s = torch.full((num_nodes,), 1.0 / num_nodes, device=adj.device()).view(-1, 1)
    x = s.clone()

    for i in range(max_iter):
        x_last = x
        x = alpha_ * s + (1 - alpha_) * (adj @ x)
        # check convergence, l1 norm
        if (abs(x - x_last)).sum() < num_nodes * epsilon_:
            # print(f'power-iter      Iterations: {i}, NNZ: {(x.view(-1) > 0).sum()}')
            return x.view(-1)

    return x.view(-1)


def _get_LSI(data, max_hop=6):
    feat = F.normalize(data.x, p=1)
    # adj = adj_norm(data.adj_t, norm='rw')
    adj = adj_norm(data.adj_t, norm='sym')

    deg = data.adj_t.sum(1)
    adj_inf = (deg + 1) / (data.num_edges + data.num_nodes)
    feat_inf = adj_inf.view(1, -1) @ feat

    smoothed_feats = [feat]
    for k in range(1, max_hop + 1):
        smoothed_feats.append(adj @ smoothed_feats[-1])

    # calculate feature distance (to stationary point) for each node
    s_dists = []
    for k, feat_k in enumerate(smoothed_feats):
        dist = (feat_k - feat_inf).norm(p=2, dim=1)
        s_dists.append(dist)
    s_dists = torch.stack(s_dists, dim=1)

    del smoothed_feats
    return s_dists


def _get_LSI2(data, max_hop=6):
    feat = F.normalize(data.x, p=1)
    adj = adj_norm(data.adj_t, norm='sym')

    smoothed_feats = [feat]
    for k in range(0, max_hop + 1):
        smoothed_feats.append(adj @ smoothed_feats[-1])

    # calculate feature distance (to its original feature) for each node
    s_dists = []
    for k, feat_k in enumerate(smoothed_feats):
        dist = (feat_k - feat).norm(p=2, dim=1)
        s_dists.append(dist)
    s_dists = torch.stack(s_dists[1:], dim=1)

    del smoothed_feats
    return s_dists


def get_filtered_homophily(_data, node_mask):
    N = _data.x.size(0)
    train_mask_diag = torch.sparse_coo_tensor(
        indices=torch.arange(N, device=node_mask.device).repeat(2, 1),
        values=node_mask,
        size=(N, N),
        dtype=torch.float32
    )
    masked_adj = _data.adj_t.to_torch_sparse_coo_tensor()
    masked_adj = torch.sparse.mm(train_mask_diag, torch.sparse.mm(masked_adj, train_mask_diag))
    masked_adj = SparseTensor.from_torch_sparse_coo_tensor(masked_adj)

    row, col, val = masked_adj.coo()
    edge_mk = val.to(torch.bool)
    _row, _col = row[edge_mk], col[edge_mk]

    out = torch.zeros(_row.size(0), device=_row.device)
    out[_data.y[_row] == _data.y[_col]] = 1.
    out = scatter(out, _col, 0, dim_size=_data.y.size(0), reduce='mean')
    return out


def split_hop_dataset(in_train_mask, in_val_mask, in_test_mask,
                      all_wrong, all_correct,
                      num_runs, mask_train: List[str], val_ratio=0.1,
                      overfit_masks=None, mask_tr_het_ratio=-1):
    """Dataset Split"""

    num_masks = in_train_mask.size(0)
    num_nodes = in_train_mask.size(1)
    out_test_mask = in_test_mask

    # training on validation set has similar best_t but larger gap
    # between best_v and best_t. Works well on larger dataset
    if val_ratio <= 0:
        out_train_mask = in_val_mask
        out_val_mask = in_train_mask
    elif val_ratio >= 1:
        out_train_mask = in_train_mask
        out_val_mask = in_val_mask
    else:
        num_val = int(in_val_mask[0].sum() * val_ratio)
        out_train_mask = torch.stack(
            [index_to_mask(mask_to_index(in_val_mask[i])[num_val:], size=num_nodes) for i in
             range(num_masks)], dim=0)
        out_val_mask = in_val_mask & ~out_train_mask

        # mask heterophily region in GNN train
        if mask_tr_het_ratio == 0:
            out_train_mask = out_train_mask | in_train_mask
        elif 0 < mask_tr_het_ratio <= 1:
            out_train_mask = out_train_mask | (in_train_mask & ~overfit_masks)

    if num_masks == 1:
        out_train_mask = out_train_mask.tile(num_runs, 1)
        out_val_mask = out_val_mask.tile(num_runs, 1)
        out_test_mask = out_test_mask.tile(num_runs, 1)

    if mask_train is not None:
        if 'all_wrong' in mask_train:
            out_train_mask = out_train_mask & ~all_wrong
        if 'all_correct' in mask_train:
            out_train_mask = out_train_mask & ~all_correct

    return out_train_mask, out_val_mask, out_test_mask


class AdaptMixer(nn.Module):
    def __init__(self, config, xs_dims, num_classes):
        super(AdaptMixer, self).__init__()

        # encoder
        input_dim = config.decoder.hidden_dim
        mlp_conf = {
            'hidden_channels': input_dim,
            'out_channels': input_dim,
            'num_layers': 1,
            'dropout': 0,
            'norm': config.decoder.norm,
            'keep_last_act': True,
        }
        self.enc_residual = config.encoder.get('residual', True)
        self.encoders = nn.ModuleList()
        for x_dim in xs_dims:
            self.encoders.append(BaseMLP(x_dim, **mlp_conf))
        self.W = nn.Linear(len(self.encoders) * input_dim, input_dim)

        self.att_mixer = AttentionChannelMixing(xs_dims, input_dim, layer_norm=True)

        if config.decoder.norm == 'layer':
            self.feat_norm = nn.LayerNorm(input_dim, eps=1e-9)
        if config.decoder.norm == 'batch':
            self.feat_norm = nn.BatchNorm1d(input_dim)

        decoder_conf = Dict(OmegaConf.to_container(config.decoder))
        self.decoder = get_model(decoder_conf, input_dim, input_dim)
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, xs, edge_index=None):
        x, gating_weights = self.att_mixer(xs)

        # decoding
        if hasattr(self, 'feat_norm'):
            x = self.feat_norm(x)
        if edge_index is None:
            x = self.decoder(x)
        else:
            x = self.decoder(x, edge_index)

        # classification
        x = self.classifier(x)
        return x, gating_weights


def train_adaptive_mixer(sampler, optimizer, lr_scheduler,
                         xs, y, split_masks):
    sampler.train()
    optimizer.zero_grad()
    train_mask, val_mask, test_mask = split_masks

    prob_hop, _ = sampler(xs)

    loss = loss_fn(prob_hop[train_mask], y[train_mask])
    loss.backward()
    optimizer.step()
    if lr_scheduler is not None:
        lr_scheduler.step()


@torch.no_grad()
def test(sampler, xs, y, metric, masks):
    sampler.eval()
    out, gating_weights = sampler(xs)

    _accs, losses = [], []
    pred, y = pred_fn(out, y)
    for mask in masks:
        metric.reset()
        metric(pred[mask], y[mask])
        _accs.append(metric.compute())
        loss = loss_fn(out[mask], y[mask])
        losses.append(loss)

    return _accs, losses, out, gating_weights


@hydra.main(version_base=None, config_path='conf', config_name='config')
def main(conf):
    data_conf = conf.dataset
    gnn_conf = conf.gnn
    train_conf = conf.train
    as_conf = conf.adapt_hop

    log.critical(f'Moscat training config\n'
                 f'GNN Expert: \n'
                 f'\tmodel: {gnn_conf.model}\n'
                 f'\tarch: {gnn_conf.arch}\n'
                 f'Basic Params: \n'
                 f'\tsampler_lr: {train_conf.sampler_lr}\n'
                 f'\thidden_dim: {as_conf.decoder.hidden_dim}\n'
                 f'\tnorm: {as_conf.decoder.norm}\n'
                 f'Moscat Params: \n'
                 f'\tmax_hop: {as_conf.max_hop}\n'
                 f'\tval_ratio: {as_conf.val_ratio}\n'
                 f'\tmask_tr_het_ratio: {as_conf.get("mask_tr_het_ratio", -1)}\n'
                 f'\tmask_wrong: {as_conf.mask_train is not None and "all_wrong" in as_conf.mask_train}\n'
                 )

    dataset_dir = osp.join(conf.data_dir, 'pyg')
    curve_dir = osp.join(conf.proc_dir, 'curve')
    logit_dir = osp.join(conf.proc_dir, 'mixer_logit')
    attn_dir = osp.join(conf.proc_dir, 'mixer_attn')
    os.makedirs(curve_dir, exist_ok=True)
    os.makedirs(logit_dir, exist_ok=True)
    os.makedirs(attn_dir, exist_ok=True)
    proc_dir = conf.proc_dir

    ## load data
    data, num_features, num_classes, dataset_dir = get_data(root=dataset_dir, **data_conf)
    metric = get_metric(data_conf.name, num_classes)
    dataset = data_conf.name

    ## load logits & accs
    model, arch, mlp_arch = gnn_conf.model, gnn_conf.arch, gnn_conf.mlp_arch
    num_depth = as_conf.max_hop if as_conf.max_hop > L_MAX else L_MAX

    if conf.get('num_ens', 1) > 1:
        logits = torch.load(proc_dir + '/ens/{}_{}{}_conv{}_ens{}.pt'.format(dataset, model, arch,
                                                                    gnn_conf.best_hop, conf.num_ens))
        logits = torch.transpose(logits, 0, 1)
    else:
        logits = []
        for i in range(0, num_depth + 1):
            if i == 0:
                t = torch.load(proc_dir + '/logit/{}_{}{}_conv{}.pt'.format(dataset, 'MLP', mlp_arch, 3))
            else:
                if 'ptb_ratio' in data_conf and data_conf['ptb_ratio'] > 0:
                    t = torch.load(
                        proc_dir + '/logit/{}_{}{}_conv{}_ptb{}{}.pt'.format(dataset, model, arch, i,
                                                                             data_conf['ptb_type'].upper(),
                                                                             data_conf['ptb_ratio']))
                else:
                    t = torch.load(proc_dir + '/logit/{}_{}{}_conv{}.pt'.format(dataset, model, arch, i))
            logits.append(t)

        ## handle cases when GNN runs<10
        if logits[-1].size(0) < logits[0].size(0):
            logits[0] = logits[0][:logits[-1].size(0)]

        logits = torch.stack(logits, dim=1)  # (runs, hops, nodes, logits)

    if logits.shape[-1] == 1:
        accs = (logits > 0).squeeze(-1) == data.y
    else:
        accs = logits.argmax(dim=-1) == data.y
    accs = accs.float()

    ## get split mask per run
    train_mask = index_to_mask(data.train_mask, size=data.num_nodes).t()
    val_mask = index_to_mask(data.val_mask, size=data.num_nodes).t()
    test_mask = index_to_mask(data.test_mask, size=data.num_nodes).t()
    if train_mask.dim() == 1:
        train_mask = train_mask.unsqueeze(0)
        val_mask = val_mask.unsqueeze(0)
        test_mask = test_mask.unsqueeze(0)
    NUM_RUNS = accs.shape[0]

    ## get train masking
    dist_masks = [check_dist(accs[i], dim=0, eps=0.04) for i in range(NUM_RUNS)]
    all_correct = torch.stack([masks[0] for masks in dist_masks], dim=0).to(conf.gpu)
    all_wrong = torch.stack([masks[1] for masks in dist_masks], dim=0).to(conf.gpu)

    ## structural encoding
    disparity_dict = {
        'pagerank': _get_pagerank(data.adj_t.t()).view(-1, 1),
        'LSI': _get_LSI(data, max_hop=num_depth),
        'LSI2': _get_LSI2(data, max_hop=num_depth),
    }
    X = torch.cat([
        disparity_dict['pagerank'],
        disparity_dict['LSI'],
        disparity_dict['LSI2'],
    ], dim=1).to(conf.gpu)
    X = (X - X.mean(dim=0, keepdim=True)) / X.std(dim=0, keepdim=True)

    data.to(conf.gpu)
    X = X.to(conf.gpu)
    accs = accs.to(conf.gpu)
    logits = logits.to(conf.gpu)
    train_mask = train_mask.to(conf.gpu)
    val_mask = val_mask.to(conf.gpu)
    test_mask = test_mask.to(conf.gpu)
    metric.to(conf.gpu)

    # set up seed for preprocessing
    setup_seed(0)

    # By default, do not use experts' training set
    mask_tr_het_ratio, overfit_masks = as_conf.get('mask_tr_het_ratio', -1), []
    if 0 < mask_tr_het_ratio <= 1:
        for r in range(train_mask.shape[0]):
            selected_nodes = train_mask[r] | val_mask[r]
            homo = get_filtered_homophily(data, selected_nodes)

            train_homo = homo[train_mask[r]].mean()
            idx1 = mask_to_index(homo < train_homo)

            idx2 = idx1[torch.randperm(idx1.shape[0])[:int(idx1.shape[0] * mask_tr_het_ratio)]]
            overfit_masks.append(index_to_mask(idx2, homo.shape[0]))
        overfit_masks = torch.stack(overfit_masks, dim=0)

    ## training
    train_mask, val_mask, test_mask = split_hop_dataset(train_mask, val_mask, test_mask,
                                                        all_wrong, all_correct,
                                                        num_runs=NUM_RUNS,
                                                        mask_train=as_conf.mask_train,
                                                        val_ratio=as_conf.val_ratio,
                                                        overfit_masks=overfit_masks,
                                                        mask_tr_het_ratio=mask_tr_het_ratio
                                                        )

    total_train_time = 0.
    val_losses = []
    best_tests_cls, best_tests_hop = [], []
    total_logit, total_attn = [], []
    for i in range(NUM_RUNS):

        log.info(f'------------------------Run {i}------------------------')
        setup_seed(i + 1)

        feat_type, xs = as_conf.encoder.get('feat_type', ['logit_2']), []
        if 'disparity' in feat_type:
            xs.append(X)
        if 'node_feat' in feat_type:
            xs.append(data.x)
        if 'logit_2' in feat_type:  # Scope-aware Logit Augmentation
            lmax = as_conf.encoder.get('lmax', 6)
            deg_norm = as_conf.encoder.get('deg_norm', 'sym')
            adj_t = adj_norm(data.adj_t, norm=deg_norm, add_self_loop=False)

            for l in range(as_conf.min_hop, as_conf.max_hop + 1):
                _xs = [logits[i][l]]

                for _l in range(1, lmax+1):
                    _xs.append(adj_t @ _xs[-1])

                if as_conf.encoder.get('pagerank', True):
                    _xs.append(X[:, 0].unsqueeze(-1))  # remove for gcnii
                _xs.append(X[:, l+1].unsqueeze(-1))
                _xs.append(X[:, l+num_depth+2].unsqueeze(-1))
                xs.append(torch.cat(_xs, dim=1))

        xs_dims = [x.size(1) for x in xs]
        sampler = AdaptMixer(as_conf, xs_dims, num_classes).to(conf.gpu)

        optimizer = torch.optim.Adam(sampler.parameters(), lr=train_conf.sampler_lr,
                                     weight_decay=train_conf.weight_decay)
        lr_scheduler = None

        best_s, best_ts = defaultdict(float), defaultdict(float)
        for epoch in range(1, train_conf.sampler_epoch + 1):

            tik = time.time()

            r = train_mask.size(0)
            train_adaptive_mixer(
                sampler, optimizer, lr_scheduler,
                xs, data.y,
                split_masks=[train_mask[i % r], val_mask[i % r], test_mask[i % r]]
               )

            tok = time.time()
            total_train_time += tok - tik

            (train_acc, val_acc, test_acc), (train_loss, val_loss, test_loss), logit, gating_weights = test(
                sampler, xs, data.y, metric, [train_mask[i % r], val_mask[i % r], test_mask[i % r]])
            val_losses.append(val_loss)
            log.info(f'Epoch {epoch:03d}, Train: {train_acc: .4f}, '
                    f'Val: {val_acc: .4f}, Test: {test_acc: .4f}\n')

            if epoch > train_conf.get('log_epoch', 50):
                if val_acc > best_s['sample_val_acc']:
                    best_s = {
                        'epoch': epoch,
                        'sample_train_acc': train_acc,
                        'sample_val_acc': val_acc,
                        'sample_test_acc': test_acc,
                        'logit': logit,
                        'gating_weights': gating_weights
                    }

                if 0 < train_conf.early_stopping < epoch:
                    tmp = torch.tensor(val_losses[-(train_conf.early_stopping + 1): -1])
                    if val_loss > tmp.mean().item():
                        break

        log.info(f"[Best Result] "
              f"Epoch: {best_s['epoch']:03d}, Train: {best_s['sample_train_acc']:.4f}, "
              f"Val: {best_s['sample_val_acc']:.4f}, Test: {best_s['sample_test_acc']:.4f}")

        best_tests_cls.append(best_s['sample_test_acc'])
        total_logit.append(best_s['logit'].cpu())
        total_attn.append(best_s['gating_weights'].cpu())

    best_tests_cls = torch.tensor(best_tests_cls)

    best_hop, best_acc, best_std = get_best_hop(accs, test_mask, logits, data.y, start_hop=1)
    log.critical(f'Baseline({best_hop}): {best_acc*100:.2f} ±{best_std*100:.2f}')
    log.critical(f'Test CLS: {best_tests_cls.mean()*100:.2f} ±{best_tests_cls.std()*100:.2f}')

    if conf.get('log_time', False):
        log.critical(f'Total Train Time: {total_train_time/NUM_RUNS:.2f}s')
    if conf.get('log_curve', False):
        filename = '{}_{}{}_mix{}'.format(conf.dataset.name, gnn_conf.model, gnn_conf.arch, as_conf.max_hop)
        filename += '.pt'
        torch.save(best_tests_cls, osp.join(curve_dir, filename))
    if conf.get('log_logit', False):
        filename = '{}_{}{}'.format(conf.dataset.name, gnn_conf.model, gnn_conf.arch)
        filename += '.pt'
        torch.save(torch.stack(total_logit, dim=0), osp.join(logit_dir, filename))
    if conf.get('log_attn', False):
        filename = '{}_{}{}_{}-{}'.format(conf.dataset.name, gnn_conf.model, gnn_conf.arch, as_conf.min_hop, as_conf.max_hop)
        filename += '.pt'
        torch.save(torch.stack(total_attn, dim=0), osp.join(attn_dir, filename))


if __name__ == '__main__':
    main()
