import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
# from torch_geometric.loader import DataLoader
import numpy as np
import argparse
import time
import os
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight

from bpp_model import GCN
from PointerNet import PointerNet
# from models import GNN
from bpp_data_generator import BPPDataset
from test_bpp import test
from utils import get_edge_feat, generate_heatmap

parser = argparse.ArgumentParser(description="Pytorch implementation of Pointer-Net")

# Data
parser.add_argument('--train_size', default=6400, type=int, help='Training data size')
parser.add_argument('--val_size', default=320, type=int, help='Validation data size')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')

# Train
parser.add_argument('--nof_epoch', default=20, type=int, help='Number of epochs')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--accumulation_steps', type=int, default=1, help='Accumulation Steps')
parser.add_argument('--drop_out', type=float, default=0.5, help='Accumulation Steps')

# GPU
parser.add_argument('--gpu', default=True, action='store_true', help='Enable gpu')

# CCBPP
parser.add_argument('--N', type=int, default=200, help='Number of items in CCBPP')
parser.add_argument('--B', type=int, default=100, help='capacity of item in CCBPP')
parser.add_argument('--C', type=int, default=5, help='compartments limit of item in CCBPP')
parser.add_argument('--Q', type=int, default=10, help='Number of classes of item in CCBPP')
parser.add_argument('--M', type=int, default=1, help='Number of classes of item in CCBPP')

# Network
parser.add_argument('--model', type=str, default="GCN", help='GCN/PointerNet')
parser.add_argument('--data_type', type=str, default="normal", help='normal/large/uniform')

parser.add_argument('--hiddens', type=int, default=64, help='Number of hidden units')
parser.add_argument('--gcn_layers', type=int, default=3, help='Dropout value')

parser.add_argument('--fit', type=str, default="FF", help='fit scheme in BPP')
parser.add_argument('--search_type', type=str, default="sample", help='search type in BPP')
parser.add_argument('--use_gumbel', type=bool, default=False, help='use gumbel sampling or not')

params = parser.parse_args()

if params.gpu and torch.cuda.is_available():
    USE_CUDA = True
    print('Using GPU, %i devices.' % torch.cuda.device_count())
else:
    USE_CUDA = False
if params.data_type == 'normal':
    model_file_name = "CCBPP-%s-RandomFeat-N%d-B%d-C%d-Q%d-M%d" % (params.model, \
                                                                   params.N, params.B, params.C, params.Q, params.M)
else:
    model_file_name = "CCBPP-%s-RandomFeat-N%d-B%d-C%d-Q%d-M%d_%s" % (params.model, \
                                                                      params.N, params.B, params.C, params.Q, params.M,
                                                                      params.data_type)

if params.model == "GCN":
    model = GCN(2 + params.Q, params.hiddens, params.gcn_layers)
elif params.model == "PointerNet":
    model = PointerNet(2 + params.Q, params.hiddens)
# model = torch.load(
#    'save_models/' + model_file_name + 'best.model')

dataset = BPPDataset("train",
                     params.train_size,
                     params.N,
                     params.B,
                     params.C,
                     params.Q,
                     params.M,
                     params.data_type
                     )

dataloader = DataLoader(dataset,
                        batch_size=params.batch_size,
                        shuffle=True,
                        num_workers=4)

if USE_CUDA:
    model.cuda()
    net = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

model_optim = optim.Adam(filter(lambda p: p.requires_grad,
                                model.parameters()),
                         lr=params.lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(model_optim, step_size=20, gamma=0.99)

losses = []
directory = os.path.dirname("save_models/")
if not os.path.exists(directory):
    os.makedirs(directory)

train_loss = []
valid_loss = []
best_valid_loss = 1e5
best_test_bins = 1e5

start = time.time()
for epoch in range(params.nof_epoch):
    batch_loss = []
    iterator = tqdm(dataloader, unit='Batch')

    model.train()

    for i_batch, sample_batched in enumerate(iterator):
        iterator.set_description('Batch %i/%i' % (epoch + 1, params.nof_epoch))

        train_item_batch = Variable(sample_batched['weights'])

        heatmap = generate_heatmap(sample_batched['heatmaps'], drop_out=params.drop_out)
        train_heatmap_batch = Variable(heatmap)

        edge_labels = heatmap.cpu().numpy().flatten()

        edge_feats = get_edge_feat(sample_batched['weights'].cpu().numpy()[:, 1:, :], params.Q, params.C, params.M)
        edge_feats = torch.Tensor(edge_feats).type(torch.FloatTensor)

        edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)
        edge_cw = torch.Tensor(edge_cw).type(torch.FloatTensor)
        # capacity = torch.Tensor(params.capacity)

        if USE_CUDA:
            train_item_batch = train_item_batch.cuda()
            train_heatmap_batch = train_heatmap_batch.cuda()

            edge_cw = edge_cw.cuda()
            edge_feats = edge_feats.cuda()

        weights = train_item_batch[:, 0, :]

        train_item_batch[:, 1:, :] += 1
        classes = torch.zeros(params.batch_size, params.N, params.Q + 1).to(train_item_batch.device)
        classes.scatter_(2, train_item_batch[:, 1:, :].transpose(1, 2).long(), 1)

        y_preds, loss = model(weights, classes[:, :, 1:], params.B, train_heatmap_batch, edge_feats, edge_cw)
        loss = loss.mean()
        loss = loss / params.accumulation_steps  # Scale loss by accumulation steps
        loss.backward()

        losses.append(loss.data.item())
        batch_loss.append(loss.data.item())
        if (i_batch + 1) % params.accumulation_steps == 0:
            model_optim.step()
            model_optim.zero_grad()
            scheduler.step()
        iterator.set_postfix(loss='{}'.format(loss.data.item()))

    test_bins = test(model, params, "valid")
    # valid_loss.append(test_loss)

    if test_bins < best_test_bins:
        best_test_bins = test_bins
        torch.save(model, "save_models/" + model_file_name + "best.model")

    if (epoch + 1) % 50 == 0:
        torch.save(model, "save_models/" + model_file_name + "epoch_" + str(epoch) + ".model")

    # print(test_loss)
    train_loss.append(np.mean(batch_loss))
    iterator.set_postfix(loss=np.average(batch_loss))

# np.save('train_bpp_loss-RandomFeat-it%d-cap%d.npy' %
#         (params.N, params.B), train_loss)
# np.save('valid_bpp_loss-RandomFeat-it%d-cap%d.npy' %
#         (params.N, params.B), valid_loss)
