from itertools import product
import random
import numpy as np
import json
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import normalize
from torch.optim.lr_scheduler import *
import time

def generate_knapsack_instance(n, W, ID):
    """
    Generates a random instance of the 0/1 Knapsack problem with n items and
    capacity W.
    
    Parameters:
    n (int): The number of items.
    W (int): The capacity of the knapsack.
    
    Returns:
    Tuple[int, List[int], List[int]]: A tuple containing the capacity W, a list
    of item weights, and a list of item values.
    """
    weights = [random.randint(1, W) for _ in range(n)]
    #weakly correlated and most realistic according to: http://cic.tju.edu.cn/faculty/gongxj/course/algorithm/doc/2005-Wherearethehardknapsackproblems.pdf
    values = [max(1,random.randint(w-int(W/10),w+int(W/10))) for w in weights]
    # values = [random.randint(1, 10) for _ in range(n)]
    ID = ID +1
    return W, weights, values, ID


def find_closest_value(input_value, float_array):
    closest_value = None
    min_difference = float('inf')

    for value in float_array:
        difference = abs(value - input_value)
        if difference < min_difference:
            min_difference = difference
            closest_value = value

    return closest_value

def find_closest_value_index(input_value, float_array):
    closest_index = None
    min_difference = float('inf')

    for index, value in enumerate(float_array):
        difference = abs(value - input_value)
        if difference < min_difference:
            min_difference = difference
            closest_index = index

    return closest_index

def generate_knapsack_instances(num_instances, min_items, max_items, min_weight, max_weight, filename, optima=False, max_item_proportion = 0.5, qualities = [0.73,0.76,0.79,0.82,0.85,0.88,0.91,0.94,0.97,1.0],tutorial=False, solutions_filename=""):
    """
    Generates and saves a specified number of knapsack instances to a JSON file.

    Parameters:
    num_instances (int): The number of instances to generate.
    min_items (int): The minimum number of items in an instance.
    max_items (int): The maximum number of items in an instance.
    min_weight (int): The minimum weight constraint of a knapsack instance.
    max_weight (int): The maximum weight constraint of a knapsack instance.
    filename (str): The name of the file to save the instances to.
    """
    instances = []
    solutions_storage = []
    if tutorial:
        i=-2
    else:
        i=0
    n = 18
    while i < num_instances:
        i=i+1
        # n = random.randint(min_items, max_items)
        W = random.randint(min_weight, max_weight)
        W, weights, values, id = generate_knapsack_instance(n, W, len(instances))
        if optima:
            #found=False
            solutions = knapsack_value_filter(W,weights,values,0)
            solutions.sort(reverse=True)
            optimal_value = solutions[0][0]
            
            values_sols = [solutions[k][0] for k in range(0,len(solutions))]
            values_sols = np.divide(values_sols,optimal_value)
            qs = []
            sols = []
            for q in qualities:
                index = find_closest_value_index(q,values_sols)
                qs.append(solutions[index][0])
                sols.append(solutions[index][1])
            if sum(solutions[0][1])<4:
                i=i-1
                continue
            #for v in values:
            #     if v > optimal_value/2.0:
            #         i=i-1
            #         found=True
            #         break
            solutions_storage.append({
                'W': W,
                'weights': weights,
                'values': values,
                'optimum': optimal_value,
                'recommendation_value_q1': qs[0],
                'recommendation_items_q1': sols[0],
                'recommendation_value_q2': qs[1],
                'recommendation_items_q2': sols[1],
                'recommendation_value_q3': qs[2],
                'recommendation_items_q3': sols[2],
                'recommendation_value_q4': qs[3],
                'recommendation_items_q4': sols[3],
                'recommendation_value_q5': qs[4],
                'recommendation_items_q5': sols[4],
                'recommendation_value_q6': qs[5],
                'recommendation_items_q6': sols[5],
                'recommendation_value_q7': qs[6],
                'recommendation_items_q7': sols[6],
                'recommendation_value_q8': qs[7],
                'recommendation_items_q8': sols[7],
                'recommendation_value_q9': qs[8],
                'recommendation_items_q9': sols[8],
                'recommendation_value_q10': qs[9],
                'recommendation_items_q10': sols[9],
                'id': id
            })
            instances.append({
                'W': W,
                'weights': weights,
                'values': values,
                'optimum': optimal_value,
                'recommendation_value_q1': qs[0],
                'recommendation_value_q2': qs[1],
                'recommendation_value_q3': qs[2],
                'recommendation_value_q4': qs[3],
                'recommendation_value_q5': qs[4],
                'recommendation_value_q6': qs[5],
                'recommendation_value_q7': qs[6],
                'recommendation_value_q8': qs[7],
                'recommendation_value_q9': qs[8],
                'recommendation_value_q10': qs[9],
                'id': id
            })
        else:
            instances.append({
                'W': W,
                'weights': weights,
                'values': values,
                'id': id
            })
        if i % 10000 == 0:
            print(i)

    with open(filename, 'w') as f:
        json.dump(instances, f)

    with open(solutions_filename, 'w') as f:
        json.dump(solutions_storage, f)

def knapsack_top_solutions(W, weights, values, num_solutions=5):
    """
    Finds the top num_solutions solutions to the 0/1 Knapsack problem for a given
    instance with capacity W, item weights, and item values.

    Parameters:
    W (int): The capacity of the knapsack.
    weights (List[int]): The weights of the items.
    values (List[int]): The values of the items.
    num_solutions (int): The number of top solutions to return.

    Returns:
    List[Tuple[int, List[int]]]: A list of tuples, where each tuple contains
    the total value of the solution and a list of 0/1 decision variables
    indicating which items are included in the solution.
    """
    # Generate all possible combinations of 0/1 decision variables
    decisions = list(product([0, 1], repeat=len(weights)))

    # Calculate the total value and weight for each solution
    results = []
    for decision in decisions:
        total_weight = sum([w * d for w, d in zip(weights, decision)])
        if total_weight <= W:
            total_value = sum([v * d for v, d in zip(values, decision)])
            results.append((total_value, decision))

    # Sort the results by total value in descending order
    results.sort(reverse=True)

    # Return the top num_solutions solutions
    return results[:num_solutions]

# def knapsack_value_filter(W, weights, values, min_value):
#     """
#     Finds all solutions to the 0/1 Knapsack problem for a given instance with
#     capacity W, item weights, and item values, where the total value is greater
#     than or equal to min_value.

#     Parameters:
#     W (int): The capacity of the knapsack.
#     weights (List[int]): The weights of the items.
#     values (List[int]): The values of the items.
#     min_value (int): The minimum total value of the solution.

#     Returns:
#     List[Tuple[int, List[int]]]: A list of tuples, where each tuple contains
#     the total value of the solution and a list of 0/1 decision variables
#     indicating which items are included in the solution.
#     """
#     # Generate all possible combinations of 0/1 decision variables
#     decisions = list(product([0, 1], repeat=len(weights)))

#     # Calculate the total value and weight for each solution
#     results = []
#     for decision in decisions:
#         total_weight = sum([w * d for w, d in zip(weights, decision)])
#         if total_weight <= W:
#             total_value = sum([v * d for v, d in zip(values, decision)])
#             if total_value >= min_value:
#                 results.append((total_value, decision))

#     # Sort the results by total value in descending order
#     results.sort(reverse=True)

#     return results

def knapsack_value_filter(W, weights, values, min_value):
    n = len(weights)
    items = sorted(zip(values, weights, range(n)), reverse=True)

    def generate_solutions(start, current_weight, current_value, decision):
        if current_weight > (W+.0000000000000002):
            return
        if current_value >= min_value:
            yield current_value, decision
        if start == n:
            return
        for i in range(start, n):
            value, weight, index = items[i]
            new_weight = current_weight + weight
            new_value = current_value + value
            new_decision = decision[:index] + (1,) + decision[index + 1:]
            yield from generate_solutions(i + 1, new_weight, new_value, new_decision)

    return list(generate_solutions(0, 0, 0, (0,) * n))

def sample_knapsack_instances(instances, n, exclude_set=None):
    
    if exclude_set is None:
        exclude_set = list()
    
    available_instances = [i for i in range(len(instances)) if instances[i] not in exclude_set]
    if len(available_instances) < n:
        raise ValueError("Not enough available instances to sample.")
    
    sampled_indices = random.sample(available_instances, n)
    sampled_instances = [instances[i] for i in sampled_indices]

    return sampled_instances
    # return random.sample(instances,n)

def load_knapsack_instances(file_path: str):
    with open(file_path, 'r') as f:
        instances = json.load(f)
    return instances

def sample_solution(solution_list, top_k=10):
    # print(len(solution_list[:top_k]))
    selected_solution = random.choice(solution_list[:top_k])
    return selected_solution

def get_min_value(model, W, weights, values):
    # values = list(np.divide(values,(W*np.array(weights))))
    # weights = list(np.divide(weights,W))
    model.eval()
    input = [W] + weights + values + [sum(weights)] + [sum(values)]
    input = torch.from_numpy(np.array(input).astype(np.float32))
    # input = input.reshape(1,input.shape[0])
    output = model(input).cpu().detach().numpy()#[0]
    solution = get_valid_prediction(output,weights,W)
    total_value = sum([v * d for v, d in zip(values, solution)])
    total_weight = sum([w * d for w, d in zip(weights, solution)])
    if total_weight>(W+.0000000000000002):
        print("ERROR: total weight higher than constraint")
        print(total_weight)
        return False
    else:
        return total_value, solution

def get_model_decision_by_thresholding(sigmoid_values, weights, constraint):
    solution = np.where(sigmoid_values >= 0.5, 1, 0)
    weight = 0
    for i in range(len(sigmoid_values)):
        if solution[i] == 1:
            weight+=weights[i]
    if weight>constraint:
        return 0
    else:
        return solution

def get_valid_prediction(sigmoid_values, weights, constraint):
     # Sort sigmoid values in descending order
    sorted_indices = np.argsort(sigmoid_values)[::-1]
    #print("sorted_indices")
    #print(sorted_indices)
    sorted_sigmoid = [sigmoid_values[int(i)] for i in sorted_indices]
    # Initialize solution with all zeros
    solution = np.zeros(len(weights))

    # Add items to solution until capacity is reached
    total_weight = 0
    history = []
    for i in range(len(sorted_sigmoid)):
        # Get index of next highest sigmoid value
        index = sorted_indices[i]
        if sigmoid_values[index] <=0:
            break
        # if sigmoid_values[index] < 0.5:
        #     continue
            # break
        # If adding the item would exceed capacity, skip it
        if total_weight + weights[index] > (constraint+.0000000000000002):
            continue
            # break

        # Add the item to the solution
        solution[index] = 1
        total_weight += weights[index]
        history.append(index)
        # If capacity is reached, break out of loop
        if total_weight == constraint:
            break
    if total_weight >(constraint+.0000000000000002):
        solution[history[len(history)-1]] = 0
    #print("solution indices")
    #print(solution)
    return solution.astype(int)

def get_item_indices(solution,instance):
    indexs = []
    W, weights, values = instance['W'], instance['weights'], instance['values']
    values_norm = list(np.divide(values,(W*np.array(weights))))
    weights_norm = list(np.divide(weights,W))
    # sorted_data = sorted(zip(values, weights), reverse=True)
    # values_sorted = [v for v,w in sorted_data]
    indexed_data = list(enumerate(zip(values_norm, weights_norm)))
    # Sort the data in descending order based on the values
    # print(indexed_data)
    sorted_data = sorted(indexed_data, key=lambda x: x[1][0], reverse=True)
    # Extract the sorted values and weights while keeping the old indices
    sorted_indices = [item[0] for item in sorted_data]
    # print(sorted_indices)
    total_value = 0
    total_weight = 0
    for i in range(0,len(solution)):
        if solution[i] == 1:
            v = values[sorted_indices[i]]
            w = weights[sorted_indices[i]]
            if total_weight+w <= W+.0000000000000002:
                total_weight+=w
                total_value+=v
                indexs.append(sorted_indices[i])
            #else:
                #return indexs, total_value
    return indexs, total_value


def test(instances, model, k):
    device = torch.device("cpu")
    test_mean = []
    test_opt = []
    test_loss = []
    solutions_output_list = []
    model.eval()
    model = model.to(device)
    criterion = nn.BCELoss()
    prediction_values = []
    with torch.no_grad():
        for instance in instances:
            # solve the knapsack problem
            W, weights, values = instance['W'], instance['weights'], instance['values']
            values = list(np.divide(values,(W*np.array(weights))))
            weights = list(np.divide(weights,W))
            sorted_data = sorted(zip(values, weights), reverse=True)
            #print(sorted_data)
            weights = [w for v, w in sorted_data]
            values = [v for v,w in sorted_data]
            W  = 1
            solutions = knapsack_value_filter(W,weights,values,0)
            solutions.sort(reverse=True)
            y_opt = list(solutions[0][1])
            value_opt = solutions[0][0]

            input = [W] + weights + values + [sum(weights)] + [sum(values)]
            input = torch.from_numpy(np.array(input).astype(np.float32))
            # X, y = X.to(device), y.to(device)
            input = input.to(device)
            y_opt = torch.tensor(y_opt).to(device)
            # input = input.reshape(1,input.shape[0])
            output = model(input)#[0]
            #y_opt = torch.from_numpy(np.array(y_opt)).to(torch.float32).to(device)
            loss = criterion(output.to(torch.float32), y_opt.to(torch.float32))
            output = output.cpu().detach().numpy()
            solution = get_valid_prediction(output,weights,W)
            idx,value_of_notnormed_solution =  get_item_indices(solution,instance)
            
            total_value = sum([v * d for v, d in zip(values, solution)])
            total_weight = sum([w * d for w, d in zip(weights, solution)])
            if total_weight>W+.0000000000000002:
                print("ERROR IN TEST SOLUTION weight>W")
            else:
                solutions_output_list.append(solution)
                prediction_values.append(total_value)
                test_mean.append(value_of_notnormed_solution/instance["optimum"])
                # test_mean.append(total_value/value_opt)
                if value_opt == total_value:
                    test_opt.append(1)
                else:
                    test_opt.append(0)
                test_loss.append(loss)

        test_mean = np.mean(test_mean)
        test_opt = np.mean(test_opt)
        test_loss = np.sum(test_loss)
    return test_mean, test_opt, test_loss, solutions_output_list

def generate_training_data(instances,models, top=0.1, min_value=0, lazy_prob= 0.5):
    # initialize training data lists
    X = []
    y = []
    achieved_values = []
    max_values = []
    model_values = []
    # loop over the instances and generate training data
    opt_count =0
    for instance in instances:
        # solve the knapsack problem
        W, weights, values = instance['W'], instance['weights'], instance['values']
        values = list(np.divide(values,(W*np.array(weights))))
        weights = list(np.divide(weights,W))
        sorted_data = sorted(zip(values, weights), reverse=True)
        
        # Extract the sorted weights
        weights = [w for v, w in sorted_data]
        values = [v for v,w in sorted_data]
        W = 1
        n_solutions = len(knapsack_value_filter(W,weights,values,0))
        if len(models)==0:
            min_value=0
        else:
            min_value,min_solution = get_min_value(models[len(models)-1],W,weights,values)
            # if min_value==0:
                #print("zero prediction")
        solutions = knapsack_value_filter(W,weights,values,min_value)
        solutions.sort(reverse=True)
        max_values.append(solutions[0][0])
        if len(models)>0:
            model_value = min_value/solutions[0][0]
            if model_value == 1:
                opt_count+=1
            model_values.append(model_value)
        # randomly sample a solution from the top k solutions
        lazy_diceroll = random.random()
        if len(models)>0 and lazy_diceroll <lazy_prob:
            solution = (min_value,min_solution)
        else:
            solution = sample_solution(solutions, top_k=max(1,round(top*n_solutions)))
        achieved_values.append(solution[0])
        # print(solution[0])
        # print(solutions[0][0])

        # add instance-solution pair to training data
        X.append([W]+weights+values + [sum(weights)] + [sum(values)])
        y.append(list(solution[1]))
    sample_values =np.divide(np.array(achieved_values), np.array(max_values)) 
    mean = np.mean(sample_values)
    # print(mean)
    if len(models)>0:
        model_mean = np.mean(model_values)
        model_opts = opt_count
        sample_opts = np.count_nonzero(sample_values == 1)
    else:
        model_mean = 0
        model_opts = opt_count
        sample_opts = np.count_nonzero(sample_values == 1)
    return X, y, mean, model_mean, sample_opts, model_opts

class Net(nn.Module):
    def __init__(self, k):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2*k+3, 90) # 25,550
        self.fc2 = nn.Linear(90,550) #550,1500
        # self.fc3 = nn.Linear(550,1500)
        self.fc4 = nn.Linear(550,90)
        self.fc5 = nn.Linear(90, 84)
        self.fc6 = nn.Linear(84, k)
        # self.lstm = nn.LSTM(input_size=1, hidden_size=128, num_layers=3, batch_first=True)
        # self.fc = nn.Linear(128, 12)
        self.sig = nn.Sigmoid()
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = self.fc6(x)
        # x, _ = self.lstm(x.unsqueeze(-1))
        # x = self.fc(x[:, -1, :])
        x = self.sig(x)
        return x

class Data(Dataset):
    def __init__(self, X, y):
        X = np.array(X)
        y = np.array(y)
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = torch.tensor(y)
        # self.X = normalize(self.X,p=1.0, dim=1)
        # print(self.X[0])
        self.len = self.X.shape[0]
       
    def __getitem__(self, index):
        return self.X[index], self.y[index]
   
    def __len__(self):
        return self.len

def train(train_dataloader, k,epochs=12, lr=0.005, step_size=6, gamma = 0.1):
    device = torch.device("cpu")
    net = Net(k).to(device)
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=0.001)
    criterion = nn.BCELoss()
    net.train()
    for epoch in range(epochs):
        # if epoch % step_size == 0 and epoch >0:
        #     lr = lr *gamma
        #     print("lr adjusted to "+str(lr))
        #     for g in optim.param_groups:
        #         g['lr'] = lr
        loss_print = 0
        for batch, (X, y) in enumerate(train_dataloader):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            output = net(X)
            loss = criterion(output.to(torch.float32), y.to(torch.float32))
            loss_print += loss.item()
            loss.backward()
            optimizer.step()
        #print(loss_print)
    return net

def knapsack_ml_loop(num_iterations, path_to_instances,path_to_test_set, sample_size,k, top=0.1, lazy_prob = 0.1):
    
    instances = load_knapsack_instances(path_to_instances)
    test_set = load_knapsack_instances(path_to_test_set)
    models = []
    means = []
    model_means = []
    total_data = []
    sample_opt_count = []
    model_opt_count = []

    test_means = []
    test_opts = []
    test_losses = []
    solutions = []
    for i in range(num_iterations):
        # start = time.time()
        # print("epoch: "+str(i))
        instances_sample = sample_knapsack_instances(instances, sample_size, exclude_set=total_data)
        total_data = total_data + instances_sample
        # start = time.time()
        X_train, y_train, mean, model_mean,sample_opts, model_opts = generate_training_data(total_data, models,top=top, min_value=0, lazy_prob=lazy_prob)
        # generating_end = time.time()
        # print("gen_end: "+str(generating_end-start))
        means.append(mean)
        model_means.append(model_mean)
        sample_opt_count.append(sample_opts/len(y_train))
        model_opt_count.append(model_opts/len(y_train))
        
        # print("mean economic performance in sample: "+str(mean))
        # print("optimal solutions in sample: "+str(sample_opts)+'/'+str(len(y_train)) + '('+str(sample_opts/len(y_train))+')')
        # print("mean economic performance of ML model: "+str(model_mean))
        # print("optimal solutions by model: "+str(model_opts)+'/'+str(len(y_train)) + '('+str(model_opts/len(y_train))+')')
        
        train_data = Data(X_train, y_train)
        train_dataloader = DataLoader(dataset=train_data, batch_size=max(1,int(sample_size/10)), shuffle=True)
        
        model = train(train_dataloader,k)
        models.append(model)

        test_mean, test_opt, test_loss, sol = test(test_set, model, k)

        # print(test_opt)
        test_means.append(test_mean)
        solutions.append(sol)
        test_opts.append(test_opt)
        test_losses.append(test_loss)
    return means,model_means,sample_opt_count,model_opt_count, test_means, test_opts, test_losses, solutions, models