import torch
import torch.nn.functional as F
import torch.nn as nn
from copy import deepcopy

import numpy as np
from scipy.spatial.distance import pdist, squareform, cdist
from sklearn.metrics.pairwise import cosine_similarity
from concurrent.futures import ProcessPoolExecutor as PExec
from ortools.linear_solver import pywraplp
_exec = PExec(1)

def loss_edges(y_pred_edges, y_edges, edge_cw):
    """
    Loss function for edge predictions.
    Args:
        y_pred_edges: Predictions for edges (batch_size, num_items, 2)
        y_edges: Targets for edges (batch_size, num_items, 2)
        edge_cw: Class weights for edges loss
    Returns:
        loss_edges: Value of loss function

    """
    # Edge loss
    y = F.log_softmax(y_pred_edges, dim=2)  # B x num_items x num_class
    y = y.permute(0, 2, 1)  # B x num_class x num_items
    loss_edges = nn.NLLLoss(edge_cw)(y, y_edges)
    return loss_edges

def  loss_edges_bpp(y_pred_edges, y_edges, edge_cw):
    # Edge loss
    y = F.log_softmax(y_pred_edges, dim=3)  # B x num_items x  num_items x num_class
    y = y.permute(0, 3, 1, 2).contiguous()  # B x num_class x num_items
    loss_edges = nn.NLLLoss(edge_cw)(y, y_edges)
    #y_edges = torch.tensor(y_edges, device = torch.device('cuda')).type(torch.FloatTensor)
    # pos_weight = 0.95  # 正样本权重
    #
    # #清除na
    #
    # y_pred_edges[y_pred_edges != y_pred_edges] = 0
    #
    # weight = torch.zeros_like(y_pred_edges).float().to(torch.device('cuda'))
    # weight = torch.fill_(weight, 1 - pos_weight)
    # weight[y_pred_edges > 0] = pos_weight
    # loss_edges = nn.BCELoss(weight=weight)(y_pred_edges, y_edges)

    return loss_edges


def loss_mudularity(y_pred_edges, y_edges):
    y = F.softmax(y_pred_edges, dim=3)[:,:,:,1]
    modularity = torch.mul(y, y_edges).mean() - y.mean()
    return -modularity

def gat_loss(y_pred_edges, y_edges):
    gamma = 2
    alpha = 0.95
    logp = nn.BCELoss()(y_pred_edges, y_edges.float())
    pt = torch.exp(-logp)
    # loss = (1 - pt) ** self.gamma * logp
    loss = alpha * (1 - pt) ** gamma * logp * y_edges.long() + \
           (1 - alpha) * (pt) ** gamma * logp * (1 - y_edges.long())
    return torch.mean(loss)


def decode(items, knapsacks, y, mode = "tabu"):
    """
    Loss function for edge predictions.
    Args:
        items:  (num_items, dimension + 1) 1 for value
        knapsacks:  (1, dimension)
        y_preds: predicted probs of each item in knapsack
    Returns:
        solution, ttl_value

    """
    values = items[:, 0]
    weights = items[:, 1:]

    num_items = items.shape[0]

    priority = np.argsort(-y)
    acc_val = 0
    in_items = []
    acc_weight = np.zeros_like(knapsacks)
    for i in range(num_items):
        if np.sometrue(acc_weight + weights[priority[i],:] > knapsacks):
            break

        acc_weight += weights[priority[i],:]
        acc_val += values[priority[i]]
        in_items.append(priority[i])

#   y_rank_pred = list(zip(range(5),np.array([4,3,7,6,1])))
    return in_items, acc_val


def sample_decode(items, knapsacks, y, mode = "tabu"):
    """
    Loss function for edge predictions.
    Args:
        items:  (num_items, dimension + 1) 1 for value
        knapsacks:  (1, dimension)
        y_preds: predicted probs of each item in knapsack
    Returns:
        solution, ttl_value

    """


    device = y.device
    values = items[:, 0]
    weights = items[:, 1:]

    num_items = items.shape[0]
    mask = torch.zeros((num_items,), dtype=torch.bool).to(device)
    current = y.argmax()
    mask[current] = True
    #priority = np.argsort(-y)
    acc_val = values[current]
    in_items = [current]
    acc_weight = np.zeros_like(knapsacks)
    acc_weight += weights[current, :]

    for _ in range(num_items - 1):
        ninf_mask = torch.where(mask == True, float('-inf'), 0.)
        masked_score = y + ninf_mask  # 此处要避免inplace的操作
        probs = torch.softmax(masked_score, dim=0)

        item_sampler = torch.distributions.Categorical(probs)
        current = item_sampler.sample()  # (K,)
        mask[current] = True

        if np.sometrue(acc_weight + weights[current,:] > knapsacks):
            break

        acc_weight += weights[current,:]
        acc_val += values[current]
        in_items.append(current)

#   y_rank_pred = list(zip(range(5),np.array([4,3,7,6,1])))
    return in_items, acc_val


def simplest_greedy(items, knapsack):
    """
    :param items: [num_items x (1 + item_dim)]
    :param knapsack: [1 x item_dim]
    """
    num_items = items.shape[0]
    value = items[:, 0]
    utils = np.mean(items[:, 1:] / knapsack, axis=-1)
    utils = value / utils
    priority = np.argsort(-utils)
    in_items = []
    acc_val = 0
    acc_weight = np.zeros_like(knapsack)
    for i in range(num_items):
        if np.sometrue(acc_weight + items[priority[i], 1:] >= knapsack):
            break
        acc_weight += items[priority[i], 1:]
        acc_val += items[priority[i], 0]
        in_items.append(i)
    return in_items, acc_val

def sample_gumbel(score, sampling_number=500, eps=1e-10):
    """
    Score : [num_items]
    """
    num_items = score.size

    #
    score = score / (np.max(score) - np.min(score))
    score -= score.mean()
    score = score * num_items
    # score = 3.0 * torch.tanh(score)

    U = np.random.rand(sampling_number, num_items)
    Z =  score - np.log(-np.log(U + eps) + eps)     # [sampling_number x num_items]
    return Z


def get_score(input):
    gumbel_rank, bags, knapsacks, values = input
    ret = 0
    n_i, n_j = gumbel_rank.shape
    for i in range(n_i):
        val = 0
        tmp = np.array(knapsacks)
        for j in range(n_j):
            tmp -= bags[gumbel_rank[i][j]]
            if np.any(tmp < 0):
                break
            val += values[gumbel_rank[i][j]]
        if val > ret:
            ret = val
    return ret


def Gumbel_Search(items, knapsacks, gumbel_rank,  num_chunks=16):
    """
    Args:
        items : [num_items x 1 + item_dim], +1 for value of items
        gumbel_rank : output of get_rank : [sampling_number x num_items]
        knapsacks : [num_knapsacks x item_dim]
    """
    values = items[:, 0]
    bags = items[:, 1:]
    ret = 0
    sh = gumbel_rank.shape[0]
    inputs = []
    j = 0
    while j < sh:
        inputs.append((gumbel_rank[j:j+num_chunks], np.array(bags), np.array(knapsacks), np.array(values)))
        j = j + num_chunks
        ret = _exec.map(get_score, inputs)
    result = np.max(list(ret))

    return result

def gumbel_decode(items, knapsacks, y_pred, sampling_number):

    gumbel_score = sample_gumbel(y_pred, sampling_number=sampling_number)
    gumbel_rank =  np.argsort(-gumbel_score, axis=1)
    return Gumbel_Search(items, knapsacks, gumbel_rank)

def glop(items, knapsack):
    data = {}
    data['values'] = items[:, 0]
    data['weights'] = items[:, 1:]
    data['capacities'] = knapsack
    num_items, item_dim = data['weights'].shape
    data['dimension'] = list(range(item_dim))
    data['items'] = list(range(num_items))
    # Create the mip solver with the CBC backend.
    solver = pywraplp.Solver.CreateSolver('GLOP')

    # Variables
    # x[i, j] = 1 if item i is packed in bin j.
    x = {}
    for i in data['items']:
        x[i] = solver.IntVar(0, 1, 'x_%i' % (i))

    # Constraints
    # The amount packed in each bin cannot exceed its capacity.
    for j in data['dimension']:
        solver.Add(
            sum(x[i] * data['weights'][i, j]
                for i in data['items']) <= data['capacities'][0][j])

    # Objective
    objective = solver.Objective()

    for i in data['items']:
        objective.SetCoefficient(x[i], int(data['values'][i]))

    objective.SetMaximization()
    status = solver.Solve()

    solutions = []
    if status == pywraplp.Solver.OPTIMAL:
        total_value = 0

        for i in data['items']:
            if x[i].solution_value() > 1-1e-5:
                total_value += data['values'][i]
                solutions.append(1)
            else:
                solutions.append(0)

    return  solutions, sum(np.array(solutions) *items[:, 0])

def preprocess_adj(A):
    '''
    Pre-process adjacency matrix
    :param A: adjacency matrix
    :return:
    '''
    I = np.eye(A.shape[0])
    A_hat = A + I # add self-loops
    D_hat_diag = np.sum(A_hat, axis=1)
    D_hat_diag_inv_sqrt = np.power(D_hat_diag, -0.5)
    D_hat_diag_inv_sqrt[np.isinf(D_hat_diag_inv_sqrt)] = 0.
    D_hat_inv_sqrt = np.diag(D_hat_diag_inv_sqrt)
    return np.dot(np.dot(D_hat_inv_sqrt, A_hat), D_hat_inv_sqrt)


def get_single_edge_feat(classes, Q, C, M):
    M, nof_item = classes.shape
    valid_number = np.sum(classes >0)
    onehot_classes = []
    for i in range(M):
        onehot_classes.append(np.eye(Q + 1)[(classes[i] + 1).astype(int)][:,1:])

    multi_classes = np.array(onehot_classes).sum(axis=0)
    sim_feats = cosine_similarity(multi_classes, multi_classes)
    #random_feats = np.random.rand(*sim_feats.shape) * 0.5
   # random_feats[sim_feats == 1] = 1
    #sim_feats[sim_feats==0] = C * 1.0 / (Q * M)
    sim_feats[sim_feats == 0] = 0.5
    row, col = np.diag_indices_from(sim_feats)
    sim_feats[row, col] = 0
    #sim_feats[valid_number:,:] = 0
    #sim_feats[:, valid_number:] = 0
    return sim_feats


def get_edge_feat(classes, Q, C, M):
    batch_feats = []

    for i in range(classes.shape[0]):
        batch_feats.append(get_single_edge_feat(classes[i], Q, C, M))
    return np.array(batch_feats)


def greedy_decode(heat_map, weights, capacity):
    n = len(heat_map)
    last = np.random.randint(n)
    #queue = [first]
    used_capacity = 0
    result = 1
    visited = np.zeros((n,))
    visited[last] = 1

    while np.sum(visited) < n:
        score = heat_map[last]
        priority = np.argsort(-score)
        for p in priority:
            if not visited[p]:
                if weights[p] + used_capacity <= capacity:
                    #queue.append(p)
                    used_capacity += weights[p]
                else:
                    used_capacity = weights[p]
                    #queue = [p]
                    result += 1
                visited[p] = 1
                last = p
                break
    return result


def FF_greedy_decode(heat_map, weights, capacity):
    def generate_seq(heat_map):
        n = len(heat_map)
        last = 0
        visited = np.zeros((n,))
        visited[last] = 1
        seq = [last]
        while np.sum(visited) < n:
            priority = np.argsort(- heat_map[last])
            for p in priority:
                if  visited[p] == 1:
                    pass
                else:
                    visited[p] = 1
                    seq.append(p)
                    last = p
                    break
        return seq
    
    n = len(weights)
    seq = generate_seq(heat_map)
    used_capacity = []
    queue_list = []
    result = 0
    
    for p in seq:
        flag = False
        for j in range(result):
            if weights[p] + used_capacity[j] <= capacity:
                used_capacity[j] += weights[p]
                queue_list[j].append(p)
                flag = True
                break
        if not flag:
            used_capacity.append(weights[p])
            queue_list.append([p])
            result += 1
    return result



def FF(weights, capacity, classes, C,  priority):
    n = len(weights)
    #priority = np.argsort(-weights)
    classes = classes[:,0]
    used_capacity = []
    used_classes = []
    queue_list = []
    result = 0

    for p in priority:
        flag = False
        for j in range(result):
            if weights[p] + used_capacity[j] <= capacity and len(set([classes[p]] + used_classes[j])) <= C:
                used_capacity[j] += weights[p]
                if classes[p] not in used_classes[j]:
                    used_classes[j].append(classes[p] )
                queue_list[j].append(p)
                flag = True
                break
        if not flag:
            used_capacity.append(weights[p])
            used_classes.append([classes[p]])
            queue_list.append([p])
            result += 1
    return result



def multi_FF(weights, capacity, classes, C,  priority):
    # priority = np.argsort(-weights)
    used_capacity = []
    used_classes = []
    queue_list = []
    result = 0
    for p in priority:
        flag = False
        for j in range(result):
            new_class_set = set(used_classes[j]) | set(list(classes[p]))
            if weights[p] + used_capacity[j] <= capacity and len(new_class_set) <= C:
                used_capacity[j] += weights[p]
                used_classes[j] = list(new_class_set)
                queue_list[j].append(p)
                flag = True
                break
        if not flag:
            used_capacity.append(weights[p])
            used_classes.append(list(classes[p]))
            queue_list.append([p])
            result += 1
    return result



def NF(weights, capacity, classes, C, priority):
    used_capacity = 0
    used_classes = []
    result = 1

    for p in priority:
        if weights[p] + used_capacity <= capacity and len(set([classes[p]] + used_classes)) < C:
            used_capacity += weights[p]
            if classes[p] not in used_classes:
                used_classes.append(classes[p])
        else:
            used_capacity = weights[p]
            used_classes = [classes[p]]
            result += 1
    return result

def get_action_probabilities(state, weights, heatmap_pred, last_action, type = 'weight'):
    device = weights.device
    bins, masks = state

    batch_size, num_items = weights.shape[0], weights.shape[1]

    if type == 'weight':
        score = weights.cpu()
    elif type == 'heatmap':
        score = heatmap_pred[torch.arange(batch_size), last_action,:].cpu() # batch_size x num_items
    #ninf_mask = torch.where(masks == True, float('-inf'))
    inf_index = (masks == True).nonzero()
    ninf_masks = torch.zeros([batch_size, num_items], dtype =torch.float)
    ninf_masks[inf_index] = float('-inf')
    probs = torch.softmax(score + ninf_masks, dim = 1)  # batch_size x num_items

    return probs.to(weights.device)

def transform_input(test_item_batch, test_heatmap_batch, edge_feats, M, K):

    nof_items = test_item_batch.shape[2]

    test_item_batch = test_item_batch.expand(K, M+1,  nof_items)
    #test_heatmap_batch = test_heatmap_batch.expand(K, nof_items, nof_items)
    edge_feats = edge_feats.expand(K, nof_items, nof_items)

    return test_item_batch, edge_feats

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def generate_heatmap(solutions, drop_out = 0.3):
    batch_size, num_bin, num_item = solutions.shape

    heatmap = torch.zeros(batch_size, \
                          num_bin *num_item, num_bin*num_item).long()
    for i in range(batch_size):
        for j in range(num_bin):
            #cluster = np.random.choice(solutions[i,j], size = int(drop_out * num_item), replace=False)

            number = torch.sum(solutions[i,j]>=0)
            number = min(int(drop_out * num_item) , number)
            cluster = solutions[i,j,:int(drop_out * num_item)]
            for l in cluster:
                for k in cluster:
                    heatmap[i, l, k] = 1
        row, col = np.diag_indices_from(heatmap[i])
        heatmap[i, row, col] = 0
    return heatmap

