import os
import numpy as np

import argparse
from ortools.linear_solver import pywraplp
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from scipy.spatial.distance import pdist, squareform

def bpp_ortool_opt(weights, capacities):
    data = {}
    data['weights'] = weights

    data['num_items'] = len(data['weights'])
    data['items'] = list(range(len(weights)))
    data['all_items'] = range(data['num_items'])

    data['bin_capacity'] = capacities
    data['bins'] = data['items']

    # Create the mip solver with the CBC backend.
    solver = pywraplp.Solver.CreateSolver('SCIP')

    # Variables
    x = {}
    for i in data['items']:
        for j in data['bins']:
            x[(i, j)] = solver.IntVar(0, 1, 'x_%i_%i' % (i, j))

    y = {}
    for j in data['bins']:
        y[j] = solver.IntVar(0, 1, 'y[%i]' % j)

    # Constraints
    # Each item must be in exactly one bin.
    for i in data['items']:
        solver.Add(sum(x[i, j] for j in data['bins']) == 1)

    # The amount packed in each bin cannot exceed its capacity.
    for j in data['bins']:
        solver.Add(
            sum(x[(i, j)] * data['weights'][i] for i in data['items']) <= y[j] *
            data['bin_capacity'])

    # Objective
    # Maximize total value of packed items.
    solver.Minimize(solver.Sum([y[j] for j in data['bins']]))
    status = solver.Solve()

    solutions = []
    if status == pywraplp.Solver.OPTIMAL:
        num_bins = 0

        for j in data['bins']:
            if y[j].solution_value() == 1:
                print(f'Bin {j}')
                bin_weight = 0
                bin_items = []
                for i in data['items']:
                    if x[i, j].solution_value() > 0:
                        print(
                            f"Item {i} weight: {data['weights'][i]}"
                        )
                        bin_weight += data['weights'][i]
                        bin_items.append(weights[i])

                if sum(bin_items) > 0:
                    num_bins += 1
                    print('Bin number', j)
                    print('  Items packed:', bin_items)
                    print('  Total weight:', bin_weight)
                    solutions.append(bin_items)

    # print('Total packed value:', total_value)
    else:
        print('The problem does not have an optimal solution.')
    return solutions


# def generate_cs_data(capacity, num_bin = 10):
#     blocks = [8, 12, 6, 9, 15, 4, 12, 13, 11, 10]
#     solutions = []
#     np.random.shuffle(blocks)
#     for i in range(num_bin):
#         cutting_pt = np.random.choice(np.arange(2, capacity - 2), blocks[i] - 1, replace=False)
#         cutting_pt.sort()
#         items = list(np.append(cutting_pt, [capacity]) - np.append([0], cutting_pt))
#         solutions.append(items)
#     return solutions

class BPPDataset(Dataset):
    """
    Random TSP dataset
    """

    def __init__(self, mode, data_size, N, B, C, Q, M, data_type="normal",
                 solver=bpp_ortool_opt, data_dir = 'ccbpp_data/'):
        self.data_size = data_size

        self.num_items = N
        self.capacity = B
        self.solve = True
        self.solver = solver
        self.mode = mode
        self.min_rate = 0.10
        self.max_rate = 0.25
        self.C = C
        self.Q = Q
        self.M = M

        if data_type == 'normal':
            if M == 1:
                data_path = os.path.join(data_dir, mode + 'B{}_C{}_N{}_Q{}'
                                     .format(B, C, N, Q))
            else:
                data_path = os.path.join(data_dir, mode + 'B{}_C{}_N{}_Q{}_M{}'
                                         .format(B, C, N, Q, M))
        else:
            if M == 1:
                data_path = os.path.join(data_dir, mode + 'B{}_C{}_N{}_Q{}_{}'
                                     .format(B, C, N, Q, data_type))
            else:
                data_path = os.path.join(data_dir, mode + 'B{}_C{}_N{}_Q{}_M{}_{}'
                                         .format(B, C, N, Q, M, data_type))

        save_batch = data_size
        # prepare for the data
        if os.path.exists(os.path.join(data_path, 'weights_0.npy')):    # data already exists
            data_list = []
            heatmap_list = []
            for i in range(self.data_size // save_batch):
                weights_list = list(np.load(os.path.join(data_path, 'weights_{}.npy'.format(i))))
                heatmap_list = list(np.load(os.path.join(data_path, 'heatmaps_{}.npy'.format(i))))
                data_list.append({'weights_list': weights_list, 'heatmaps_list': heatmap_list, 'solutions':[]})
            self.data = self._get_ttl_data(data_list)
        else:                                                           # data not exists, generate data and save
            if not os.path.exists(data_path):
                os.makedirs(data_path)
            data_list = []
            for i in range(self.data_size//save_batch):
                batch_data = self._generate_bpp_data(save_batch)
                np.save(os.path.join(data_path, 'weights_{}.npy'.format(i)), np.array(batch_data['weights_list']))
                np.save(os.path.join(data_path, 'heatmaps_{}.npy'.format(i)), batch_data['heatmaps_list'])
                np.save(os.path.join(data_path, 'solutions_{}.npy'.format(i)), batch_data['solutions'])
                data_list.append(batch_data)
            self.data = self._get_ttl_data(data_list)

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):

        if self.mode == "train":
            weights = torch.from_numpy(np.array(self.data['weights_list'][idx]).transpose(1,0)).float()
            heatmaps = torch.from_numpy(np.array(self.data['heatmaps_list'][idx])).long() if self.solve else None
            sample = {'weights':weights,  'heatmaps':heatmaps}
        else:
            solutions = torch.from_numpy(np.array([20])).long() if self.solve else None
            weights = torch.from_numpy(np.array(self.data['weights_list'][idx]).transpose(1,0)).float()
#            heatmaps = torch.from_numpy(np.array(self.data['heatmaps_list'][idx])).long() if self.solve else None
            #solutions = torch.from_numpy(np.array(self.data['solutions'][idx])).long() if self.solve else None
            sample = {'weights': weights, 'heatmaps': [], 'solutions': solutions}
        return sample



    def _get_ttl_data(self, data_list, max_bin=20, max_item=10):
        ttl_data = {'weights_list': [], 'heatmaps_list': [], 'solutions': []}
        for batch in data_list:
            ttl_data['weights_list'] += batch['weights_list']
            ttl_data['heatmaps_list'] += batch['heatmaps_list']

        return ttl_data

    def _generate_cs_data(self, num_item, capacity, delta = 3):

        items = np.arange(self.min_rate * capacity, self.max_rate * capacity + 1)
        num_bin =  20

        solutions = []
        for i in range(num_bin):
            used_capacity = 0
            item_list = []
            class_list = []
            while used_capacity < capacity:
                item = np.random.choice(items)
                if used_capacity + item <= capacity - delta:
                    used_capacity += item
                    item_list.append(item)
                elif capacity - delta <= used_capacity + item <= capacity:
                    item_list.append(item)
                    break
                else:  # reset
                    used_capacity = 0
                    item_list = []

            for i in range(int(1 / 0.1) - len(item_list)):
               item_list.append(0)

            if self.M == 1:
                class_set = np.random.choice(np.arange(self.Q), size=self.C, replace=False, p=None)
                class_index = np.random.choice(np.arange(self.C), size=len(item_list), replace=True, p=None)
                class_list = class_set[class_index]
                solutions.append(np.vstack([item_list, class_list]).transpose(1,0))
            else:
                class_set = np.random.choice(np.arange(self.Q), size=self.C, replace=False, p=None)
                class_list = []
                for _ in range(len(item_list)):
                    class_index = np.random.choice(class_set, size=self.M, replace=False, p=None)
                    class_list.append(class_index)
                solutions.append(np.vstack([item_list, np.array(class_list).transpose(1, 0)]).transpose(1,0))

        return solutions

    def _generate_heatmap(self, solutions):
        num_bin = len(solutions)
        num_item = 0
        all_solutions = []
        for s in solutions:
            num_item += len(s)
            all_solutions += s.tolist()

        item_list = list(range(num_item))
        np.random.shuffle(item_list)
        results = []
        weights = [0] * num_item

        #heatmap = np.zeros((num_item, num_item))
        passed_item = 0
        for s in solutions:
            num_item = len(s)
            bin_result = []
            for i in item_list[passed_item:passed_item + num_item]:
                bin_result.append(i)
                weights[i] = all_solutions[passed_item]
                #for j in item_list[passed_item:passed_item + num_item]:
                #    heatmap[i, j] = 1
                passed_item += 1
            results.append(bin_result)

        #zipped_data = list(zip(item_list, np.array(solutions).transpose(0, 2, 1).reshape(-1, self.M + 1)))
        #zipped_data.sort(key=lambda x: x[0])
        #weights = [x[1] for x in zipped_data]
        return weights, results, num_bin

    def _generate_bpp_data(self, save_batch, data_type ='cs'):
        weights_list = []
        results_list = []
        heatmaps_list = []

        data_iter = tqdm(range(save_batch), unit='data')

        for i, _ in enumerate(data_iter):
            data_iter.set_description('Weights & Values %i/%i' % (i+1, self.data_size))
            #weights = np.random.randint(0.1 *self.capacity , 0.2 * self.capacity, (self.num_items, ))
            if data_type == 'cs':
                solutions = self._generate_cs_data(self.num_items, self.capacity)
                #solutions = self._generate_2D_bin(self.num_items, (10, 10))
            #elif data_type == 'random':
            #    solutions = self.solver(weights, self.capacity)
            else:
                raise NotImplementedError

            weights, heatmap, target_value = self._generate_heatmap(solutions)

            weights_list.append(weights)
            results_list.append(target_value)
            heatmaps_list.append(heatmap)

        return {'weights_list': weights_list,  'heatmaps_list': heatmaps_list, 'solutions':results_list}

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="CCBPP Data Generation")

    parser.add_argument('--type', default="train", type=str, help='Training data size')
    parser.add_argument('--size', default=6400, type=int, help='Training data size')
    parser.add_argument('--N', type=int, default=200, help='Number of items in CCBPP')
    parser.add_argument('--B', type=int, default=100, help='capacity of item in CCBPP')
    parser.add_argument('--C', type=int, default=5, help='compartments limit of item in CCBPP')
    parser.add_argument('--Q', type=int, default=10, help='Number of classes of item in CCBPP')
    parser.add_argument('--M', type=int, default=1, help='Number of classes of item in CCBPP')

    params = parser.parse_args()

    dataset = BPPDataset("train",
                         params.train_size,
                         params.N,
                         params.B,
                         params.C,
                         params.Q,
                         params.M
                         )
