import math
import time

import numpy as np
import argparse
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.autograd import Variable
from bpp_data_generator import BPPDataset
from full_cn import FC
import torch.optim as optim
from sklearn.utils.class_weight import compute_class_weight


from bpp_env import bpp_env
from utils import *

def run_eas_heatmap(dataloader, problem_size, config):
    """
    Efficient active search using tabular updates

    Heatmap: (batch_r, num_items * num_items)
    """
    start_time = time.time()
    dataset_size = config.test_size
    batch_size = config.param_K

    assert batch_size <= dataset_size

    instance_solutions = np.zeros((dataset_size, problem_size))
    instance_costs = np.zeros((dataset_size))

    #result = []

    iterator =  tqdm(dataloader, unit='Batch')
    for i_batch, sample_batched in enumerate(iterator):
        max_reward = -np.inf

        # model = torch.load(
        #     'save_models/' + model_file_name + 'best.model')
        #
        # model.eval()
        #
        # model.requires_grad_(False)
        # model.mlp_edges.V.weight.requires_grad_(True)
        # model.mlp_edges.V.bias.requires_grad_(True)
        #
        # optimizer = torch.optim.Adam([model.mlp_edges.V.weight, model.mlp_edges.V.bias], lr=5e-3, weight_decay=1e-4)

        model = FC(params.N)
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=5e-4, weight_decay=1e-4)
        if USE_CUDA:
            model.cuda()
            net = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
            cudnn.benchmark = True


        # Load the instances
        ###############################################
        test_item_batch = Variable(sample_batched['weights'])

        edge_feats = get_edge_feat(sample_batched['weights'].cpu().numpy()[:, 1:, :], params.Q, params.C, params.M)
        edge_feats[:,:,:] = 1.0/params.N
        edge_feats = torch.Tensor(edge_feats).type(torch.FloatTensor)



        if USE_CUDA:
            edge_feats = edge_feats.cuda()

        # Start the search
        ###############################################

        for iter in range(config.max_iter):
            optimizer.zero_grad()

            test_item_batch, edge_feats = \
                transform_input(test_item_batch, None,  edge_feats, params.M, params.param_K)

            y_preds = model(edge_feats)
            heatmap = y_preds

            reward, log_probs = model.bin_sample(test_item_batch,  params.B, params.Q, params.C, params.M, heatmap, config) # (batch_size, K, n)

            #t2 = time.time()
            Ret_list = np.array(reward)
            R_list = torch.FloatTensor(Ret_list).view(batch_size, -1).expand(batch_size, config.N)
            if USE_CUDA:
                R_list = R_list.cuda()
            adv = R_list - R_list.mean()
            loss = (-adv * log_probs).sum(dim = 1)
            loss = loss.mean()
            score = R_list.mean().item()

            max_reward = max(max_reward, max(Ret_list))  # min(best_sol, (-R_list).min().item())

            if  max_reward == -20:
                print(f"Find better solution: {-max_reward} in {iter}epoch")
                break


            loss.backward()
            optimizer.step()
            #print(time.time() - t2)
            if (iter+1) % 10 == 0:
                print(
                    f"{i_batch}batch, {iter+1}epoch,  LOSS : {loss.item():.3f}, CURRENT BIN VALUE: {np.mean(-max_reward)}")
                #    break
        # Store incumbent solutions and their objective function value
        #instance_solutions[i_batch * config.batch_size: i_batch * config.batch_size + batch_size] = incumbent_solutions
        instance_costs[
        i_batch * config.batch_size: i_batch * config.batch_size + batch_size] = -max_reward

    print("For CCBPP-N%d-B%d-C%d-Q%d:" % (params.N, params.B, params.C, params.Q))
    print("%s-EAS %d iterations: %.4f seconds for each instance." % (
        params.fit, params.max_iter, (time.time() - start_time) / params.test_size))
    print("total value is %.4f. " % (np.mean(instance_costs)))


    return instance_costs
        #, instance_solutions

parser = argparse.ArgumentParser(description="Pytorch implementation of EAS Tabular")

parser.add_argument('--test_size', default=200, type=int, help='Test data size')
parser.add_argument('--batch_size', default=1, type=int, help='Batch size')
# GPU
parser.add_argument('--gpu', default=True, action='store_true', help='Enable gpu')

# CCBPP
parser.add_argument('--N', type=int, default=200, help='Number of items in CCBPP')
parser.add_argument('--B', type=int, default=100, help='capacity of item in CCBPP')
parser.add_argument('--C', type=int, default=5, help='compartments limit of item in CCBPP')
parser.add_argument('--Q', type=int, default=10, help='Number of classes of item in CCBPP')
parser.add_argument('--M', type=int, default=1, help='Number of classes of item in CCBPP')

parser.add_argument('--fit', type=str, default="FF", help='fit scheme in CCBPP')
parser.add_argument('--use_gumbel', type=bool, default=False, help='use gumbel sampling or not')

# Network
parser.add_argument('--embedding_size', type=int, default=64, help='Embedding size')
parser.add_argument('--hiddens', type=int, default=64, help='Number of hidden units')
parser.add_argument('--max-iter', type=int, default=30, help='Number of max-iter of EAS')

#EAS params
parser.add_argument('-param_K', default=32, type=float)


params = parser.parse_args()


if params.gpu and torch.cuda.is_available():
    USE_CUDA = True
else:
    USE_CUDA = False

dataset = BPPDataset("test",
                     params.test_size,
                     params.N,
                     params.B,

                     params.C,
                     params.Q,
                     params.M
                     )

dataloader = DataLoader(dataset,
                        batch_size=params.batch_size,
                        shuffle=False,
                        num_workers=4)

instance_costs = run_eas_heatmap(dataloader, params.N, params)
print(np.mean(instance_costs))
