import gurobipy as gp
from gurobipy import GRB
import numpy as np


# Our method
def MIP(S, n):
    m = len(S)

    # Create a Gurobi model
    model = gp.Model()
    model.setParam("OutputFlag", 0)
    model.setParam("NonConvex", 2)

    # Create binary decision variables
    X = model.addVars(m, n, vtype=GRB.BINARY, name="X")

    # Create continuous decision variables for y
    y = model.addVars(n, vtype=GRB.CONTINUOUS, name="y")
    z = model.addVars(m, n, m, vtype=GRB.BINARY, name="z")
    objective_expr = gp.quicksum(
        y[j]
        * gp.quicksum(
            [gp.quicksum([z[i, j, k] * S[i, k] for i in range(m)]) for k in range(m)]
        )
        for j in range(n)
    )

    model.setObjective(objective_expr, sense=GRB.MAXIMIZE)

    # Each task is assigned to at least one group
    for i in range(m):
        model.addConstr(
            gp.quicksum(X[i, j] for j in range(n)) >= 1, name=f"constraint1_{i}"
        )

    # introduce the auxiliary variable y
    for j in range(n):
        model.addConstr(
            gp.quicksum(X[i, j] for i in range(m)) * y[j] == 1,
            name=f"constraint2_{j}",
        )

    # each group has at least one task
    for j in range(n):
        model.addConstr(
            gp.quicksum(X[i, j] for i in range(m)) >= 1, name=f"constraint3_{j}"
        )

    # no same groups
    for j1 in range(n):
        for j2 in range(j1, n):
            if j1 != j2:
                model.addConstr(
                    gp.quicksum(
                        X[i, j1] + X[i, j2] - 2 * X[i, j1] * X[i, j2] for i in range(m)
                    )
                    >= 1,
                    name=f"constraint4_{j1}_{j2}",
                )

    # introduce z to relax to a quadratic program
    for i in range(m):
        for j in range(n):
            for k in range(m):
                model.addConstr(
                    z[i, j, k] - X[i, j] * X[k, j] == 0, name=f"constraint5_{i}_{j}_{k}"
                )

    # Optimize the model
    model.optimize()
    groups = [[] for i in range(n)]
    soln = np.zeros((m, n))
    if model.status == GRB.OPTIMAL:

        for j in range(n):
            for i in range(m):
                soln[i, j] = X[i, j].X
                if np.isclose(X[i, j].X, 1):
                    groups[j].append(i)
    else:
        print("No optimal solution found.")
    return groups, model.objVal

# =================================================================================================
# =================================================================================================

# TAG
def gen_task_combinations(revised_integrals, tasks, rtn, index, path, path_dict):
    if index >= len(tasks):
        return

    for i in range(index, len(tasks)):
        cur_task = tasks[i]
        new_path = path
        new_dict = {k: v for k, v in path_dict.items()}

        # Building from a tree with two or more tasks...
        if new_path:
            new_dict[cur_task] = 0.0
            for prev_task in path_dict:
                new_dict[prev_task] += revised_integrals[prev_task][cur_task]
                new_dict[cur_task] += revised_integrals[cur_task][prev_task]
            new_path = "{}|{}".format(new_path, cur_task)
            rtn[new_path] = new_dict
        else:  # First element in a new-formed tree
            new_dict[cur_task] = 0.0
            new_path = cur_task

        gen_task_combinations(revised_integrals, tasks, rtn, i + 1, new_path, new_dict)

        if "|" not in new_path:
            new_dict[cur_task] = -1e6
            rtn[new_path] = new_dict


def select_groups(index, cur_group, best_group, best_val, splits, num_tasks, rtn_tup):
    # Check if this group covers all tasks.
    task_set = set()
    for group in cur_group:
        for task in group.split("|"):
            task_set.add(task)
    if len(task_set) == num_tasks:
        best_tasks = {task: -1e6 for task in task_set}

        # Compute the per-task best scores for each task and average them together.
        for group in cur_group:
            for task in cur_group[group]:
                best_tasks[task] = max(best_tasks[task], cur_group[group][task])
        group_avg = np.mean(list(best_tasks.values()))

        # Compare with the best grouping seen thus far.
        if group_avg > best_val[0]:
            # print(best_group)
            best_val[0] = group_avg
            best_group.clear()
            for entry in cur_group:
                best_group[entry] = cur_group[entry]

    # Base case.
    if len(cur_group.keys()) == splits:
        return

    # Back to combinatorics
    for i in range(index, len(rtn_tup)):
        selected_group, selected_dict = rtn_tup[i]

        new_group = {k: v for k, v in cur_group.items()}
        new_group[selected_group] = selected_dict

        if len(new_group.keys()) <= splits:
            select_groups(
                i + 1, new_group, best_group, best_val, splits, num_tasks, rtn_tup
            )


def TAG(tag_info, n):
    tasks = list(range(len(tag_info)))
    revised_integrals = {
        "{}".format(tasks[i]): {
            "{}".format(tasks[j]): tag_info[i, j] for j in range(len(tasks))
        }
        for i in range(len(tasks))
    }
    rtn = {}
    tasks = list(revised_integrals.keys())
    num_tasks = len(tasks)
    task_combinations = gen_task_combinations(revised_integrals,
        tasks=tasks, rtn=rtn, index=0, path="", path_dict={}
    )
    for group in rtn:
        if '|' in group:
            for task in rtn[group]:
                rtn[group][task] /= (len(group.split('|')) - 1)

    assert(len(rtn.keys()) == 2**len(revised_integrals.keys()) - 1)
    rtn_tup = [(key,val) for key,val in rtn.items()]
    selected_group = {}
    selected_val = [-100000000]

    select_groups(
        index=0,
        cur_group={},
        best_group=selected_group,
        best_val=selected_val,
        splits=n,
        num_tasks=num_tasks,
        rtn_tup=rtn_tup,
    )
    groups = []
    for key in selected_group:
        groups.append([int(t) for t in key.split("|")])
    return groups, selected_val[0]


def MIP_constr(S, n, min_count, max_count):
    m = len(S)

    # Create a Gurobi model
    model = gp.Model()
    model.Params.OutputFlag = 0  # 0 means no output, 1 (default) means standard output
    model.setParam("NonConvex", 2)
    # Create binary decision variables
    X = model.addVars(m, n, vtype=GRB.BINARY, name="X")

    # Create continuous decision variables for y
    y = model.addVars(n, vtype=GRB.CONTINUOUS, name="y")

    z = model.addVars(m, n, m, vtype=GRB.BINARY, name="z")

    # if min_count != max_count:
    #     phi = model.addVars(m, vtype=GRB.CONTINUOUS, name="phi")

    # Set objective function

    objective_expr = gp.quicksum(
        y[j]
        * gp.quicksum(
            [gp.quicksum([z[i, j, k] * S[i, k] for i in range(m)]) for k in range(m)]
        )
        for j in range(n)
    )
    model.setObjective(objective_expr, sense=GRB.MAXIMIZE)
    # Each task is assigned to at least one group
    for i in range(m):
        model.addConstr(
            gp.quicksum(X[i, j] for j in range(n)) >= 1,
            name=f"constraint1_{i}",
        )

    if min_count != max_count:
        # more than min_count
        for j in range(n):
            model.addConstr(
                gp.quicksum(X[i, j] for i in range(m)) >= min_count,
                name=f"constraint7_1_{j}",
            )
        if max_count != m:
            # less than max_count
            for j in range(n):
                model.addConstr(
                    gp.quicksum(X[i, j] for i in range(m)) <= max_count,
                    name=f"constraint7_2_{j}",
                )
    else:
        # less than max_count
        for j in range(n):
            model.addConstr(
                gp.quicksum(X[i, j] for i in range(m)) == max_count,
                name=f"constraint7_1_{j}",
            )

    # introduce the auxiliary variable y
    for j in range(n):
        model.addConstr(
            gp.quicksum(X[i, j] for i in range(m)) * y[j] == 1,
            name=f"constraint2_{j}",
        )

    # introduce z to relax to a quadratic program
    for i in range(m):
        for j in range(n):
            for k in range(m):
                model.addConstr(
                    z[i, j, k] - X[i, j] * X[k, j] == 0, name=f"constraint5_{i}_{j}_{k}"
                )

    # no same groups
    for j1 in range(n):
        for j2 in range(j1, n):
            if j1 != j2:
                model.addConstr(
                    gp.quicksum(
                        X[i, j1] + X[i, j2] - 2 * X[i, j1] * X[i, j2] for i in range(m)
                    )
                    >= 1,
                    name=f"constraint4_{j1}_{j2}",
                )

    # Optimize the model
    model.optimize()
    groups = [[] for i in range(n)]
    soln = np.zeros((m, n))
    # Print the solution
    if model.status == GRB.OPTIMAL:
        for j in range(n):
            for i in range(m):
                soln[i, j] = X[i, j].X
                if np.isclose(X[i, j].X, 1):
                    groups[j].append(i)
    else:
        print("No optimal solution found.")
    return groups, model.objVal


import random


def random_sampling_with_const(n, m, min_size, max_size, allow_repeats=True):
    if max_size * m < n:
        raise ValueError(
            "Impossible to distribute tasks with given min/max sizes and number of groups"
        )

    tasks = list(range(n))
    groups = []
    covered_tasks = set()

    for _ in range(
        m - 1
    ):  # Handle the last group separately to ensure all tasks are covered
        valid_group = False
        while not valid_group:
            group_size = random.randint(min_size, max_size)
            if allow_repeats:
                group = random.sample(tasks, group_size)
            else:
                remaining_tasks = list(set(tasks) - covered_tasks)
                group = random.sample(
                    remaining_tasks, min(group_size, len(remaining_tasks))
                )

            if group not in groups:  # Ensure the group is distinct
                groups.append(group)
                covered_tasks.update(group)
                valid_group = True

    # Ensure the last group covers any remaining tasks
    remaining_tasks = list(set(tasks) - covered_tasks)
    if allow_repeats:
        last_group = list(
            set(
                remaining_tasks
                + random.sample(tasks, random.randint(min_size, max_size))
            )
        )
    else:
        last_group = remaining_tasks

    groups.append(last_group)

    return groups
