from ExtensiveGame import ExtensiveGame
from CFR import CFR
from gurobipy import *
from SequenceNash import SequenceNash


class BestNashStatic:
    def __init__(self, fname, strategy, best=True, epsilon=0, player=0):
        # save game name
        self.fname = fname
        self.fixed_strategy = strategy
        self.fixed_strategy[-1] = [1]

        # Game variables
        self.game = ExtensiveGame()
        self.game.load(fname)

        # Variables to save
        self.is_to_node = {}
        self.action_index = {}
        self.action = 1
        self.i_set = 0
        self.i_sets = {}
        self.best = best
        self.epsilon = epsilon
        self.player = player

        # Constraints
        self.x = None
        self.obj = None
        self.u = None
        self.new_u = None
        self.m = None
        self.new_obj = None

        self.u_exp = None
        self.obj_exp = None
        # Just information variable
        self.nodes = 0

        # Saved solution in object
        self.solution = None

    def create_variables_public(self):
        self.create_variables(self.game.root, (-1, -1), 0, 1.0)

    def solve(self):
        self.nodes = 0
        self.create_variables(self.game.root, (-1, -1), 0, 1.0)
        self.solution = self.create_expresions()
        return self.solution  # print(res)

    def create_variables(self, node, last_node_action, last_node_strat, chance):
        self.nodes += 1
        if node.player == 3:
            # Terminal node
            if last_node_action not in self.is_to_node:
                self.is_to_node[last_node_action] = []
            self.is_to_node[last_node_action].append((node, chance, last_node_strat))
        elif node.player == 2:
            # Chance node
            for child, next_chance in zip(node.children, node.chance):
                self.create_variables(child, last_node_action, last_node_strat, chance * next_chance)
        elif node.player == self.player:
            # My agent
            create_strats = False
            if (last_node_strat, node.i_set) not in self.action_index:
                self.action_index[(last_node_strat, node.i_set)] = []
                create_strats = True
            for i, child in enumerate(node.children):
                if create_strats:
                    self.action_index[(last_node_strat, node.i_set)].append(self.action)
                    self.action += 1
                self.create_variables(child, last_node_action, self.action_index[((last_node_strat, node.i_set))][i],
                                      chance)
        else:
            # Opponent
            if last_node_action not in self.is_to_node:
                self.is_to_node[last_node_action] = []
            if node.i_set not in self.is_to_node[last_node_action]:
                self.is_to_node[last_node_action].append(node.i_set)
                if node.i_set not in self.i_sets:
                    self.i_sets[node.i_set] = self.i_set
                    self.i_set += 1
            for i, child in enumerate(node.children):
                self.create_variables(child, (node.i_set, i), last_node_strat, chance)

    def create_expresions(self):
        seq_nash = SequenceNash(self.fname)
        nash_value = seq_nash.solve(self.player)

        self.m = Model("lp")

        # variables
        self.x = self.m.addVars(self.action, lb=0, ub=1, vtype=GRB.CONTINUOUS)
        self.obj = self.m.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name="objective")
        self.u = self.m.addVars(len(self.i_sets), lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS)
        self.new_u = self.m.addVars(len(self.i_sets), lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS)
        self.new_obj = self.m.addVar(lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name="objective")
        self.m.update()

        if self.best:
            self.m.setObjective(self.new_obj, GRB.MAXIMIZE)
        else:
            self.m.setObjective(self.new_obj, GRB.MINIMIZE)

        # print(self.epsilon)
        self.m.addConstr(self.obj >= (nash_value + self.epsilon))
        self.m.addConstr(self.x[0] == 1)

        for key, value in self.action_index.items():
            expr = LinExpr()
            for val in value:
                expr.add(self.x[val])
            self.m.addConstr(expr == self.x[key[0]])

        recreated = {}
        for key, value in self.is_to_node.items():
            if key[0] not in recreated:
                recreated[key[0]] = []
            recreated[key[0]].append(value)

        for key, value in recreated.items():
            high_expr = LinExpr()
            strategy_vector = self.fixed_strategy[key]
            for i, action_values in enumerate(value):
                expr = LinExpr()
                for val in action_values:
                    if isinstance(val, int):
                        expr.add(self.new_u[val])
                    else:
                        expr.add(self.x[val[2]] * (-1 if self.player == 1 else 1) * val[0].value * val[1])
                high_expr.add(expr * strategy_vector[i])
            if key == -1:
                self.m.addConstr(high_expr == self.new_obj)
            else:
                self.m.addConstr(high_expr == self.new_u[key])

        for key, value in self.is_to_node.items():
            expr = LinExpr()
            for val in value:
                if isinstance(val, int):
                    expr.add(self.u[self.i_sets[val]])
                else:
                    expr.add(self.x[val[2]] * (-1 if self.player == 1 else 1) * val[0].value * val[1])
            if key[0] == -1:
                self.m.addConstr(expr >= self.obj)
            else:
                self.m.addConstr(expr >= self.u[self.i_sets[key[0]]])

        self.m.setParam('OutputFlag', False)
        self.m.optimize()

        return self.m.objVal

    def objective(self, x):
        # print(x[self.action])
        return x[self.action]

    def print_strategy(self):
        if self.solution is None:
            print("Instance not yet solved.")
        else:
            for key, value in self.action_index.items():
                print(str(key) + ":", end="")
                if self.x[key[0]].x < 0.000001:
                    print(" Parent strategy is 0", end="")
                else:
                    for val in value:
                        print(" " + str(self.x[val].x / self.x[key[0]].x), end="")
                print()

    def strategy_in_cfr_format(self):
        if self.solution is None:
            print("Instance not yet solved.")
        else:
            cfrqr = CFR(self.fname)
            cfrqr.create_isets()
            # cfrqr.initialize_strategy()
            for key, value in self.action_index.items():
                temp = [0] * len(value)
                i = 0
                if self.x[key[0]].x < 0.000001:
                    temp = [1 / len(value)] * len(value)
                else:
                    for val in value:
                        temp[i] = self.x[val].x / self.x[key[0]].x
                        i += 1
                cfrqr.strategy[self.player][key[1]] = temp
            return cfrqr.strategy
