import itertools

import numpy as np


def convert_matrix(
    cost_matrix,
    supplier,
    demander,
):
    assert np.sum(supplier) == np.sum(demander)
    total_truck = np.sum(supplier)
    new_matrix = np.zeros((total_truck, len(demander)))
    cum_supplier = np.concatenate([[0], np.cumsum(supplier)])
    for i in range(len(cum_supplier) - 1):
        for j in range(len(demander)):
            new_matrix[cum_supplier[i] : cum_supplier[i + 1], j] = cost_matrix[i][j]
    return new_matrix


def convert_action_list(action_list, supplier, demander):
    assert np.sum(supplier) == np.sum(demander)
    new_action_list = []

    # Convert each action into a new one
    # First, we permutate every pattern of
    for action in action_list:
        supplier_dict = {}
        for i in range(len(supplier)):
            supplier_dict[i] = []
            permutations = list(
                itertools.permutations([truck for truck in range(supplier[i])])
            )

            for perm in permutations:
                i_truck = [
                    [0 for demander_num in range(len(demander))]
                    for _ in range(supplier[i])
                ]
                done = 0
                for j in range(len(demander)):
                    amount = int(action[i][j])
                    for x in range(done, done + amount):
                        i_truck[perm[x]][j] = 1
                    done += amount
                if i_truck not in supplier_dict[i]:
                    supplier_dict[i].append(i_truck)
        new_actions = make_new_action_list(supplier_dict, [], 0, [])

        for action in new_actions:
            new_action_list.append(np.array(action))
    return np.array(new_action_list)


def make_new_action_list(
    supplier_dict,
    ret,
    supplier_num,
    final_output,
):
    if supplier_num == len(supplier_dict):
        final_output.append(ret)
    else:
        for pattern in supplier_dict[supplier_num]:
            make_new_action_list(
                supplier_dict=supplier_dict,
                ret=ret + pattern,
                supplier_num=supplier_num + 1,
                final_output=final_output,
            )
    return final_output


def generate_loss(
    regime,
    cost_matrix,
    supplier,
    demander,
    time_horizon,
    epsilon_adversary,
    sigma=0.01,
):
    n_i = [
        [max(supplier[i], demander[j]) for j in range(len(demander))]
        for i in range(len(supplier))
    ]
    loss_list = []
    for t in range(time_horizon):
        L = [
            [
                np.random.uniform(
                    cost_matrix[i][j] - min(1 - cost_matrix[i][j], cost_matrix[i][j]),
                    cost_matrix[i][j] + min(1 - cost_matrix[i][j], cost_matrix[i][j]),
                    int(n_i[i][j]),
                )
                for j in range(len(demander))
            ]
            for i in range(len(supplier))
        ]

        loss_list.append(L)

    return loss_list


def convert_loss_list(
    loss_list,
    supplier,
    demander,
):
    new_loss_list = []
    for t in range(len(loss_list)):

        new_loss = np.zeros((np.sum(supplier), len(demander), 1))
        done = 0
        for i in range(len(supplier)):
            for j in range(len(demander)):
                for truck_num in range(supplier[i]):
                    new_loss[done + truck_num, j, 0] = loss_list[t][i][j][truck_num]
            done += supplier[i]
        new_loss_list.append(new_loss)

    return new_loss_list


def objective(m, lhat_sum, x_t, beta, n_i, gamma):
    ret = np.sum((m + lhat_sum) * x_t) + np.sum(
        beta
        * n_i
        * (
            x_t / n_i
            - 1
            - np.log(x_t / n_i)
            + gamma * (x_t / n_i + (1 - x_t / n_i) * np.log(1 - x_t / n_i))
        )
    )
    return ret


def OFTRL_CVXPY(
    m,
    lhat_sum,
    action_list,
    beta,
    n_i,
    gamma,
    p_initial,
):
    supplier_num = action_list.shape[1]
    demander_num = action_list.shape[2]

    # Initialize paramters
    T_init = 100
    T_min = 1
    cool = 0.999999
    iter_max = 3000

    p_current = p_initial

    x_current = np.zeros((supplier_num, demander_num))
    for i in range(len(action_list)):
        x_current = x_current + p_current[i] * action_list[i]
    f_current = objective(m, lhat_sum, x_current, beta, n_i, gamma)

    T = T_init
    for iter in range(iter_max):
        # generate a new solution
        p_new = abs(p_current * (1 + 0.005 * np.random.randn(len(action_list))))
        p_new /= np.sum(p_new)

        x_new = (p_new @ action_list.reshape(len(action_list), -1)).reshape(
            supplier_num, demander_num
        )
        f_new = objective(m, lhat_sum, x_new, beta, n_i, gamma)

        if f_new < f_current:
            p_current = p_new
            f_current = f_new
        T = cool * T
        if T < T_min:
            break
    x_t = np.zeros((supplier_num, demander_num))
    for i in range(len(action_list)):
        x_t = x_t + p_current[i] * action_list[i]

    p_t = p_current
    chosen_action_index = np.random.choice(len(action_list), 1, p=p_t)[0]
    a_t = action_list[chosen_action_index]

    return x_t, a_t, p_t, chosen_action_index
