import random

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.utils import dense_to_sparse
import networkx as nx
from grakel import Graph as GrakelGraph, GraphKernel
from torch_geometric.data import Data

import matplotlib.pyplot as plt
from torch_geometric.datasets import Planetoid, WikipediaNetwork

import copy

from models import GCN,GIN,GAT,ChebNet,GraphSAGE,Graphormer




import argparse

# Default project and data directory constants


# ------------------ Hyperparameters ------------------
def get_command_line_parser():
    parser = argparse.ArgumentParser(description="Graph model and generator configuration")

    parser.add_argument('--dataset_name', type=str, default='PubMed',
                        choices=['Cora', 'CiteSeer', 'PubMed', 'Squirrel', 'Chameleon'],
                        help='Specific dataset name or variant')
    # Model hyperparameters
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed for reproducibility')
    parser.add_argument('--model_type', type=str, default='GCN',
                        choices=['GCN', 'ChebNet', 'GIN', 'GAT', 'GraphSAGE', 'Graphormer', 'GraphGAN'],
                        help='Type of graph model to use')
    parser.add_argument('--target_layer', type=int, default=1,
                        choices=[1, 2],
                        help='Which layer to target (1 or 2)')
    parser.add_argument('--gen_mode', type=str, default='feat',
                        choices=['adj', 'feat', 'both'],
                        help='Generation mode: adjacency matrix, features, or both')
    parser.add_argument('--hidden_channels', type=int, default=32,
                        help='Number of hidden channels in the model')

    # Optimization parameters
    parser.add_argument('--lr_model', type=float, default=0.005,
                        help='Learning rate for the model')
    parser.add_argument('--wd_model', type=float, default=5e-4,
                        help='Weight decay for the model optimizer')
    parser.add_argument('--lr_gen', type=float, default=0.0005,
                        help='Learning rate for the generator')

    # Training schedule
    parser.add_argument('--num_epochs_model', type=int, default=10000,
                        help='Number of epochs to train the model')
    parser.add_argument('--num_epochs_gen', type=int, default=10000,
                        help='Number of epochs to train the generator')
    parser.add_argument('--patience1', type=int, default=500,
                        help='Early stopping patience for the model')
    parser.add_argument('--patience2', type=int, default=100,
                        help='Early stopping patience for the generator')
    parser.add_argument('--threshold', type=float, default=0.5,
                        help='Threshold value for classification')

    return parser

if __name__ == '__main__':
    parser = get_command_line_parser()
    args = parser.parse_args()
    # ------------------ Fix random seed ------------------
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    # ------------------ Load data ------------------
    DATASET_NAME = 'PubMed'
    if DATASET_NAME in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = Planetoid(root=f'data/{DATASET_NAME}', name=DATASET_NAME)
    elif DATASET_NAME in ['Squirrel', 'Chameleon']:
        dataset = WikipediaNetwork(root=f'data/{DATASET_NAME}', name=DATASET_NAME)
    else:
        raise ValueError(f"Unknown dataset {DATASET_NAME!r}")
    data = dataset[0].to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    device = data.x.device
    mean = data.x.mean(dim=0, keepdim=True)
    std  = data.x.std(dim=0, keepdim=True)
    split_id = 0
    if data.train_mask.dim() == 2:
        data.train_mask = data.train_mask[:, split_id]
        data.val_mask   = data.val_mask[:,   split_id]
        data.test_mask  = data.test_mask[:,  split_id]

    # def compute_dist_matrix(data, max_dist=10):
    #
    #     G = to_networkx(data, to_undirected=True)
    #
    #     lengths = dict(nx.all_pairs_shortest_path_length(G, cutoff=max_dist))
    #     N = data.num_nodes
    #
    #     dist = torch.full((N, N), max_dist, dtype=torch.long)
    #     for i, nbrs in lengths.items():
    #         for j, d in nbrs.items():
    #             dist[i, j] = d
    #     return dist

    #  data
    # data.dist = compute_dist_matrix(data, max_dist=10)





    cosine_sims = []
    match_ratios = []
    Sorce = []
    for run in range(5):
        random.seed(42)
        np.random.seed(42)
        torch.manual_seed(42)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(42)

        # ------------------ Train & save base model ------------------
        if   args.model_type == 'GCN':
            base_model = GCN(dataset.num_node_features, args.hidden_channels, dataset.num_classes)
        elif args.model_type == 'ChebNet':
            base_model = ChebNet(dataset.num_node_features, args.hidden_channels, dataset.num_classes)
        elif args.model_type == 'GIN':
            base_model = GIN(dataset.num_node_features, args.hidden_channels, dataset.num_classes)
        elif args.model_type == 'GAT':
            base_model = GAT(dataset.num_node_features, args.hidden_channels, dataset.num_classes)
        elif args.model_type == 'GraphSAGE':
            base_model = GraphSAGE(dataset.num_node_features, args.hidden_channels, dataset.num_classes)
        elif args.model_type == 'Graphormer':
            base_model = Graphormer(dataset.num_node_features, args.hidden_channels, dataset.num_classes)
        else:
            raise ValueError("args.model_type must be one of ['GCN','GIN','GAT','GraphSAGE']")
        base_model = base_model.to(device)
        opt = torch.optim.Adam(base_model.parameters(), lr=args.lr_model, weight_decay=args.wd_model)

        def train_base():
            base_model.train()
            opt.zero_grad()
            out = base_model(data)
            loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            opt.step()
            return loss.item()

        def test_base():
            base_model.eval()
            logits = base_model(data)
            accs = []
            for mask in [data.train_mask, data.val_mask, data.test_mask]:
                pred = logits[mask].max(1)[1]
                accs.append(pred.eq(data.y[mask]).sum().item() / mask.sum().item())
            return accs

        print(f"Training {args.model_type}...")


        best_loss = float('inf')
        patience_counter = 0
        best_state_dict = None

        for epoch in range(1, args.num_epochs_model + 1):
            loss = train_base()
            tr, va, te = test_base()

            #
            if epoch % 10 == 0:
                print(f"Epoch {epoch:03d}, Loss={loss:.4f}, Train={tr:.4f}, Val={va:.4f}, Test={te:.4f}")

            #
            if loss < best_loss:
                best_loss = loss
                #
                best_state_dict = copy.deepcopy(base_model.state_dict())
                patience_counter = 0
            else:
                patience_counter += 1


            if patience_counter >= args.patience2:
                print(f"Early stopping at epoch {epoch}. Best loss: {best_loss:.4f}")
                break


        torch.save(best_state_dict, 'output/model_save.pth')
        print(f"Saved best model (loss={best_loss:.4f}) to model_save.pth")

        # ------------------ Generators ------------------

        density_global_X = (data.x != 0).float().mean().item()


        def compute_density_global_A(edge_index: torch.LongTensor, num_nodes: int) -> float:


            edges = torch.stack([edge_index[0], edge_index[1]], dim=1)  # [E, 2]
            edges = torch.sort(edges, dim=1).values  #  (i, j)  i <= j


            edges = edges[edges[:, 0] < edges[:, 1]]  #  i < j
            unique_edges = torch.unique(edges, dim=0)  #


            num_undirected = unique_edges.size(0)  #
            M = num_nodes * (num_nodes - 1) / 2  #

            density_global_A = num_undirected / M
            return float(density_global_A)


        edge_index = data.edge_index  # torch.LongTensor, shape [2, E]
        N = data.num_nodes  # int

        density_global_A = compute_density_global_A(edge_index, N)


        class BinarizeSTE(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input, threshold):
                # input: (M,) in (0,1), threshold: scalar
                return (input >= threshold).float()

            @staticmethod
            def backward(ctx, grad_output):

                return grad_output, None


        class TriangularSTEAdjacency(nn.Module):
            def __init__(self,
                         num_nodes: int,
                         density_global: float,
                         s: float = 10.0,
                         lambda_reg: float = 1e-3):

                super().__init__()
                self.n = num_nodes


                self.tri_i, self.tri_j = torch.triu_indices(num_nodes, num_nodes, offset=1)
                self.M = self.tri_i.numel()


                self.logits = nn.Parameter(
                    torch.randn(self.M) * math.sqrt(2.0 / self.M)
                )


                init = max(1e-6, min(density_global, 1 - 1e-6))
                logit_density = math.log(init / (1.0 - init))
                self.log_density = nn.Parameter(torch.tensor(logit_density))


                self.s = s
                self.lambda_reg = lambda_reg


                self.last_margin = None

            def forward(self):

                p = torch.sigmoid(self.s * self.logits)  # (M,)


                density = torch.sigmoid(self.log_density)  # scalar in (0,1)


                k = max(1, int((density.detach().item() * self.M)))


                flat = p
                topk_vals, _ = torch.topk(flat, k)
                threshold = topk_vals.min().detach()


                b_hard = BinarizeSTE.apply(p, threshold)  # (M,) ∈ {0,1}
                b = p + (b_hard - p).detach()  # STE


                margin = torch.mean(p * (1 - p)) * self.lambda_reg
                self.last_margin = margin


                A = torch.zeros(self.n, self.n, device=p.device)
                A[self.tri_i, self.tri_j] = b
                A[self.tri_j, self.tri_i] = b

                return A

            @torch.no_grad()
            def inference_adj(self):


                p = torch.sigmoid(self.s * self.logits)
                density = torch.sigmoid(self.log_density)
                k = max(1, int((density.item() * self.M)))
                topk_vals, _ = torch.topk(p, k)
                threshold = topk_vals.min()

                b = (p >= threshold).float()
                A = torch.zeros(self.n, self.n, device=p.device)
                A[self.tri_i, self.tri_j] = b
                A[self.tri_j, self.tri_i] = b

        class FeatureGenerator_discrete(nn.Module):
            #for Cora,Citeseer,Squirrel,Chameleon dataset
            def __init__(self, num_nodes, num_features,
                         s=10.0, lambda_reg=1e-3):
                super().__init__()

                self.logits = nn.Parameter(
                    mean.repeat(num_nodes, 1) + std.repeat(num_nodes, 1)
                    * torch.randn(num_nodes, num_features, device=mean.device)
                )

                init = torch.clamp(torch.tensor(density_global_X, device=mean.device), 1e-6, 1-1e-6)
                logit_density = torch.log(init / (1 - init))
                self.log_density = nn.Parameter(logit_density)

                self.s = s
                self.lambda_reg = lambda_reg

                self.last_margin = None

            def forward(self):

                p = torch.sigmoid(self.s * self.logits)

                density = torch.sigmoid(self.log_density)

                flat = p.view(-1)
                k = max(1, int(density.detach().item() * flat.numel()))
                topk_vals, _ = torch.topk(flat, k)
                args.threshold_value = topk_vals.min()

                b_hard = (p >= args.threshold_value).float()
                b = p + (b_hard - p).detach()

                margin = torch.mean(p * (1 - p)) * self.lambda_reg
                self.last_margin = margin
                return b

            def discretize(self, p):

                flat = p.view(-1)
                density = torch.sigmoid(self.log_density).detach().item()
                k = max(1, int(density * flat.numel()))
                args.threshold_value = torch.topk(flat, k).values.min()
                return (p >= args.threshold_value).float()

        class FeatureGenerator_Continuous(nn.Module):
            #for PubMed dataset
            def __init__(self, num_nodes, feat_dim):
                super().__init__()
                init_noise = mean + std * torch.randn_like(data.x)
                self.param = nn.Parameter(init_noise)

                #self.param = nn.Parameter(torch.randn(num_nodes, feat_dim))
            def forward(self):
                return torch.relu(self.param)



        # ------------------ Load fixed model & compute real target ------------------
        fixed = base_model.__class__(dataset.num_node_features, args.hidden_channels, dataset.num_classes)
        fixed.load_state_dict(torch.load('output/model_save.pth'))
        fixed = fixed.to(device).eval()

        def get_target(model, x, ei, ew, layer):
            if layer == 1:
                #x = F.dropout(x, p=0.5, training=model.training)
                h = model.conv1(x, ei, edge_weight=ew) if isinstance(model, (GCN, ChebNet)) else model.conv1(x, ei)
                #return h
                return F.relu(h)
            else:
                h1 = model.conv1(x, ei, edge_weight=ew) if isinstance(model, (GCN, ChebNet)) else model.conv1(x, ei)
                h1 = F.relu(h1); h1 = F.dropout(h1, training=False)
                h2 = model.conv2(h1, ei, edge_weight=ew) if isinstance(model, (GCN, ChebNet)) else model.conv2(h1, ei)
                #return h2
                return F.log_softmax(h2, dim=1)



        # def get_target(model, x, ei, ew, layer):
        #     #for Res Model
        #     if layer == 1:
        #         h1 = model.conv1(x, ei)
        #         r1 = model.res1(x) if model.res1 is not None else x
        #         h1 = h1 + model.alpha1 * r1
        #
        #         #return h1
        #         return F.relu(h1)
        #     else:
        #         h1 = model.conv1(x, ei)
        #         res1 = model.res1(x) if model.res1 is not None else x
        #         h1 = h1 + model.alpha1 * res1
        #         h1 = F.relu(h1)
        #
        #         h1 = F.dropout(h1, p=0.5, training=model.training)
        #
        #
        #         h2 = model.conv2(h1, ei)
        #         res2 = model.res2(h1) if model.res2 is not None else h1
        #         h2 = h2 + model.alpha2 * res2
        #         return h2
        #         #return F.log_softmax(h2, dim=1)

        with torch.no_grad():
            H_real = get_target(fixed, data.x, data.edge_index, None, args.target_layer)

        # ------------------ Init & train generator ------------------
        if args.gen_mode in ['adj', 'both']:
            genG = TriangularSTEAdjacency(data.num_nodes,density_global_A).to(device)
        if args.gen_mode in ['feat', 'both']:
            #featG = FeatureGenerator_discrete(data.num_nodes, dataset.num_node_features).to(device)
            featG = FeatureGenerator_Continuous(data.num_nodes, dataset.num_node_features).to(device)

        params = []
        if   args.gen_mode == 'adj':
            params = list(genG.parameters())
        elif args.gen_mode == 'feat':
            params = list(featG.parameters())
        elif args.gen_mode == 'both':
            params = list(genG.parameters()) + list(featG.parameters())

        optG = torch.optim.Adam(params, lr=args.lr_gen)

        print("Training generator...")
        # for epoch in range(1, args.num_epochs_gen + 1):
        #     optG.zero_grad()
        #     if args.gen_mode == 'adj':
        #         A = genG(); ei, ew = dense_to_sparse(A); x_in = data.x
        #     elif args.gen_mode == 'feat':
        #         ei, ew = data.edge_index, None; x_in = featG()
        #     else:
        #         A = genG(); ei, ew = dense_to_sparse(A); x_in = featG()
        #     H_gen = get_target(fixed, x_in, ei, ew, args.target_layer)
        #     lossG = torch.norm(H_gen - H_real) / torch.norm(H_real)
        #     lossG = lossG + featG.margin_loss
        #     lossG.backward()
        #     optG.step()
        #     if epoch % 10 == 0:
        #         print(f"Gen Epoch {epoch:03d}, Loss={lossG:.4f}")



        best_loss = float('inf')
        best_x_in = None
        best_A = None
        best_ei = None
        best_ew = None
        no_improve_count = 0


        for epoch in range(1, args.num_epochs_gen + 1):
            optG.zero_grad()


            if args.gen_mode == 'adj':
                A = genG()
                ei, ew = dense_to_sparse(A)
                x_in = data.x
            elif args.gen_mode == 'feat':
                ei, ew = data.edge_index, None
                x_in = featG()
            else:  # 'both'
                A = genG()
                ei, ew = dense_to_sparse(A)
                x_in = featG()


            H_gen = get_target(fixed, x_in, ei, ew, args.target_layer)
            lossG = torch.norm(H_gen - H_real) / torch.norm(H_real)
            #lossG = lossG + featG.margin_loss
            lossG.backward()
            optG.step()


            if lossG.item() < best_loss:
                best_loss = lossG.item()
                best_x_in = x_in.clone().detach()
                if args.gen_mode in ['adj', 'both']:
                    best_A = A.clone().detach()

                    best_ei, best_ew = ei, ew
                no_improve_count = 0
            else:
                no_improve_count += 1


            if epoch % 10 == 0:
                print(f"Gen Epoch {epoch:03d}, Loss={lossG:.4f} "
                      f"(best={best_loss:.4f}, no_improve={no_improve_count})")


            if no_improve_count >= args.patience1:
                print(f"Early stopping at epoch {epoch} "
                      f"(no improvement in last {args.patience1} epochs).")
                break


        print(f"Training finished. Best loss={best_loss:.4f}.")

        if args.gen_mode == 'adj':
            A = best_A
            ei, ew = best_ei, best_ew
            x_in = data.x
        elif args.gen_mode == 'feat':
            ei, ew = data.edge_index, None
            x_in = best_x_in
        else:  # 'both'
            A = best_A
            ei, ew = best_ei, best_ew
            x_in = best_x_in






        # ------------------ Evaluation ------------------
        def to_nx(ei, num):
            G = nx.Graph()
            G.add_nodes_from(range(num))
            G.add_edges_from(ei.t().cpu().tolist())
            return G

        realG = to_nx(data.edge_index, data.num_nodes)
        if args.gen_mode in ['adj', 'both']:
            A_cont = genG().detach()
            A_bin  = (A_cont > args.threshold).float()
            ei2, _ = dense_to_sparse(A_bin)
            genNx  = to_nx(ei2, data.num_nodes)
        else:
            genNx = realG

        if args.gen_mode in ['feat', 'both']:

            #feats_bin = featG().detach()
            ei, ew = data.edge_index, None;
            #x_in = feats_bin
            H_gen = get_target(fixed, x_in, ei, ew, args.target_layer)
            lossG = torch.norm(H_gen - H_real) / torch.norm(H_real)
            print(lossG)
            torch.save(x_in, "output/GCN_PubMed.pt")

            v1 = x_in.view(x_in.size(0), -1)
            v2 = data.x.view(data.x.size(0), -1)
            cos_sims = F.cosine_similarity(v1, v2, dim=1)
            f_sim = cos_sims.mean().item()
            print(f"Post-binarized feature cosine sim = {f_sim:.4f} (~{f_sim*100:.2f}%)")

            with torch.no_grad():

                logits_orig = fixed(data)

                pred_orig = logits_orig.argmax(dim=1)

                #  metamer input
                if args.gen_mode == 'adj':
                    x_gen, ei_gen = data.x, ei2
                elif args.gen_mode == 'feat':
                    x_gen, ei_gen = x_in, data.edge_index
                    #print(x_gen)
                else:  # 'both'
                    x_gen, ei_gen = featG(), ei2

                data_gen = Data(x=x_gen, edge_index=ei_gen).to(device)
                #data_gen.dist = data.dist

                # metamer
                logits_gen = fixed(data_gen)
                # print(sum(sum(data.x)))
                # print(sum(sum(data_gen.x)))
                pred_gen = logits_gen.argmax(dim=1)
                #print(pred_gen.shape)


                match_ratio = pred_orig.eq(pred_gen).float().mean().item()
                print(f"Prediction match ratio = {match_ratio:.4f} (~{match_ratio * 100:.2f}%)")

                data_x_np = data.x.cpu().numpy()[:100, :100]
                x_gen_np = x_gen.cpu().numpy()[:100, :100]

                # 6.  data.x
                plt.figure(figsize=(8, 6))
                plt.imshow(data_x_np, aspect='auto')
                plt.title('Heatmap of data.x (Original Features)')
                plt.xlabel('Features')
                plt.ylabel('Nodes')
                plt.colorbar()
                plt.tight_layout()
                plt.savefig('output/data_x_heatmap.png')
                plt.close()


                plt.figure(figsize=(8, 6))
                plt.imshow(x_gen_np, aspect='auto')
                plt.title('Heatmap of x_gen (Generated Features)')
                plt.xlabel('Features')
                plt.ylabel('Nodes')
                plt.colorbar()
                plt.tight_layout()
                plt.savefig('output/x_gen_heatmap.png')
                plt.close()

                D = data_x_np[:100, :100] - x_gen_np[:100, :100]


                abs_max = np.max(np.abs(D))


                plt.figure(figsize=(8, 6))
                plt.imshow(D, aspect='auto', vmin=-abs_max, vmax=abs_max, cmap='RdBu_r')
                plt.title('Heatmap of Difference (data.x - x_gen)')
                plt.xlabel('Features')
                plt.ylabel('Nodes')
                plt.colorbar(label='Difference Value')
                plt.tight_layout()
                plt.savefig('output/difference_heatmap.png')
                plt.show()

            cosine_sims.append(f_sim)
            match_ratios.append(match_ratio)
            tt = f_sim*match_ratio+(1-f_sim)*(1-match_ratio)
            Sorce.append(tt)

        if args.gen_mode in ['adj', 'both']:

            g_real = GrakelGraph(nx.to_dict_of_lists(realG), node_labels={n: "0" for n in realG})
            g_gen  = GrakelGraph(nx.to_dict_of_lists(genNx),  node_labels={n: "0" for n in genNx})
            wl = GraphKernel(kernel=[{"name": "weisfeiler_lehman", "n_iter": 5},
                                      {"name": "vertex_histogram"}], normalize=True)
            K_wl = wl.fit_transform([g_real, g_gen]); sim_wl = K_wl[0,1]
            print(f"WL (structure-only) sim = {sim_wl:.4f} (~{sim_wl*100:.2f}%)")
            sp = GraphKernel(kernel=[{"name": "shortest_path"}], normalize=True)
            K_sp = sp.fit_transform([g_real, g_gen]); sim_sp = K_sp[0,1]
            print(f"Shortest‐Path sim = {sim_sp:.4f} (~{sim_sp*100:.2f}%)")

    print("\n===== Summary over 5 runs =====")
    print("Post-binarized feature cosine sim: mean={:.4f}, std={:.4f}"
          .format(np.mean(cosine_sims), np.std(cosine_sims)))
    print("Prediction match ratio:           mean={:.4f}, std={:.4f}"
          .format(np.mean(match_ratios), np.std(match_ratios)))
    print("Score:           mean={:.4f}, std={:.4f}"
          .format(np.mean(Sorce), np.std(Sorce)))

