import math
import time

import numpy as np
import argparse
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.autograd import Variable
from bpp_data_generator import BPPDataset
from sklearn.utils.class_weight import compute_class_weight


from bpp_env import bpp_env
from utils import *

def run_eas_emb(dataloader, problem_size, config):
    """
    Efficient active search using tabular updates

    Heatmap: (batch_r, num_items * num_items)
    """
    start_time = time.time()
    dataset_size = config.test_size
    batch_size = config.param_K

    assert batch_size <= dataset_size

    instance_solutions = np.zeros((dataset_size, problem_size))
    instance_costs = np.zeros((dataset_size))

    #result = []

    iterator =  tqdm(dataloader, unit='Batch')
    for i_batch, sample_batched in enumerate(iterator):
        max_reward = -np.inf

        model = torch.load(
            'save_models/' + model_file_name + 'best.model')

        model.eval()

        model.requires_grad_(False)
        model.mlp_edges.V.weight.requires_grad_(True)
        model.mlp_edges.V.bias.requires_grad_(True)

        optimizer = torch.optim.Adam([model.mlp_edges.V.weight, model.mlp_edges.V.bias], lr=5e-3, weight_decay=1e-4)


        # Load the instances
        ###############################################

        test_item_batch = Variable(sample_batched['weights'])
        #test_heatmap_batch = Variable(sample_batched['heatmaps'])

        #edge_labels = sample_batched['heatmaps'].cpu().numpy().flatten()

        edge_feats = get_edge_feat(sample_batched['weights'].cpu().numpy()[:, 1:, :], params.Q, params.C, params.M)
        edge_feats = torch.Tensor(edge_feats).type(torch.FloatTensor)

        #edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)
        #edge_cw = torch.Tensor(edge_cw).type(torch.FloatTensor)

        if USE_CUDA:
            test_item_batch = test_item_batch.cuda()
            #test_heatmap_batch = test_heatmap_batch.cuda()
            #edge_cw = edge_cw.cuda()
            edge_feats = edge_feats.cuda()

        # Start the search
        ###############################################

        for iter in range(config.max_iter):
            optimizer.zero_grad()

            test_item_batch, edge_feats = \
                transform_input(test_item_batch, None,  edge_feats, params.M, params.param_K)

            weights = test_item_batch[:, 0, :]

            classes = torch.zeros(params.param_K, params.N, params.Q + 1).to(test_item_batch.device)

            classes.scatter_(2, (test_item_batch[:,1:, :] + 1).transpose(1,2).long(), 1)
            y_preds = model(weights, classes[:, :, 1:], params.B, None, edge_feats, edge_cw=None)

            if params.model == "GCN":
                if params.use_gumbel:
                    heatmap = F.gumbel_softmax(y_preds, dim=3)[:, :, :, 1]
                else:
                    heatmap = F.softmax(y_preds, dim=3)[:, :, :, 1]
            else:
                heatmap = y_preds

            #t1 = time.time()
            reward, log_probs = model.bin_sample(test_item_batch,  params.B, params.Q, params.C, params.M,
                                                 heatmap, config, type = params.decode_type) # (batch_size, K, n)
            #print(time.time() - t1)

            # batch_path = batch_path.reshape(-1, config.nof_items)
            # log_probs = log_probs.reshape(-1, config.nof_items)
            #
            # #test_data = test_item_batch.unsqueeze(1).expand(batch_size, config.param_K, config.nof_items).reshape(-1,
            # #                                                                                                      config.nof_items)
            #
            # env = bpp_env(test_item_batch, problem_size, config.capacity, config.fit)  # not necessary?
            # group_state, reward, done = env.reset()
            #
            # for i in range(batch_path.shape[1]):
            #     action = batch_path[:,i]
            #     state, reward, done = env.step(action)

            #t2 = time.time()

            Ret_list = np.array(reward)
            R_list = torch.FloatTensor(Ret_list).view(batch_size, -1).expand(batch_size, params.N)
            if USE_CUDA:
                R_list = R_list.cuda()
            adv = R_list - R_list.mean()
            loss = (-adv * log_probs).sum(dim = 1)
            loss = loss.mean()
            score = R_list.mean().item()

            max_reward = max(max_reward, max(Ret_list))  # min(best_sol, (-R_list).min().item())

            if max_reward == -20:
                print(f"Find better solution in {iter} epoch! ")
                break

            loss.backward()
            optimizer.step()
            #print(time.time() - t2)
            if (iter+1) % 10 == 0:
                print(
                    f"{i_batch}batch, {iter+1}epoch,  LOSS : {loss.item():.3f}, CURRENT BIN VALUE: {np.mean(-max_reward)}")
                #    break
        # Store incumbent solutions and their objective function value
        #instance_solutions[i_batch * config.batch_size: i_batch * config.batch_size + batch_size] = incumbent_solutions
        instance_costs[
        i_batch * config.batch_size: i_batch * config.batch_size + batch_size] = -max_reward

    print("For CCBPP-N%d-B%d-C%d-Q%d:" % (params.N, params.B, params.C, params.Q))
    print("%s-EAS %d iterations: %.4f seconds for each instance." % (
        params.fit, params.max_iter, (time.time() - start_time) / params.test_size))
    print("total value is %.4f. " % (np.mean(instance_costs)))


    return instance_costs
        #, instance_solutions

parser = argparse.ArgumentParser(description="Pytorch implementation of EAS Tabular")

parser.add_argument('--test_size', default=200, type=int, help='Test data size')
parser.add_argument('--batch_size', default=1, type=int, help='Batch size')
# GPU
parser.add_argument('--gpu', default=True, action='store_true', help='Enable gpu')

# CCBPP
parser.add_argument('--N', type=int, default=200, help='Number of items in CCBPP')
parser.add_argument('--B', type=int, default=100, help='capacity of item in CCBPP')
parser.add_argument('--C', type=int, default=5, help='compartments limit of item in CCBPP')
parser.add_argument('--Q', type=int, default=10, help='Number of classes of item in CCBPP')
parser.add_argument('--M', type=int, default=1, help='Number of classes of item in CCBPP')

parser.add_argument('--fit', type=str, default="FF", help='fit scheme in CCBPP')
parser.add_argument('--use_gumbel', type=bool, default=False, help='use gumbel sampling or not')
parser.add_argument('--data_type', type=str, default="normal", help='normal/large/uniform')

# Network
parser.add_argument('--model', type=str, default="GCN", help='GCN/PointerNet')
parser.add_argument('--embedding_size', type=int, default=64, help='Embedding size')
parser.add_argument('--hiddens', type=int, default=64, help='Number of hidden units')
parser.add_argument('--max-iter', type=int, default=30, help='Number of max-iter of EAS')
parser.add_argument('--decode_type', type=str, default="argmax", help='Number of max-iter of EAS')

#EAS params
parser.add_argument('-param_K', default=32, type=float)


params = parser.parse_args()

#model_file_name = "GCN-it%d-cap%d" % (100, params.capacity)

model_file_name = "CCBPP-%s-RandomFeat-N%d-B%d-C%d-Q%d-M%d" % (params.model,
                                                               params.N, params.B, params.C, params.Q, params.M)

if params.gpu and torch.cuda.is_available():
    USE_CUDA = True
else:
    USE_CUDA = False

dataset = BPPDataset("test",
                     params.test_size,
                     params.N,
                     params.B,

                     params.C,
                     params.Q,
                     params.M,
                     params.data_type
                     )

dataloader = DataLoader(dataset,
                        batch_size=params.batch_size,
                        shuffle=False,
                        num_workers=4)

instance_costs = run_eas_emb(dataloader, params.N, params)
print(np.mean(instance_costs))
