import numpy as np
from models import *
import torch.nn.functional as F
import torch.nn as nn
import torch
import deeprobust.graph.utils as utils
from torch.nn.parameter import Parameter
from tqdm import tqdm
import scipy.sparse as sp
from scipy import stats
import pandas as pd
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from copy import deepcopy
from utils import reset_args
import random
from torch_geometric.utils import to_scipy_sparse_matrix, from_scipy_sparse_matrix, dropout_adj, is_undirected, \
    to_undirected
from sklearn.mixture import GaussianMixture
from sklearn.metrics.pairwise import cosine_similarity
# from sklearn.manifold import TSNE
import torch.jit
import seaborn as sns
from openTSNE import TSNE
import normflows as nf
from torch import distributions
import wandb
import socket
from torch.cuda.amp import autocast as autocast
import os
import joypy
from sklearn.tree import DecisionTreeClassifier
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torchist
from collections import OrderedDict
import math
from collections import Counter


@torch.jit.script
def entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from **log_softmax**."""
    return -(x * torch.log(x + 1e-15)).sum(1)


def sim(t1, t2):
    # cosine similarity
    t1 = t1 / (t1.norm(dim=1).view(-1, 1) + 1e-15)
    t2 = t2 / (t2.norm(dim=1).view(-1, 1) + 1e-15)
    return (t1 * t2).sum(1)


class GNN:
    def __init__(self, data_all, args):
        self.device = 'cuda'
        # self.device = 'cpu'
        self.args = args
        self.data_all = data_all
        self.model = self.pretrain_model()
        self.trains_p = None

    def pretrain_model(self, verbose=True):
        data_all = self.data_all
        args = self.args
        device = self.device
        feat, labels = data_all[0][0].graph['node_feat'], data_all[0][0].label
        edge_index = data_all[0][0].graph['edge_index']
        if args.model == "GCN":
            save_mem = False
            model = GCN(nfeat=feat.shape[1], nhid=args.hidden, dropout=args.dropout, nlayers=args.nlayers,
                        weight_decay=args.weight_decay, with_bn=True, lr=args.lr, save_mem=save_mem,
                        nclass=max(labels).item() + 1, device=device, args=args).to(device)

        elif args.model == "GAT":
            model = GAT(nfeat=feat.shape[1], nhid=32, heads=4, lr=args.lr, nlayers=args.nlayers,
                        nclass=labels.max().item() + 1, with_bn=True, weight_decay=args.weight_decay,
                        dropout=0.0, device=device, args=args).to(device)
        elif args.model == "SAGE":
            if args.dataset == "fb100":
                model = SAGE(feat.shape[1], 32, max(labels).item() + 1, num_layers=args.nlayers,
                             dropout=0.0, lr=0.01, weight_decay=args.weight_decay,
                             device=device, args=args, with_bn=args.with_bn).to(device)
            else:
                model = SAGE(feat.shape[1], 32, max(labels).item() + 1, num_layers=args.nlayers,
                             dropout=0.0, lr=0.01, weight_decay=args.weight_decay, device=device,
                             args=args, with_bn=args.with_bn).to(device)
        else:
            raise NotImplementedError
        if verbose:
            print('---------Model-----------')
            print(model)

        import os.path as osp
        if args.ood:
            filename = f'saved/{args.dataset}_{args.model}_s{args.seed}.pt'
        else:
            filename = f'saved_no_ood/{args.dataset}_{args.model}_s{args.seed}.pt'

        if args.debug and osp.exists(filename):
            model.load_state_dict(torch.load(filename, map_location=self.device))
        else:
            train_iters = 500 if args.dataset == 'ogb-arxiv' else 200
            model.fit_inductive(data_all, train_iters=train_iters, patience=500, verbose=True)
            if args.debug:
                torch.save(model.state_dict(), filename)

        return model

    def contrastive_loss_plus(self, model, feat, edge_index, edge_weight=None, weights=None):
        edge_index_1, edge_weight_1 = dropout_adj(edge_index, edge_weight, p=0.05)
        output1 = model.get_embed_plus(feat, edge_index_1, edge_weight_1, weights)  # DropEdge
        output2 = model.get_embed_plus(feat, edge_index, edge_weight, weights)  # Original
        idx = np.random.permutation(feat.shape[0])
        shuf_fts = feat[idx, :]
        output3 = model.get_embed_plus(shuf_fts, edge_index, edge_weight, weights)  # Shuffle
        pos = torch.exp(sim(output2, output1))
        neg = torch.exp(sim(output2, output3))
        loss = -1.0 * torch.log(pos / (pos+neg)).mean()
        return loss

    def get_distribution_trains(self, x_s, bins_num=100):
        model = deepcopy(self.model)
        model.eval()
        xs_p = []
        if type(x_s) is list:
            for n_bn in range(self.args.nlayers - 1):
                bns_p = []
                for x in x_s:
                    x_p = []
                    edge_index = x.graph['edge_index'].to(self.device)
                    feat = x.graph['node_feat'].to(self.device)
                    h = model.get_embed_ith(feat, edge_index, i=n_bn + 1)
                    h = h.detach().cpu().numpy()
                    for dim in range(h.shape[1]):
                        data = h[:][dim]
                        if len(bns_p):
                            bins = bns_p[0][dim][1]
                        else:
                            length = np.max(data) - np.min(data)
                            bins = np.linspace(np.min(data) - 0.25 * length, np.max(data) + 0.25 * length,
                                               num=bins_num + 1)
                        x_p.append(np.histogram(data, bins=bins, density=True))
                    bns_p.append(x_p)
                xs_p.append(bns_p)
        else:
            print("Train set is not a list!")
            return
        xs_p = np.array(xs_p, dtype=object)
        self.trains_p = xs_p

    def differentiable_histogram(self, x, bins, density=True):
        w = bins[1] - bins[0]
        p = torch.zeros(len(bins) - 1).to(self.device)
        for i in range(len(bins) - 1):
            low = bins[i].item()
            high = bins[i + 1].item()
            mask = ((x >= low) & (x < high)).float()
            p[i] = torch.sum((x * mask) / x.detach())
        if density:
            p = p / (len(x) * w)
        return p

    def js_divergence(self, p, q, KL):
        q = q.float()
        p_norm = p / torch.sum(p)
        q_norm = q / torch.sum(q)
        M_log = ((p_norm + q_norm) / 2).log()
        kl = 0.5 * KL(M_log, p_norm) + 0.5 * KL(M_log, q_norm)
        return kl

    def bhattacharyya_distance(self, p, q):
        q = q.float()
        p_soft = F.softmax(p)
        q_soft = F.softmax(q)
        bdc = torch.sum(torch.sqrt(p_soft * q_soft))
        return -1.0 * torch.log(bdc)

    def get_distribution_distance(self, model, feat, edge_index, loss=False):
        kld = nn.KLDivLoss(reduction='sum')
        if self.args.model in ['GAT']:  # head=4
            weights = torch.zeros((self.args.nlayers-1, self.args.hidden * 4)).to(self.device)
        else:
            weights = torch.zeros((self.args.nlayers - 1, self.args.hidden)).to(self.device)
        for n_bn in range(self.args.nlayers - 1):
            h = model.get_embed_ith(feat, edge_index, i=n_bn + 1)
            for dim in range(h.shape[1]):
                bins = self.trains_p[n_bn][0][dim][1]
                bins = torch.tensor(bins).to(self.device)
                p = self.differentiable_histogram(h[:][dim], bins=bins, density=True)
                a = torch.zeros(len(self.trains_p[n_bn])).to(self.device)
                for i, qs in enumerate(self.trains_p[n_bn]):
                    if loss:
                        dist = self.bhattacharyya_distance(p, torch.tensor(qs[dim][0]).to(self.device))
                    else:
                        dist = self.js_divergence(p, torch.tensor(qs[dim][0]).to(self.device), kld)
                    if not torch.isnan(dist):
                        a[i] = dist
                mean = torch.sum(a) / torch.count_nonzero(a)
                weights[n_bn, dim] = mean
        if loss:
            return torch.sum(weights)
        else:
            return weights.detach()

    def sample_sgld(self, x, a, model, buffer, e, weights=None):
        noise_sigma = 0.01
        step_size = 1.0
        reinit_p = 0.05
        if len(buffer):
            r = np.random.rand()    # U[0,1)
            if r < reinit_p:
                reinit = 1
            else:
                reinit = 0
        else:
            reinit = 1
        if reinit:
            x_sample = torch.rand_like(x)  # U(0,1)
            x_sample.to(self.device)
        else:
            x_sample = random.choice(buffer)
        model.train()
        output_sample = model.forward(x_sample, a)
        e_sp = torch.logsumexp(output_sample, 1).mean()
        e_min = e - e_sp
        sample_min = x_sample
        if self.args.dataset in ['fb100']:
            epochs = 3
        else:
            epochs = 30
        for i in range(epochs):
            model.eval()
            x_copy = x_sample.clone().detach()
            x_copy.requires_grad = True
            grad = torch.autograd.grad(torch.logsumexp(model.forward(x_copy, a), 1).mean(), [x_copy], retain_graph=True)[0]
            model.train()
            x_sample = x_sample + step_size * grad + noise_sigma * torch.randn_like(grad).to(self.device)
            output_sample = model.forward(x_sample, a)
            e_sp = torch.logsumexp(output_sample, 1).mean()
            if (e - e_sp) < e_min:
                e_min = e - e_sp
                sample_min = x_sample
        return sample_min, e_min

    def data_select(self, p):
        th_pos = 0.8
        th_neg = 0.2
        ent_select_p = 0.6
        ent = entropy(p).detach().cpu().numpy()
        id_sort = np.argsort(ent)
        end = int(len(ent) * ent_select_p)
        id_sele = id_sort[:end]
        mask_conf = (p >= th_pos) | (p <= th_neg)
        scores = mask_conf.int() * torch.exp(p - ((th_pos+th_neg)/2))
        return scores, id_sele

    def crossentropyloss(self, p, q):
        ce = (p * torch.log(q + 1e-15)).sum(1)
        return -1.0 * torch.mean(ce)

    def jem_training(self):
        trains = self.data_all[0]
        tests = self.data_all[2]
        eval_func = self.model.eval_func
        accs = []
        y_te_all, out_te_all = [], []
        self.get_distribution_trains(trains, bins_num=self.args.bin_num)
        method = self.args.method_bn

        if type(tests) is list:
            for i, dat in enumerate(tests):
                edge_index = dat.graph['edge_index'].to(self.device)
                feat = dat.graph['node_feat'].to(self.device)
                labels = dat.label.to(self.device)
                model = deepcopy(self.model)
                model.reset_bn_track_running_stats(stats=False)
                p = model.predict(feat, edge_index)

                weights_file = f'weights_mask/{self.args.model}_s{self.args.seed}_{self.args.dataset}_g{i}.pt'
                if os.path.exists(weights_file):
                    weights = torch.load(weights_file).to(self.device)
                else:
                    weights = self.get_distribution_distance(model, feat, edge_index)
                    bern = Parameter(torch.ones_like(weights)).to(self.device)  # Gumbel sigmoid sampling  B=1
                    bern.requires_grad = True
                    optimizer_mask = optim.Adam([bern], lr=self.args.lr_mask)
                    temperature = 1e-2
                    KLD = nn.KLDivLoss(reduction='batchmean')
                    p_log = F.log_softmax(p, dim=-1)
                    if self.args.model in ['GCN']:
                        classify = model.layers[-1]
                    elif self.args.model in ['SAGE', 'GAT']:
                        classify = model.convs[-1]

                    for _ in range(self.args.learn_mask_epochs):
                        model.train()
                        optimizer_mask.zero_grad()
                        delta = torch.rand(weights.shape[0], weights.shape[1]).to(self.device)  # U(0,1)
                        mask = torch.sigmoid((torch.log(delta) - torch.log(1 - delta) + bern) / temperature)
                        weights_mask = weights * mask
                        loss_cl = self.contrastive_loss_plus(model, feat, edge_index, weights=weights_mask)
                        h = model.get_embed_plus(feat, edge_index, weights=weights_mask)
                        q = classify(h, edge_index)
                        q_soft = F.softmax(q, dim=-1)
                        loss_kl = KLD(p_log, q_soft)
                        loss = loss_cl + self.args.loss_lambda * loss_kl
                        loss.backward()
                        optimizer_mask.step()

                    mask = torch.sigmoid(bern)  # [0,1]
                    mask[mask > 0.5] = 1.0
                    mask[mask <= 0.5] = 0.0
                    weights = weights.detach() * mask.detach()
                    torch.save(weights, weights_file)
                epochs = self.args.tta_epochs_bns
                if method in ['BNPA']:
                    epochs = 0
                for _ in range(epochs):
                    model.train()
                    model.adaptation_bn_statistic(feat, edge_index, weights=weights)

                for param in model.parameters():
                    param.requires_grad = False
                for param in model.bns.parameters():
                    param.requires_grad = True
                optimizer = optim.Adam(model.parameters(), lr=self.args.lr_te, weight_decay=0)
                epochs = self.args.tta_epochs
                if method in ['BNSA']:
                    epochs = 0
                gt = torch.softmax(p, 1)
                if self.args.dataset == 'elliptic':
                    gt_mask = gt[dat.mask]
                    score, id_sele_ent = self.data_select(gt_mask)
                    gt_mask_aug = gt_mask * score
                else:
                    score, id_sele_ent = self.data_select(gt)
                    gt_aug = gt * score
                replay_buffer = []

                for e in range(epochs):
                    model.train()
                    optimizer.zero_grad()
                    output = model.forward(feat, edge_index)
                    pred = torch.softmax(output, dim=1)
                    if self.args.dataset == 'elliptic':
                        pred_mask = pred[dat.mask]
                        loss_clf = self.crossentropyloss(gt_mask_aug[id_sele_ent], pred_mask[id_sele_ent])
                    else:
                        loss_clf = self.crossentropyloss(gt_aug[id_sele_ent], pred[id_sele_ent])
                    energy_base = torch.logsumexp(output, 1).mean()
                    feat_sample, loss_gen = self.sample_sgld(feat, edge_index, model, replay_buffer, energy_base)
                    replay_buffer.append(feat_sample)
                    loss = loss_clf + loss_gen
                    loss.backward()
                    optimizer.step()

                output = model.predict(feat, edge_index)
                if self.args.dataset in ['cora', 'amazon-photo', 'twitch-e', 'fb100', 'ogb-arxiv', 'ogb-products']:
                    acc_test = eval_func(labels, output)
                    accs.append(acc_test)
                    y_te_all.append(labels)
                    out_te_all.append(output)
                elif self.args.dataset in ['elliptic']:
                    acc_test = eval_func(labels[dat.mask], output[dat.mask])
                    accs.append(acc_test)
                    y_te_all.append(labels[dat.mask])
                    out_te_all.append(output[dat.mask])
                else:
                    raise NotImplementedError

            print('BN adaptation JEM: Test accs:', accs)
            acc_te = eval_func(torch.cat(y_te_all, dim=0), torch.cat(out_te_all, dim=0))
            print(f'BN adaptation JEM: flatten test: {acc_te}')
        else:
            print("Test set is not a list!")
            return
