from torch.utils.data import DataLoader


from tqdm import tqdm
import argparse
import time

from utils import NF, FF, multi_FF
from search_utils import *
from torch.autograd import Variable
from sklearn.utils.class_weight import compute_class_weight


from bpp_data_generator import BPPDataset
from utils import get_edge_feat

def test(model, params, mode = "test"):
    start_time = time.time()
    if params.gpu and torch.cuda.is_available():
        USE_CUDA = True
    else:
        USE_CUDA = False


    dataset = BPPDataset(mode,
                         params.val_size,
                         params.N,
                         params.B,
                         params.C,
                         params.Q,
                         params.M,
                         params.data_type
                         )


    dataloader = DataLoader(dataset,
                            batch_size=params.batch_size,
                            shuffle=False,
                            num_workers=4)

    model.eval()
    with torch.no_grad():
        batch_loss = []
        result = []
        optimal_result = []

        print("----------------Start valid batch---------------------")
        iterator = tqdm(dataloader, unit='Batch')

        for i_batch, sample_batched in enumerate(iterator):

            test_item_batch = Variable(sample_batched['weights'])


            edge_feats = get_edge_feat(sample_batched['weights'].cpu().numpy()[:, 1:, :], params.Q, params.C, params.M)
            edge_feats = torch.Tensor(edge_feats).type(torch.FloatTensor)

            if USE_CUDA:
                test_item_batch = test_item_batch.cuda()
                #test_heatmap_batch = test_heatmap_batch.cuda()
                #edge_cw = edge_cw.cuda()
                edge_feats = edge_feats.cuda()

            #y_preds, loss = model(test_item_batch[:, 0, :], params.B, test_heatmap_batch, edge_feats, edge_cw)
            if params.search_type == "greedy" or params.search_type == "sample":
                weights = test_item_batch[:, 0, :]
                classes = torch.zeros(params.batch_size, params.N, params.Q + 1).to(test_item_batch.device)
                classes.scatter_(2, (test_item_batch[:, 1:, :] + 1).transpose(1, 2).long(), 1)

                y_preds = model(weights, classes[:, :, 1:], params.B, None, edge_feats, None)
            else:
                y_preds = None

#            loss = loss.mean()
#            batch_loss.append(loss.data.item())
            # value = \
            #     calculate_value(test_item_batch, test_heatmap_batch, y_preds)
            # result.append(value)

            if mode == "test" or mode == "valid":
                value = \
                    calculate_value(test_item_batch,  y_preds, params)

                result.append(value)

                if mode == "test":
                    solutions_batch = sample_batched['solutions']
                    optimal_result.append(np.mean(solutions_batch.cpu().numpy()))


           # iterator.set_postfix(loss='{}'.format(loss.data.item()))
        print("----------------End valid batch---------------------")
    if mode == "test" or mode == "valid":
        print("For CCBPP-N%d-B%d-C%d-Q%d:" % (params.N, params.B,params.C, params.Q))
        print("use gumbel softmax: %s:" % (params.use_gumbel))
        if mode == "test":
            print("%s-%s : %.4f seconds for each instance." % (
                params.fit, params.search_type, (time.time() - start_time) / params.test_size))
        else:
            print("%s-%s : %.4f seconds for each instance." % (
                params.fit, params.search_type, (time.time() - start_time) / params.val_size))
        print("total value is %.4f. " % (np.mean(result)))
        if mode == "test":
            print("optimal result is %.4f. "% (np.mean(optimal_result)))

        return np.mean(result)
    else:
        return np.mean(batch_loss)


def calculate_value(batch_items, y_pred_edges, params, probs_type = 'raw'):
    batch_size, nof_items, dim  = batch_items.shape

    if params.search_type == "greedy" or params.search_type == "sample":

        if params.model == "GCN":
            if probs_type == 'raw':
                # Compute softmax over edge prediction matrix
                if params.use_gumbel:
                    y = F.gumbel_softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
                else:
                    y = F.softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
                # Consider the second dimension only
                y = y[:, :, :, 1]  # B x V x V

            elif probs_type == 'logits':
                # Compute logits over edge prediction matrix
                y = F.log_softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
                # Consider the second dimension only
                y = y[:, :, :, 1]  # B x V x V
                y[y == 0] = -1e-20  # Set 0s (i.e. log(1)s) to very small negative number
        else:
            y = y_pred_edges

    if params.fit == "FF" and params.M == 1:
        fit_func = FF
    elif params.fit == "FF" and params.M > 1:
        fit_func = multi_FF
    else:
        fit_func = NF

    out_preds, result = [], []
    for i in range(batch_size):

        weight  = batch_items[i,:]
        item = weight.cpu().numpy()

        if params.search_type == "decreasing":
            value = fit_func(item[0],params.B, item[1:].transpose(1,0), params.C, np.argsort(-item[0]))
        elif params.search_type == "random":
            value = fit_func(item[0], params.B, item[1:].transpose(1, 0), params.C, np.random.permutation(len(item[0])))
        elif params.search_type == "greedy":
            value = sample_bin_batch_search(item, params.B, params.Q,params.C,params.M, y[i, :, :].cpu().numpy(), params.fit, 1)
        elif params.search_type == "sample":
            value = sample_bin_batch_search(item,  params.B, params.Q,params.C,params.M, y[i, :, :].cpu().numpy(), params.fit, 32)

        result.append(value)

    #np.save(f'results/heatmap-N%d-B%d-C%d-Q%d.npy'% (
    #    params.N, params.B, params.C, params.Q), out_preds)


    return np.mean(result)


#model = PointerNet(params.dimension,
#                   params.embedding_size,
#                   params.hiddens,
#                   params.dropout,
#                   params.alpha).to(device)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Pytorch implementation of Pointer-Net")

    # Data

    parser.add_argument('--test_size', default=200, type=int, help='Test data size')
    parser.add_argument('--batch_size', default=50, type=int, help='Batch size')
    # GPU
    parser.add_argument('--gpu', default=True, action='store_true', help='Enable gpu')

    # CCBPP
    parser.add_argument('--N', type=int, default=200, help='Number of items in CCBPP')
    parser.add_argument('--B', type=int, default=100, help='capacity of item in CCBPP')
    parser.add_argument('--C', type=int, default=5, help='compartments limit of item in CCBPP')
    parser.add_argument('--Q', type=int, default=10, help='Number of classes of item in CCBPP')
    parser.add_argument('--M', type=int, default=1, help='Number of classes of item in CCBPP')
    parser.add_argument('--data_type', type=str, default="normal", help='normal/large/uniform')


    parser.add_argument('--fit', type=str, default="FF", help='fit scheme in BPP')
    parser.add_argument('--search_type', type=str, default="random", help='search type in BPP')
    parser.add_argument('--use_gumbel', type=bool, default=False, help='use gumbel sampling or not')

    # Network
    parser.add_argument('--model', type=str, default="GCN", help='GCN/PointerNet')
    parser.add_argument('--embedding_size', type=int, default=64, help='Embedding size')
    parser.add_argument('--hiddens', type=int, default=64, help='Number of hidden units')

    params = parser.parse_args()


    model_file_name = "CCBPP-%s-RandomFeat-N%d-B%d-C%d-Q%d-M%d" % (params.model,
                                                        params.N, params.B, params.C, params.Q, params.M)


    model = torch.load(
        'save_models/'+ model_file_name + 'best.model')

    print(test(model, params, mode = "test"))
