from collections import defaultdict

import numpy as np
import torch
from z3.z3 import Int, Solver, Or, Sum, ToInt, Abs, sat, If, ToReal, Optimize, And


def convert_z_t_src_dst(z_t, z_t_old):
    # z_t_new with shape(2, m)  z_t_new[1] is dst, z_t_new[0] is src
    src_list, dst_list = z_t[1].cpu().tolist(), z_t[0].cpu().tolist()
    dst_list_old = z_t_old[0].cpu().tolist()
    src_list = src_list[len(dst_list_old):]
    dst_list = dst_list[len(dst_list_old):]
    return None


def check_predecessor_balance(x_n, z_t, rho):
    # z_t with shape(A, B) start with node 1
    predecessors = defaultdict(list)
    dst_index = len(z_t[0])
    for i in range(len(z_t)):
        for j in range(len(z_t[i])):
            if int(z_t[i][j]) == 1:
                src = j + 1
                dst = dst_index + i + 1
                predecessors[dst].append(src)
    unsatisfy = None
    src_list = None
    for v, preds in predecessors.items():
        n_v_0 = sum(1 for u in preds if x_n[u-1][0] == 0)
        n_v_1 = sum(1 for u in preds if x_n[u-1][0] == 1)
        n_v = n_v_0 + n_v_1
        if n_v > 0:  
            imbalance_ratio = (np.floor(abs(n_v_0 - n_v_1) / 2)) / (n_v / 2)
            if imbalance_ratio > rho:
                if unsatisfy is None:
                    unsatisfy = []
                if src_list is None:
                    src_list = set()
                unsatisfy.append(v)
                for u in preds:
                    src_list.add(u)
    if unsatisfy is not None and src_list is not None:
        return unsatisfy, list(src_list)
    else:
        return None, None


def solve_dag_constraint(x_n, z_t, src_list=None, unsatisfy=None, rho=0.5):
    # z_t with shape(A, B) start with 1
    solver = Optimize()
    solver.set('timeout', 1000*1)
    #solver.set('max_memory', 1)
    dst_index = len(z_t[0])
    edges = {}
    for i in range(len(z_t)):
        if (unsatisfy is None) or ((dst_index+i+1) in unsatisfy):
            for j in range(len(z_t[i])):
                if (src_list is None) or (j+1 in src_list):
                    var = Int(f'x_{dst_index+i+1}_{j+1}')
                    edges[f'x_{dst_index+i+1}_{j+1}'] = var
                    solver.add(Or(var == 0, var == 1))
                    solver.add_soft(var == int(z_t[i][j]))
                    #solver.set_initial_value(var, int(z_t[i][j]))

    for i in range(len(z_t)):
        if (unsatisfy is None) or ((dst_index+i+1) in unsatisfy):
            solver.add(Or([edges[var] == 1 for var in edges if var.startswith(f'x_{dst_index+i+1}')]))
    for i in range(len(z_t)):
        if (unsatisfy is None) or ((dst_index+i+1) in unsatisfy):
            n0 = Sum([If(And(edges[f'x_{dst_index+i+1}_{j+1}'] == 1, int(x_n[j][0]) == 0), 1, 0) for j in range(len(z_t[i])) if (src_list is None) or (j+1 in src_list)])
            n1 = Sum([If(And(edges[f'x_{dst_index+i+1}_{j+1}'] == 1, int(x_n[j][0]) == 1), 1, 0) for j in range(len(z_t[i])) if (src_list is None) or (j+1 in src_list)])
            diff_half = ToInt(Abs(ToReal(n0) - ToReal(n1)) / 2)
            avg_half = ToReal((n0 + n1) / 2)
            solver.add(diff_half / avg_half <= rho)

    z = []
    try:
        if solver.check() == sat:
            model = solver.model()
            for i in range(len(z_t)):
                if (unsatisfy is None) or ((dst_index+i+1) in unsatisfy):
                    z_i = []
                    for j in range(len(z_t[i])):
                        if (src_list is None) or (j+1 in src_list):
                            z_i.append(model[edges[f'x_{dst_index+i+1}_{j+1}']].as_long())
                        else:
                            z_i.append(0)
                else:
                    z_i = z_t[i].copy()
                z.append(z_i)
                assert len(z[i]) == len(z_t[i])
            return z
        else:
            return None
    except Exception as e:
        print(f"Solver error: {e}")
        print(f'x_n: {x_n}, z_t: {z_t}')
        return None


def solve_edge_constraint(x_n, z_t_new, z_t_old, unsatisfy, rho):
    # z_t_new with shape(2, m)  z_t_new[1] is dst, z_t_new[0] is src
    solver = Optimize()
    solver.set('timeout', 1000)
    src_list, dst_list = z_t_new[1].cpu().tolist(), z_t_new[0].cpu().tolist()
    dst_list_old = z_t_old[0].cpu().tolist()
    src_list = src_list[len(dst_list_old):]
    dst_list = dst_list[len(dst_list_old):]
    src_index = min(dst_list) - 1
    z_t = []
    for _ in sorted(list(set(dst_list))):
        z_t.append([0 for _ in range(src_index)])
    for src, dst in zip(src_list, dst_list):
        z_t[dst-src_index-1][src-1] = 1
    # z_t with shape(A, B)
    z_t_solution = solve_dag_constraint(x_n, z_t, src_list=src_list, unsatisfy=unsatisfy, rho=rho)
    if z_t_solution is not None:
        src_list_new = z_t_old[1].cpu().tolist()
        dst_list_new = z_t_old[0].cpu().tolist()
        for i in range(len(z_t_solution)):
            for j in range(len(z_t_solution[i])):
                if z_t_solution[i][j] == 1:
                    src_list_new.append(j+1)
                    dst_list_new.append(src_index + 1 + i)
        src_new = torch.tensor(src_list_new, device=z_t_new.device)
        dst_new = torch.tensor(dst_list_new, device=z_t_new.device)
        z_t_constrain = torch.stack((dst_new, src_new), dim=0)
        return z_t_constrain
    else:
        return None