import torch
import torch.nn as nn

import pandas as pd

from torch_geometric.data import HeteroData, Data, InMemoryDataset

from torch_geometric.nn import MessagePassing, GCNConv, GCN
from torch_geometric.utils import add_self_loops, degree, contains_isolated_nodes, remove_isolated_nodes
import torch_geometric.transforms as transforms

from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import NearestNeighbors

from scipy.spatial import distance_matrix
import numpy as np



import os

import importlib

def predefine(x, s, y, input_dim, sensitive_size):
    
    def _train(epoch, model, optimizer, criterion, x, s, y):
        model.train()
        loss_val = 0.0
        
        y_pred = model(x, s)
        ce_loss = criterion(y_pred, y.view_as(y_pred))
        loss = ce_loss 
        
        loss_val += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        print(f'Epoch:{epoch}, loss:{loss_val}')
    
    class _LinearMapping(nn.Module):
        def __init__(self, input_dim, sensitive_size):
            '''
            sx -> y
            '''
            super(_LinearMapping, self).__init__()
            self.fc1 = nn.Linear(input_dim + sensitive_size, 1)
            
        

        def forward(self, x, s):
            xs = torch.cat((x, s), dim=1)
            out = self.fc1(xs)
            return torch.flatten(out)
    
    device = "cuda:0" if torch.cuda.is_available() else "cpu"     
    model = _LinearMapping(input_dim, sensitive_size)
    model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr= 0.01)
    criterion = nn.BCEWithLogitsLoss()
    x, s, y = x.to(device), s.to(device), y.to(device)
    
    for epoch in range(100):
        _train(epoch, model, optimizer, criterion, x, s, y)
        
    torch.save(model.fc1.state_dict(), 'weights/predefine.pt')
    
def predefine_s(args, s, z, edge_index):
    def _train(epoch, model, optimizer, criterion, s, z):
        model.train()
        loss_val = 0.0
        
        s_pred = model(z, edge_index)
        ce_loss = criterion(s_pred, s.view_as(s_pred))
        loss = ce_loss 
        
        loss_val += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        print(f'Epoch:{epoch}, loss:{loss_val}')  
         
    gnn = GCN(args.input_size, 16, 3, args.sensitive_size) 
    optimizer = torch.optim.Adam(gnn.parameters(), lr= 0.01)    
    criterion = nn.BCEWithLogitsLoss()
    for epoch in range(200):
        _train(epoch, gnn, optimizer, criterion, s, z)
    
    
    gnn.eval()
    print(torch.mean(s))
    s = gnn(z, edge_index)
    prob_s = torch.sigmoid(s)
    s_random = torch.rand(s.shape)
    
    s = torch.where(s_random < prob_s, 1.0, 0.0)
        
    print(torch.mean(s))    
    return s 



def generate_adj_matrix_random(n_nodes, threshold):
    p = torch.rand(n_nodes, n_nodes)
    upper_diag = torch.triu(p, diagonal=1)
    adj_matrix = upper_diag + upper_diag.T
    adj_matrix = torch.where(adj_matrix > threshold, 1.0, 0.0)
    return adj_matrix

def generate_adj_matrix_knn(node_features, sa, node_labels, k=101):
    node = torch.cat((node_features, sa), dim=1)
    nbrs = NearestNeighbors(n_neighbors=k, n_jobs= -1).fit(node)
    adj_matrix = nbrs.kneighbors_graph(node).toarray()
    adj_matrix = torch.FloatTensor(adj_matrix)
    upper_diag = torch.triu(adj_matrix, diagonal=1)
    adj_matrix = upper_diag + upper_diag.T
    return adj_matrix

def generate_adj_matrix_distance(node_features, sa, thresh=0.25):
    x = torch.cat((node_features, sa), dim=1)
    df_euclid = 1 / (1 + distance_matrix(x.T.T, x.T.T))
    
    idx_map = []
    for ind in range(df_euclid.shape[0]):
        max_sim = np.sort(df_euclid[ind, :])[-2]
        neig_id = np.where(df_euclid[ind, :] > thresh*max_sim)[0]
        import random
        random.shuffle(neig_id)
        for neig in neig_id:
            if neig != ind:
                idx_map.append([ind, neig])
    idx_map =  torch.tensor(idx_map).T
    
    return idx_map
    

def generate_pyg_edge(adj_matrix):
    edge_index = (adj_matrix > 0).nonzero().t()
    return edge_index

class RealNetDataset(InMemoryDataset):
    def __init__(self, args, node_features, z_features, sa, node_labels, transform=None, pre_transform = None, pre_filter = None, **kwargs):
        self.args = args
        self.node_features = node_features
        self.z_features = z_features
        self.sa = sa
        self.node_labels = node_labels
        super().__init__(args.root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [self.args.filename]

    @property
    def processed_file_names(self):
        return [f'{self.args.dataset}_real_{self.args.gtseed}_k{self.args.k}.pt']
    
    def process(self):
        node_features, z_features, sa, node_labels = self.node_features, self.z_features, self.sa, self.node_labels
        edge_index = generate_adj_matrix_distance(node_features, sa, 0.6)
        data = Data(x=node_features, z=z_features, edge_index= edge_index, y= node_labels, s= sa)
        data.validate(raise_on_error=True)
        data_list = [data]
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
        
          
        data, slices = self.collate(data_list)
        
        torch.save((data, slices), self.processed_paths[0])
        


class NetDataset(InMemoryDataset):
    def __init__(self, args, node_features, z_features, sa, node_labels, mapping_function, transform=None, pre_transform = None, pre_filter = None):
        self.args = args
        self.node_features = node_features
        self.z_features = z_features
        self.sa = sa
        self.node_labels = node_labels
        self.mapping_function = mapping_function
        super().__init__(args.root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [self.args.filename]

    @property
    def processed_file_names(self):
        return [f'{self.args.dataset}_{self.mapping_function}_{self.args.gtseed}_k{self.args.k}.pt']
    
    
    
    def get_mapping(self):
        mapping = getattr(importlib.import_module('util.mapping_function'), self.mapping_function)
        return mapping(self.args.input_size, self.args.sensitive_size, self.args.gen_A)

    def process(self):
        # Read data into huge `Data` list.
        node_features, z_features, sa, node_labels = self.node_features, self.z_features, self.sa, self.node_labels
        adj_matrix = generate_adj_matrix_knn(node_features, sa, self.args.k)
        edge_index = generate_pyg_edge(adj_matrix)
        
        # edge_index = generate_adj_matrix_distance(node_features, sa, 0.75)
        
        if self.args.mode == 'predefine':
            # pass
            sa = predefine_s(self.args, sa, node_features, edge_index)
            
        data = Data(x=node_features, z=node_features, edge_index= edge_index, y= node_labels, s= sa)
        data.validate(raise_on_error=True)
        # conv = GCNMessagePassing(self.args.sensitive_size, data.x.shape[1])
        conv = GCN(self.args.sensitive_size, 8, 3, data.x.shape[1])
        with torch.no_grad():
          
            a = conv(2 * data.s - 1, data.edge_index)
            a = a * self.args.coff_A
            do_pos = torch.ones(data.s.shape)
            do_neg = torch.zeros(data.s.shape) - 1
       
            a_pos = conv(do_pos, data.edge_index)
            a_pos = a_pos * self.args.coff_A
          
            a_neg = conv(do_neg, data.edge_index)
            a_neg = a_neg * self.args.coff_A
        
        
 
        noise_y = torch.rand_like(node_labels)
        noise_y = (noise_y - 0.5) * self.args.scale
        noise_x = torch.rand_like(node_features)
        noise_x = (noise_x - 0.5) * self.args.scale
        
        if self.args.mode == 'predefine':
            predefine(data.z + a + noise_x, data.s, data.y, self.args.input_size, self.args.sensitive_size)
        
        mapping = self.get_mapping()
        modified_y = torch.where((mapping(data.z + a + noise_x, data.s, edge_index) + noise_y) > 0.0, 1.0, 0.0)
        do_pos_y = torch.where((mapping(data.z + a_pos + noise_x, do_pos, edge_index) + noise_y) > 0.0, 1.0, 0.0)
        do_neg_y = torch.where((mapping(data.z + a_neg + noise_x, do_neg, edge_index) + noise_y) > 0.0, 1.0, 0.0)
        
        print(modified_y.mean())
        assert (~modified_y.mean().eq(1.0) and ~modified_y.mean().eq(0.0))
        data.update_tensor(data.z + a + noise_x, 'x')
        data.update_tensor(modified_y, 'y')
        data.update_tensor(data.z + a_pos + noise_x, 'do_pos_x')
        data.update_tensor(data.z + a_neg + noise_x, 'do_neg_x')
        data.update_tensor(do_pos_y, 'do_pos_y')
        data.update_tensor(do_neg_y, 'do_neg_y')
        
        data.validate(raise_on_error=True)
        data_list = [data]
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
            
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        
class GCNMessagePassing(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = nn.Linear(in_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self): 
        self.lin.reset_parameters()
        
    def forward(self, s, edge_index, z):
        # x has shape [N, in_channels]
        # devide x into s, x
        with torch.no_grad():
            s = 2 * s - 1
            # edge_index has shape [2, E]

            # Step 1: Add self-loops to the adjacency matrix.
            edge_index, _ = add_self_loops(edge_index, num_nodes=s.shape[0])

            # Step 2: Linearly transform node feature matrix.
            s = self.lin(s)

            # Step 3: Compute normalization.
            row, col = edge_index
            deg = degree(col, s.shape[0], dtype=s.dtype)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

            # Step 4-5: Start propagating messages.
            out = self.propagate(edge_index, x=s, norm=norm)

            return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j
    
""" if __name__ == '__main__':
    args = parse_args()
    # print(args.root)
    root = os.path.join('data', 'credit')
    mapping_function = LinearMapping(23,22)
    mapping_function = LinearMappingNeighbor(23, 22)
    pre_transform = transforms.Compose([transforms.RemoveIsolatedNodes(),
                                    transforms.RandomNodeSplit(split='train_rest', num_val=0, num_test=0.2)])
    dataset = NetDataset(args, mapping_function, pre_transform=pre_transform)
    print(dataset[0]) """
        
    