import numpy as np
from models import *
import torch.nn.functional as F
import torch
import deeprobust.graph.utils as utils
from torch.nn.parameter import Parameter
from tqdm import tqdm
import scipy.sparse as sp
import pandas as pd
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from copy import deepcopy
from models.cop import COP
from utils import reset_args
from gtransform_adj import EdgeAgent
from torch_geometric.utils import to_scipy_sparse_matrix, from_scipy_sparse_matrix, dropout_adj, is_undirected, to_undirected
from gtransform_adj import *

class GraphAgent(EdgeAgent):

    def __init__(self, data_all, args):
        self.device = 'cuda'
        self.args = args
        self.data_all = data_all
        self.model = self.pretrain_model()

    def learn_graph(self, data):
        print('====learning on this graph===')
        args = self.args
        self.setup_params(data)
        args = self.args
        model = self.model
        model.eval() # should set to eval

        self.max_final_samples = 5

        if args.noise_feature > 0:
            add_feature_noise(data, args.noise_feature)
        if args.noise_structure > 0:
            add_structure_noise(data, args.noise_structure)
        if args.noise_feature > 0 or args.noise_structure > 0:
            feat, labels = data.graph['node_feat'].to(self.device), data.label.to(self.device) #.squeeze()
            edge_index = data.graph['edge_index'].to(self.device)
            output = self.model.predict(feat, edge_index)
            print("===Test set results on noisy graph:")
            self.evaluate_single(self.model, output, labels, data)
            # self.get_perf(output, labels, data.test_mask)

        from utils import get_gpu_memory_map
        mem_st = get_gpu_memory_map()
        args = self.args
        self.data = data
        nnodes = data.graph['node_feat'].shape[0]
        d = data.graph['node_feat'].shape[1]
        if args.cop: # contional parameterization
            self.cop = COP(nfeat=d, device=self.device).to(self.device)
            cop = self.cop
            self.optimizer_feat = torch.optim.Adam(self.cop.parameters(), lr=args.lr_feat)
            # self.delta_feat = None
        else:
            delta_feat = Parameter(torch.FloatTensor(nnodes, d).to(self.device))
            self.delta_feat = delta_feat
            delta_feat.data.fill_(1e-7)
            self.optimizer_feat = torch.optim.Adam([delta_feat], lr=args.lr_feat)

        model = self.model
        for param in model.parameters():
            param.requires_grad = False
        model.eval() # should set to eval

        feat, labels = data.graph['node_feat'].to(self.device), data.label.to(self.device)#.squeeze()
        edge_index = data.graph['edge_index'].to(self.device)
        self.edge_index, self.feat, self.labels = edge_index, feat, labels
        # edge_index, feat, labels = data.edge_index, data.x, data.y
        self.edge_weight = torch.ones(self.edge_index.shape[1]).to(self.device)

        if args.ptb_rate>0:
            adj = self.evasion_attack(adj, feat, labels)

        n_perturbations = int(args.ratio * self.edge_index.shape[1] //2)
        print('n_perturbations:', n_perturbations)
        self.sample_random_block(n_perturbations)

        self.perturbed_edge_weight.requires_grad = True
        self.optimizer_adj = torch.optim.Adam([self.perturbed_edge_weight], lr=args.lr_adj)

        edge_index, edge_weight = edge_index, None

        for it in tqdm(range(args.epochs//(args.loop_feat+args.loop_adj))):
            for loop_feat in range(args.loop_feat):
                self.optimizer_feat.zero_grad()
                if args.cop:
                    delta_feat = cop(feat)
                loss = self.test_time_loss(model, feat+delta_feat, edge_index, edge_weight)

                if args.debug==2 or args.debug==3:
                    self.check_corr()
                # loss.backward(retain_graph=True)
                loss.backward()

                if loop_feat == 0:
                    print(f'Epoch {it}, Loop Feat {loop_feat}: {loss.item()}')

                self.optimizer_feat.step()
                # self.scheduler.step()
                if args.debug==2 or args.debug==3:
                    output = model.predict(feat+delta_feat, edge_index, edge_weight)
                    print('Debug Test:', self.evaluate_single(model, output, labels, data, verbose=0))

            # self.evaluate_single(self.model, output, labels, data)
            new_feat = (feat+delta_feat).detach()
            for loop_adj in range(args.loop_adj):
                self.perturbed_edge_weight.requires_grad = True
                edge_index, edge_weight  = self.get_modified_adj()
                if torch.cuda.is_available() and self.do_synchronize:
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()

                loss = self.test_time_loss(model, new_feat, edge_index, edge_weight)
                # print(loss)

                gradient = grad_with_checkpoint(loss, self.perturbed_edge_weight)[0]
                if not args.existing_space:
                    if torch.cuda.is_available() and self.do_synchronize:
                        torch.cuda.empty_cache()
                        torch.cuda.synchronize()
                # if it == 0:
                #     print(f'Epoch {it}: {loss}')
                if loop_adj == 0:
                    print(f'Epoch {it}, Loop Adj {loop_adj}: {loss.item()}')

                with torch.no_grad():
                    self.update_edge_weights(n_perturbations, it, gradient)
                    self.perturbed_edge_weight = self.project(
                        n_perturbations, self.perturbed_edge_weight, self.eps)
                    # edge_index, edge_weight = self.get_modified_adj()
                    # logits = self.attacked_model(data=self.attr.to(self.device), adj=(edge_index, edge_weight))
                    # accuracy = utils.accuracy(logits, self.labels, self.idx_attack)
                    del edge_index, edge_weight #, logits
                    # Resampling of search space (Algorithm 1, line 9-14)
                    if not args.existing_space:
                        if it < self.epochs_resampling - 1:
                            self.resample_random_block(n_perturbations)
                if it < self.epochs_resampling - 1:
                    self.perturbed_edge_weight.requires_grad = True
                    self.optimizer_adj = torch.optim.Adam([self.perturbed_edge_weight], lr=args.lr_adj)
                    # elif self.with_early_stopping and it == self.epochs_resampling - 1:
                    #     # Retreive best epoch if early stopping is active (not explicitly covered by pesudo code)
                    #     print(f'Loading search space of epoch {best_epoch} (accuarcy={best_accuracy}) for fine tuning\n')
                    #     self.current_search_space = best_search_space.to(self.device)
                    #     self.modified_edge_index = best_edge_index.to(self.device)
                    #     self.perturbed_edge_weight = best_edge_weight_diff.to(self.device)
                    #     self.perturbed_edge_weight.requires_grad = True
            # edge_index, edge_weight = self.sample_final_edges(n_perturbations, data)
            if args.loop_adj != 0:
                edge_index, edge_weight  = self.get_modified_adj()
                edge_weight = edge_weight.detach()

            # self.evaluate_single(self.model, output, labels, data)
        if args.cop:
            delta_feat = cop(feat)
        print(f'Epoch {it+1}: {loss}')
        gpu_mem = get_gpu_memory_map()
        print(f'Mem used: {int(gpu_mem[args.gpu_id])-int(mem_st[args.gpu_id])}MB')

        # Sample final discrete graph (Algorithm 1, line 16)
        # edge_index, edge_weight = self.sample_final_edges(n_perturbations, data)[0], None
        if args.loop_adj != 0:
            edge_index, edge_weight = self.sample_final_edges(n_perturbations, data)

        with torch.no_grad():
            loss = self.test_time_loss(model, feat+delta_feat, edge_index, edge_weight)
        print('final loss:', loss.item())

        output = model.predict(feat+delta_feat, edge_index, edge_weight)
        print('Test:')

        if args.dataset == 'elliptic':
            return self.evaluate_single(model, output, labels, data), output[data.mask], labels[data.mask]
        else:
            return self.evaluate_single(model, output, labels, data), output, labels

    def augment(self, strategy='dropedge', p=0.5, edge_index=None, edge_weight=None):
        model = self.model
        if hasattr(self, 'delta_feat'):
            if self.args.cop:
                delta_feat = self.cop(self.feat)
            else:
                delta_feat = self.delta_feat
            feat = self.feat + delta_feat
        else:
            feat = self.feat
        # edge_index = self.edge_index
        if strategy == 'shuffle':
            # edge_index, feat, labels = data.edge_index, data.x, data.y
            idx = np.random.permutation(feat.shape[0])
            shuf_fts = feat[idx, :]
            # output = model.forward(shuf_fts, edge_index)
            output = model.get_embed(shuf_fts, edge_index, edge_weight)
        if strategy == "dropedge":
            edge_index, edge_weight = dropout_adj(edge_index, edge_weight, p=p)
            output = model.get_embed(feat, edge_index, edge_weight)
        if strategy == "dropnode":
            feat = self.feat + self.delta_feat
            mask = torch.cuda.FloatTensor(len(feat)).uniform_() > p
            feat = feat * mask.view(-1, 1)
            output = model.get_embed(feat, edge_index, edge_weight)
        if strategy == "rwsample":
            import augmentor as A
            if self.args.dataset in ['twitch-e', 'elliptic']:
                walk_length = 1
            else:
                walk_length = 10
            aug = A.RWSampling(num_seeds=1000, walk_length=walk_length)
            x = self.feat + self.delta_feat
            x2, edge_index2, edge_weight2 = aug(x, edge_index, edge_weight)
            output = model.get_embed(x2, edge_index2, edge_weight2)

        if strategy == "dropmix":
            feat = self.feat + self.delta_feat
            mask = torch.cuda.FloatTensor(len(feat)).uniform_() > p
            feat = feat * mask.view(-1, 1)
            edge_index, edge_weight = dropout_adj(edge_index, edge_weight, p=p)
            output = model.get_embed(feat, edge_index, edge_weight)

        if strategy == "dropfeat":
            feat = F.dropout(self.feat, p=p) + self.delta_feat
            output = model.get_embed(feat, edge_index, edge_weight)
        if strategy == "featnoise":
            # feat = F.dropout(feat, p=p)
            mean, std = 0, p
            noise = torch.randn(feat.size()) * std + mean
            feat = feat + noise.to(feat.device)
            output = model.get_embed(feat, edge_index)
        return output

def add_feature_noise(data, noise_ratio):
    np.random.seed(0)
    feat = data.graph['node_feat']
    n, d = feat.shape
    noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*n), d)))
    indices = np.arange(n)
    indices = np.random.permutation(indices)[: int(noise_ratio*n)]

    delta_feat = torch.zeros_like(feat)
    delta_feat[indices] = noise - feat[indices]

    feat[indices] = noise
    return delta_feat

def add_feature_noise_test(data, noise_ratio):
    np.random.seed(0)
    n, d = data.x.shape
    indices = np.arange(n)
    indices = np.random.permutation(indices)
    test_nodes = indices[data.test_mask]
    selected = test_nodes[: int(noise_ratio*len(test_nodes))]
    noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*len(test_nodes)), d)))

    delta_feat = torch.zeros_like(data.x)
    delta_feat[selected] = noise - data.x[selected]
    data.x[selected] = noise
    return delta_feat


def add_structure_noise(data, noise_ratio):
    np.random.seed(0)
    from deeprobust.graph.global_attack import Random
    from torch_geometric.utils import to_scipy_sparse_matrix, from_scipy_sparse_matrix, dropout_adj
    adj = to_scipy_sparse_matrix(data.edge_index)
    model = Random()
    model.attack(adj, n_perturbations=int(noise_ratio*adj.nnz)//2, type='remove')
    modified_adj = model.modified_adj
    data.edge_index = from_scipy_sparse_matrix(modified_adj)[0].to(data.edge_index.device)

def compare_models(model1, model2):
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        if p1.data.ne(p2.data).sum() > 0:
            return False
    return True

def inner(t1, t2):
    t1 = t1 / (t1.norm(dim=1).view(-1,1) + 1e-15)
    t2 = t2 / (t2.norm(dim=1).view(-1,1) + 1e-15)
    return (1-(t1 * t2).sum(1)).mean()

def diff(t1, t2):
    t1 = t1 / (t1.norm(dim=1).view(-1,1) + 1e-15)
    t2 = t2 / (t2.norm(dim=1).view(-1,1) + 1e-15)
    return 0.5*((t1-t2)**2).sum(1).mean()

def corr(t1, t2):
    norm1 = (t1.norm() + 1e-15)
    t1 = t1 / norm1
    norm2 = (t2.norm() + 1e-15)
    t2 = t2 / norm2
    sims = (t1 * t2).sum()
    return '%.2f' % sims.item(), norm1.item(), norm2.item()

