import gurobipy as gp
from gurobipy import GRB
import numpy as np
import continuous_scores
import discrete_scores
import math
from utils import find_cycles

'''
Defines and solves the master problem (RMLP/RMIP) in column generation
Handles both linear relaxation and integer programming with callbacks
'''

class MasterProblem:
    def __init__(self, data, data_type, K, C_set, linear_relaxation, regu_Lambda, save_path):
        self.data = data
        self.K = K
        self.C_set = C_set
        self.linear_relaxation = linear_relaxation
        self.regu_Lambda = regu_Lambda
        self.save_path = save_path
        self.model = None
        self.X = None
        self.data_type = data_type
        self.sol_X = None
        # self.graph_ls = []

    def define_master_problem(self):
        n = self.data.n
        ndata = self.data.ndata
        if self.data_type=='C':
            cost = continuous_scores.cost
        else:
            cost = discrete_scores.cost

        '''
        Calculate costs (negative scores)
        '''
        cost_dic = {}
        for i in range(n):
            k = self.K[i].shape[0]
            cost_dic[i] = []
            for j in range(k):
                J = np.where(self.K[i][j, :] == 1)[0]
                if self.data_type=='C':
                    cost_dic[i].append(np.log(cost(self.data, J, i)) * ndata / 2 + (1 + np.log(2 * np.pi)) * ndata / 2 + self.regu_Lambda * len(J))
                else: # 'D'
                    cost_dic[i].append(cost(self.data, J, i) * ndata + self.regu_Lambda*(self.data.arity[i]-1)*math.prod([self.data.arity[j] for j in J]))

        '''
        Define variables
        '''
        self.model = gp.Model('Master Problem')
        self.model.Params.LogToConsole = 0
        self.model.Params.LogFile = self.save_path
        self.model.Params.OutputFlag = 1
        self.model.Params.Threads = 1
        # self.model.setParam(GRB.Param.TimeLimit, 60)
        self.model.update()
        vtype = GRB.CONTINUOUS if self.linear_relaxation else GRB.BINARY
        self.X = []
        for i in range(n):
            var_list = []
            for j in range(len(self.K[i])):
                var = self.model.addVar(lb=0, vtype=vtype, name=f'X{i}[{j}]')
                var_list.append(var)
            self.X.append(var_list)

        '''
        Define objectives
        '''
        total_cost = gp.quicksum(self.X[i][j] * cost_dic[i][j] for i in range(n) for j in range(len(self.K[i])))
        self.model.setObjective(total_cost, GRB.MINIMIZE)

        '''
        Define constraints
        '''
        for i in range(n):
            self.model.addConstr(gp.quicksum(self.X[i][j] for j in range(len(self.K[i]))) == 1, name=f'one[{i}]')

        for C_index in range(len(self.C_set)):
            C = self.C_set[C_index]
            cluster = gp.quicksum(self.X[i][j] for i in C for j in range(len(self.K[i])) if len([v for v in np.where(self.K[i][j, :] == 1)[0] if v in C]) != 0)
            self.model.addConstr(-cluster >= 1 - len(C), name=f'cluster[{C_index}]')

    def set_initial_solution(self, sol_X):
        self.sol_X = sol_X

    def solve(self):
        global graph
        d = self.data
        n = d.n
        self.define_master_problem()
        if self.linear_relaxation==False: # IP
            # warmstart
            if self.sol_X is not None:
                for i in range(n):
                    for j in range(len(self.sol_X[i])):
                        self.X[i][j].Start = self.sol_X[i][j]
            self.model.update()

            def cb(model, where): # call back function to add cluster
                global graph
                if where == GRB.Callback.MIPSOL:
                    # Get current solution values
                    X_vals = []
                    for i in range(n):
                        X_vals.append([model.cbGetSolution(self.X[i][j]) for j in range(len(self.K[i]))])
                    # from IP solution to graph
                    graph = np.zeros((n,n))
                    for i in range(n):
                        choice = np.argmax(X_vals[i])  # the chosen pattern index
                        graph[i,:] = self.K[i][choice,:] # the chosen pattern
                    # self.graph_ls.append(graph)
                    C, self.C_set = find_cycles(graph, self.C_set)
                    if C != None:
                        cluster = gp.quicksum(self.X[i][j] for i in C for j in range(len(self.K[i])) if len([v for v in np.where(self.K[i][j, :] == 1)[0] if v in C]) != 0)
                        # Add lazy constraint
                        model.cbLazy(-cluster >= 1 - len(C))
            self.model.Params.lazyConstraints = 1
            self.model.optimize(cb)
            return self.model, self.C_set, self.X
        else:
            self.model.optimize()
            return self.model, self.X
        # if self.model.status != GRB.OPTIMAL:
        #     print(self.model.status)
        #     raise RuntimeError("Master problem optimization failed")

