import numpy as np
from models import *
import torch.nn.functional as F
import torch
import deeprobust.graph.utils as utils
from torch.nn.parameter import Parameter
from tqdm import tqdm
import scipy.sparse as sp
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim
from copy import deepcopy
from utils import reset_args
from torch_geometric.utils import to_scipy_sparse_matrix, from_scipy_sparse_matrix, dropout_adj, is_undirected, \
    to_undirected
from torch_geometric.utils import k_hop_subgraph
import torch_geometric
from utils import KHopAggregator, flip_edges_with_degree_preservation
from feat_agent import FeatAgent
import math

import time

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import os, matplotlib
from torch import nn


# matplotlib.use('Agg')


class GraphAgent(FeatAgent):

    def __init__(self, data_all, args):
        self.device = 'cuda'
        self.args = args
        self.data_all = data_all
        self.model = self.pretrain_model()

    def setup_params(self, data):
        args = self.args
        for param in self.model.parameters():
            param.requires_grad = False

        """a lowrank regularization"""
        self.row_diversity_weight = 0.9

    def learn_graph(self, data, test_id, visualization=False):
        print('====learning on this graph===')
        args = self.args
        self.setup_params(data)
        model = self.model
        model.eval()  # should set to eval

        from utils import get_gpu_memory_map
        mem_st = get_gpu_memory_map()

        self.data = data
        nnodes = data.graph['node_feat'].shape[0]
        d = data.graph['node_feat'].shape[1]
        feat, labels = data.graph['node_feat'].to(self.device), data.label.to(self.device)  # .squeeze()
        ncls = len(torch.unique(labels))
        print(nnodes)
        self.num_nodes = nnodes

        if not args.mlp_prompt:
            delta_feat = Parameter(torch.FloatTensor(nnodes, d).to(self.device))
            delta_feat.data.fill_(1e-7)
            self.optimizer_feat = torch.optim.Adam([delta_feat], lr=args.lr_feat)
            self.delta_feat = delta_feat

        if args.prompt_comp == "UPF":
            glb_delta_feat = Parameter(torch.FloatTensor(1, d).to(self.device))
            glb_delta_feat.data.fill_(1)
            self.optimizer_feat = torch.optim.Adam([glb_delta_feat], lr=args.lr_feat)
            self.delta_feat = torch.add(glb_delta_feat, torch.zeros_like(feat))

        if args.mlp_prompt:
            if args.prompt_comp == "":
                self.prompt_layers = args.prompt_layers

                from LRA import PromptEncoder
                # from LinPromptEncodeMulti import PromptEncoder
                model_hidden = args.hidden
                attn_ratio = args.attn_ratio
                drop_out = args.dropout
                vnodes_rate = args.virtual_nodes_ratio
                vnodes = ncls * vnodes_rate if args.virtual_nodes_ratio != 0 else 1

                self.prompt_encoder = PromptEncoder(args, d, model_hidden, self.prompt_layers,
                                                    model_hidden // attn_ratio,
                                                    nnodes, vnodes, drop_out).to(self.device)
                # self.prompt_encoder = PromptEncoder(d, model_hidden, self.prompt_layers, 32, nnodes, ncls*vnodes_rate,attn_ratio, drop_out).to(self.device)
                for name, param in self.prompt_encoder.named_parameters():
                    if 'weight' in name:
                        torch.nn.init.zeros_(param)
                        # torch.nn.init.xavier_normal_(param)
                    elif 'bias' in name:
                        torch.nn.init.zeros_(param)
                        # torch.nn.init.xavier_normal_(param)
                self.optimizer_feat = torch.optim.Adam(self.prompt_encoder.parameters(), lr=args.lr_feat)
                # self.optimizer_feat = torch.optim.AdamW(self.prompt_encoder.parameters(), lr=args.lr_feat, weight_decay=args.weight_decay)

            elif args.prompt_comp == "All in One":
                threshold = 0.6
                n_prompt_tokens = args.virtual_nodes_ratio * ncls
                delta_feat = [Parameter(torch.FloatTensor(1, d).to(self.device)) for _ in range(n_prompt_tokens)]
                combined_prompt = torch.zeros_like(feat)
                for x in delta_feat:
                    x.data.fill_(1 / n_prompt_tokens)
                    prompt_weight = nn.Sigmoid()(x @ feat.t())
                    prompt_weight = nn.ReLU()(prompt_weight - threshold)
                    combined_prompt += prompt_weight.t() @ x

                self.delta_feat = combined_prompt
                self.optimizer_feat = torch.optim.Adam(delta_feat, lr=args.lr_feat)

            elif args.prompt_comp == "MultiG":
                self.delta_feat = Parameter(torch.FloatTensor(nnodes, args.hidden))
                self.optimizer_feat = torch.optim.Adam([self.delta_feat], lr=args.lr_feat)

        model = self.model
        for param in model.parameters():
            param.requires_grad = False
        model.eval()  # should set to eval as test-time tuning with prompt

        edge_index = data.graph['edge_index'].to(self.device)
        self.edge_index, self.feat, self.labels = edge_index, feat, labels
        self.edge_weight = torch.ones(self.edge_index.shape[1]).to(self.device)

        edge_index, edge_weight = edge_index, None

        """        Test tuning        """
        patience = args.patience
        now_wait = 0
        best_delta_mean = float("inf")

        if visualization:
            print("Visualization of Feat...")
            self.visualize_embedding(model.get_embed(
                feat, edge_index, edge_weight),
                labels, data, test_id, input_mode="feat_before_prompt")

        start = time.time()
        for it in tqdm(range(args.epochs)):
            self.optimizer_feat.zero_grad()
            if not args.prompt:
                loss = self.test_time_loss(model, self.apply_prompt(feat, self.delta_feat if hasattr(self,
                                                                                                     'delta_feat') else None,
                                                                    edge_index), edge_index,
                                           edge_weight)
            else:
                loss = self.test_time_loss(model, feat, edge_index, edge_weight)

            if args.mlp_prompt and args.prompt_comp == "":
                row_diversity_loss, row_similarities, expected_similarities = self.low_rank_node_penalty()
                loss = loss + row_diversity_loss

            loss.backward(retain_graph=True if args.prompt_comp in ["All in One"] else False)

            if args.mlp_prompt and args.prompt_comp == "":
                self.prompt_encoder.E.grad = self.prompt_encoder.E.grad - 2 * self.row_diversity_weight * torch.matmul(
                    row_similarities - expected_similarities, self.prompt_encoder.E)

            if not args.mlp_prompt:
                delta_mean = self.delta_feat.mean().item()
                embed_delta_mean = model.get_embed(self.delta_feat, edge_index).mean().item()
                print(
                    f'Epoch {it}, Loss: {loss.item()}, prompt: {delta_mean:.4f}, prompt embed: {embed_delta_mean:.4f}')
            else:
                delta_mean = self.generate_prompt(feat, edge_index).mean().item()
                embed_delta_mean = model.get_embed(self.generate_prompt(feat, edge_index),
                                                   edge_index).mean().item()
                print(
                    f'Epoch {it}, Loss: {loss.item()}, prompt: {delta_mean:.4f}, prompt embed: {embed_delta_mean:.4f}')

            self.optimizer_feat.step()

            if args.debug == 2 or args.debug == 3:
                output = model.predict(self.apply_prompt(feat, delta_feat, edge_index), edge_index, edge_weight)
                print('Debug Test:', self.evaluate_single(model, output, labels, data, verbose=0))

            if abs(embed_delta_mean) < best_delta_mean:
                best_delta_mean = abs(embed_delta_mean)
                now_wait = 0
            else:
                now_wait += 1

            if now_wait > patience:
                print("epoch patience out! ")
                break

        gpu_mem = get_gpu_memory_map()
        print(f'Mem used: {int(gpu_mem[args.gpu_id]) - int(mem_st[args.gpu_id])}MB')

        with torch.no_grad():
            loss = self.test_time_loss(model,
                                       self.apply_prompt(feat, self.delta_feat if hasattr(self, 'delta_feat') else None,
                                                         edge_index), edge_index, edge_weight)
        print('Final Loss:', loss.item())
        output = model.predict(
            self.apply_prompt(feat, self.delta_feat if hasattr(self, 'delta_feat') else None, edge_index),
            edge_index, edge_weight)

        if visualization:
            print("Visualization of Prompt...")
            # self.visualize_embedding(model.get_embed(
            #     self.generate_prompt(feat, edge_index),
            #     edge_index, edge_weight
            # ),
            self.visualize_embedding(model.get_embed(
                self.generate_prompt(feat, edge_index), edge_index),
                labels, data, test_id, input_mode="prompt")

            print("Visualization of After Prompt Embedding...")
            self.visualize_embedding(model.get_embed(
                self.apply_prompt(feat, self.delta_feat if hasattr(self, 'delta_feat') else None, edge_index),
                edge_index, edge_weight),
                labels, data, test_id, input_mode="feat_after_prompt")

        end = time.time()
        print('Test:')
        print(f"Tuning GOAT time costs: {end - start} seconds.")

        if args.show_SVD and args.mlp_prompt and args.prompt_comp == "":
            self.visualize_LR()

        if args.dataset == 'elliptic':
            return self.evaluate_single(model, output, labels, data), output[data.mask], labels[data.mask], self.generate_prompt(feat, edge_index)
        else:
            return self.evaluate_single(model, output, labels, data), output, labels, self.generate_prompt(feat, edge_index)

    def aggregate_k_hop_features(self, x0, edge_index, k):
        # Perform k-hop neighborhood aggregation
        # Use message passing or any suitable technique to aggregate features from k-hop neighborhood
        # Consider the edge_index for connectivity information

        # Initialize and apply k-hop aggregator
        aggregator = KHopAggregator()
        for ii in range(k):
            x0 = aggregator(x0, edge_index)

        return x0

    def generate_prompt(self, x0, edge_index):
        if not hasattr(self, "delta_feat"):
            prompt_embeds = []

            # k = 1
            # # Aggregate k-hop neighborhood features
            # x0 = self.aggregate_k_hop_features(x0, edge_index, k)

            for i in range(1, self.prompt_layers + 1):
                x = self.model.get_embed(x0, edge_index, output_layer=i)
                prompt_embeds.append(x)

            """global pool"""
            # from torch_geometric.nn import global_add_pool, global_mean_pool, avg_pool_x

            # x_global = global_mean_pool(x, torch.zeros(x.size(0), dtype=torch.long, device=x.device))
            # x_global = torch.add(x_global, torch.zeros_like(prompt_embeds[-1]))
            # prompt_embeds.append(x_global)

            """Method1: naive MLP"""
            # print([prompt.shape for prompt in prompt_embeds])
            # prompt_embeds = torch.cat(prompt_embeds, dim=-1)
            # prompt = self.prompt_mlp(prompt_embeds)

            """Method2: sigmoid weighted application """
            #

            """Method3: Transformer like application """
            prompt_embeds = torch.cat(prompt_embeds, dim=-1)
            # prompt_embeds = torch.cat([p.unsqueeze(-1) for p in prompt_embeds], dim=-1)
            prompt = self.prompt_encoder(x0, prompt_embeds)
        else:
            prompt = self.delta_feat
        return prompt

    def apply_prompt(self, x, P=None, edge_index=None):
        if P is not None:
            if self.args.prompt_comp not in ["MultiG"]:
                return x + P
            else:
                return x
        else:
            prompt = self.generate_prompt(x, edge_index)
            return x + prompt

    def augment(self, strategy='dropedge', p=0.5, edge_index=None, edge_weight=None, with_delta=False, mlp_delta=False,
                keep_connection=False):
        model = self.model
        uf = UnionFind(self.num_nodes)
        if not with_delta:
            if hasattr(self, 'delta_feat') and self.args.prompt_comp not in ["MultiG"]:
                delta_feat = self.delta_feat
                feat = self.feat + delta_feat
            else:
                feat = self.feat
            if strategy == 'shuffle':
                idx = np.random.permutation(feat.shape[0])
                shuf_fts = feat[idx, :]
                output = model.get_embed(shuf_fts, edge_index, edge_weight)
            if strategy == "dropedge":
                edge_index, edge_weight = dropout_adj(edge_index, edge_weight, p=p)
                output = model.get_embed(feat, edge_index, edge_weight)
            if strategy == "dropnode":
                feat = self.feat + self.delta_feat
                mask = torch.cuda.FloatTensor(len(feat)).uniform_() > p
                feat = feat * mask.view(-1, 1)
                output = model.get_embed(feat, edge_index, edge_weight)
            if strategy == "rwsample":
                import augmentor as A
                if self.args.dataset in ['twitch-e', 'elliptic']:
                    walk_length = 1
                else:
                    walk_length = 10
                aug = A.RWSampling(num_seeds=1000, walk_length=walk_length)
                x = self.feat + self.delta_feat
                x2, edge_index2, edge_weight2 = aug(x, edge_index, edge_weight)
                output = model.get_embed(x2, edge_index2, edge_weight2)

            if strategy == "dropmix":
                feat = self.feat + self.delta_feat
                mask = torch.cuda.FloatTensor(len(feat)).uniform_() > p
                feat = feat * mask.view(-1, 1)
                edge_index, edge_weight = dropout_adj(edge_index, edge_weight, p=p)
                output = model.get_embed(feat, edge_index, edge_weight)

            if strategy == "dropfeat":
                feat = F.dropout(self.feat, p=p) + self.delta_feat
                output = model.get_embed(feat, edge_index, edge_weight)
            if strategy == "featnoise":
                mean, std = 0, p
                noise = torch.randn(feat.size()) * std + mean
                feat = feat + noise.to(feat.device)
                output = model.get_embed(feat, edge_index)
            return output

        else:
            feat = self.feat
            if not mlp_delta and hasattr(self, 'delta_feat'):
                delta_feat = self.delta_feat
            else:
                delta_feat = None
            if strategy == 'shuffle':
                idx = np.random.permutation(feat.shape[0])
                shuf_fts = feat[idx, :]
                output = model.get_embed(shuf_fts, edge_index, edge_weight)
                output_with_delta = model.get_embed(self.apply_prompt(shuf_fts, delta_feat, edge_index),
                                                    edge_index,
                                                    edge_weight)
                delta_embed = model.get_embed(self.generate_prompt(shuf_fts, edge_index), edge_index)

            if strategy == "dropedge":
                edge_index, edge_weight = dropout_adj(edge_index, edge_weight, p=p)
                output = model.get_embed(feat, edge_index, edge_weight)
                output_with_delta = model.get_embed(self.apply_prompt(feat, delta_feat, edge_index), edge_index,
                                                    edge_weight)
                delta_embed = model.get_embed(self.generate_prompt(feat, edge_index), edge_index)

            if strategy == "flipedge":
                # num_edges = edge_index.size(1)
                # mask = torch.rand(num_edges) < p
                # flip_edges = edge_index[:, mask]
                # flip_edges = flip_edges[[1, 0]]
                # flipped_edge_index = torch.cat([edge_index[:, ~mask], flip_edges], dim=1)

                flipped_edge_index = flip_edges_with_degree_preservation(edge_index, p)

                output = model.get_embed(feat, flipped_edge_index, edge_weight)
                output_with_delta = model.get_embed(self.apply_prompt(feat, delta_feat, flipped_edge_index),
                                                    flipped_edge_index,
                                                    edge_weight)
                delta_embed = model.get_embed(self.generate_prompt(feat, flipped_edge_index), flipped_edge_index)

            if strategy == "dropnode":
                feat = self.feat + self.delta_feat
                mask = torch.cuda.FloatTensor(len(feat)).uniform_() > p
                feat = feat * mask.view(-1, 1)
                output = model.get_embed(feat, edge_index, edge_weight)
                output_with_delta = model.get_embed(self.apply_prompt(feat, delta_feat, edge_index), edge_index,
                                                    edge_weight)
                delta_embed = model.get_embed(self.generate_prompt(feat, edge_index), edge_index)

            # OOM on huge graph!
            # if strategy == "rwsample":
            #     import augmentor as A
            #     if self.args.dataset in ['twitch-e', 'elliptic']:
            #         walk_length = 1
            #     else:
            #         walk_length = 10
            #     aug = A.RWSampling(num_seeds=1000, walk_length=walk_length)
            #     x = self.feat + self.delta_feat
            #     x2, edge_index2, edge_weight2 = aug(x, edge_index, edge_weight)
            #     output = model.get_embed(x2, edge_index2, edge_weight2)

            if strategy == "dropmix":
                # feat = self.feat + self.delta_feat
                mask = torch.cuda.FloatTensor(len(feat)).uniform_() > p
                feat = feat * mask.view(-1, 1)
                edge_index, edge_weight = dropout_adj(edge_index, edge_weight, p=p)
                output = model.get_embed(feat, edge_index, edge_weight)
                output_with_delta = model.get_embed(self.apply_prompt(feat, delta_feat, edge_index), edge_index,
                                                    edge_weight)
                delta_embed = model.get_embed(self.generate_prompt(feat, edge_index), edge_index)

            if strategy == "dropfeat":
                feat = F.dropout(self.feat, p=p) + self.delta_feat
                output = model.get_embed(feat, edge_index, edge_weight)
                output_with_delta = model.get_embed(self.apply_prompt(feat, delta_feat, edge_index), edge_index,
                                                    edge_weight)
                delta_embed = model.get_embed(self.generate_prompt(feat, edge_index), edge_index)

            if strategy == "featnoise":
                mean, std = 0, p
                noise = torch.randn(feat.size()) * std + mean
                feat = feat + noise.to(feat.device)
                output = model.get_embed(feat, edge_index)
                output_with_delta = model.get_embed(self.apply_prompt(feat, delta_feat, edge_index), edge_index)
                delta_embed = model.get_embed(self.generate_prompt(feat, edge_index), edge_index)
            return output, output_with_delta, delta_embed

    def visualize_embedding(self, embedding, labels=None, test_data=None, test_id=0, mode="SCATTER",
                            input_mode="feat_after_prompt"):
        assert mode in ["KDE", "SCATTER"]

        embedding = embedding.detach().cpu().numpy()

        # if self.args.dataset in ['ogb-arxiv']:
        #     labels = labels[test_data.test_mask]
        # elif self.args.dataset in ['cora', 'amazon-photo', 'twitch-e', 'fb100']:
        #     labels = labels
        # elif self.args.dataset in ['elliptic']:
        #     labels = labels[test_data.mask]
        # else:
        #     raise NotImplementedError
        labels = labels.detach().cpu().numpy()

        vis_dir = f'results/visualization/{self.args.dataset}'
        if not os.path.exists(vis_dir):
            os.makedirs(vis_dir)

        # t-SNE
        tsne = TSNE(n_components=2, random_state=self.args.seed)
        embedding_2d = tsne.fit_transform(embedding)

        # KDE density
        x = embedding_2d[:, 0]
        y = embedding_2d[:, 1]

        sns.set(context="paper", style="white")

        fig, ax = plt.subplots(figsize=(20, 20))

        if mode == "KDE":
            fig.patch.set_facecolor('#fff7e0')
            ax.set_facecolor('#fff7e0')

            sns.kdeplot(x=x, y=y, ax=ax, cmap="YlOrBr", shade=True, bw_adjust=0.5)

        scatter = ax.scatter(x, y, c=labels, cmap='Spectral', alpha=0.8)
        cbar = fig.colorbar(scatter, ax=ax)
        cbar.set_label('Labels')

        ax.set_aspect('equal')
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_color('black')

        # ax.set_title('t-SNE and KDE Visualization')
        sns.despine()
        plt.savefig(os.path.join(vis_dir, f'{input_mode}_embeddings_vis_{test_id}.png'), dpi=400, bbox_inches='tight')
        plt.close(fig)

        del embedding

    def visualize_LR(self):
        E = self.prompt_encoder.E
        # SVD for E
        U, S, V = torch.svd(E, some=True)

        singular_values = S.detach().cpu().numpy()
        print(singular_values[:10])

        if self.args.visualization:
            # Visualization
            plt.figure(figsize=(8, 4))
            plt.plot(singular_values, marker='o', linestyle='-', linewidth=1)
            plt.title('Singular Values of Low-rank Matrix E')
            plt.xlabel('Singular Value Index')
            plt.ylabel('Singular Value')
            plt.grid(True)
            plt.show()

            # Accum
            cumulative_variance = np.cumsum(singular_values) / np.sum(singular_values)

            plt.figure(figsize=(8, 4))
            plt.plot(cumulative_variance, marker='o', linestyle='-', linewidth=1)
            plt.title('Cumulative Proportion of Singular Values')
            plt.xlabel('Singular Value Index')
            plt.ylabel('Cumulative Proportion')
            plt.grid(True)
            plt.show()

    def low_rank_node_penalty(self, ):  # might be useful
        E = self.prompt_encoder.E
        row_similarities = torch.matmul(E, E.transpose(-2, -1))
        expected_similarities = torch.eye(self.prompt_encoder.n).to(row_similarities.device)
        row_diversity_loss = self.row_diversity_weight * torch.sum((row_similarities - expected_similarities) ** 2)
        return row_diversity_loss, row_similarities, expected_similarities


def inner(t1, t2, epsilon=1e-15):
    t1 = t1 / (t1.norm(dim=1).view(-1, 1) + epsilon)
    t2 = t2 / (t2.norm(dim=1).view(-1, 1) + epsilon)
    return (1 - (t1 * t2).sum(1)).mean()


def diff(t1, t2, epsilon=1e-15):
    t1 = t1 / (t1.norm(dim=1).view(-1, 1) + epsilon)
    t2 = t2 / (t2.norm(dim=1).view(-1, 1) + epsilon)
    return 0.5 * ((t1 - t2) ** 2).sum(1).mean()


class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.parent[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.parent[rootX] = rootY
            else:
                self.parent[rootY] = rootX
                self.rank[rootX] += 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)
