import numpy as np
import torch
from env import Network
import gurobipy as gp
m = gp.Model()
m.Params.LogToConsole = 0
m.setParam('NonConvex', 2)
from scipy.optimize import minimize
from scipy.optimize import LinearConstraint, NonlinearConstraint

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

# Determinant of diagonal matrix.
def diag_det(diag_matrix: np.ndarray):
    ans = 1
    for x in diag_matrix:
        ans = ans * x
    return ans

def solve_ucb_SciPy(S, A, func, sigma, theta_hat, beta, phi_v):
        constraints = []
        action_space = np.eye(A)
        state_space = np.eye(S)
        condfident_set = lambda theta: np.dot(np.dot(sigma, theta - theta_hat), theta - theta_hat)
        constraints.append(NonlinearConstraint(condfident_set, 0, beta ** 2)) # bound of confidence set
        for i in range(S):
            for j in range(A):
                state_vec = torch.from_numpy(np.array([state_space[i]])).float().cuda()
                action_vec = torch.from_numpy(np.array([action_space[j]])).float().cuda()
                phi_all_s = func(state_vec, action_vec).T
                constraints.append(LinearConstraint(np.sum(phi_all_s, axis = 0), 0.9999, 1.0001, keep_feasible = True)) # sum of probs = 1
                constraints.append(LinearConstraint(phi_all_s, 0.0, 1.0, keep_feasible = True)) # each of prob that 0 <= prob <= 1
        try:
            res = minimize(lambda theta: np.dot(theta, phi_v),
                            x0=theta_hat / sum(theta_hat), 
                            method='trust-constr',
                            constraints=constraints)
        except ValueError:
            print('Something contains infs or NaNs')
            return None
        return res.x

def solve_ucb_Gurobi(S, A, func, sigma, theta_hat, beta, phi_v):
    action_space = np.eye(A)
    state_space = np.eye(S)
    theta = m.addMVar(len(theta_hat), name = 'theta')
    m.setObjective(phi_v @ theta, gp.GRB.MAXIMIZE)

    lhs = gp.QuadExpr()
    for i in range(len(theta_hat)):
        for j in range(len(theta_hat)):
            lhs += (theta[i]-theta_hat[i]) * sigma[i][j] * (theta[j]-theta_hat[j])
    m.addConstr(lhs <= beta ** 2, name='quad_constr')

    # The constraints of the probability simplex
    ########################################################
    for i in range(S):
        s = torch.from_numpy(np.array([state_space[i]])).float().cuda()
        for j in range(A):
            a = torch.from_numpy(np.array([action_space[j]])).float().cuda()
            with torch.no_grad():
                phi = func(s, a).T
                for k in range(S):
                    m.addConstr(phi[k] @ theta <= 1.0)
                    m.addConstr(phi[k] @ theta >= 0.0)
                total_prob = gp.quicksum([phi[k] @ theta for k in range(phi.shape[0])])
                m.addConstr(total_prob == 1)
    ########################################################
    m.optimize()
    maximizer = []
    for v in m.getVars():
        # print('%s %g' % (v.VarName, v.X))
        maximizer.append(v.X)
    print('Obj: %g' % m.ObjVal)
    return np.array(maximizer)[0:len(theta_hat)]

def solve_rbmle_Gurobi(S, A, func, phi_list, phi_v_bias):

    action_space = np.eye(A)
    state_space = np.eye(S)
    t = phi_list.shape[0]
    theta = m.addMVar(len(phi_v_bias), ub=10,
                      vtype=gp.GRB.CONTINUOUS, name='theta')
    bias = theta @ phi_v_bias

    # Testing
    ########################################################
    # likelihood = gp.quicksum(theta @ phi_list[j] for j in range(t))
    # z = m.addVar(vtype=gp.GRB.CONTINUOUS, ub=10, name='z')
    # m.addConstr(likelihood == z)
    # u = m.addVar(vtype=gp.GRB.CONTINUOUS, ub=10, name='u')
    # m.addGenConstrLog(z, u)
    # # print(phi_list)
    # m.setObjective(likelihood + bias, gp.GRB.MAXIMIZE)
    ########################################################

    # The code of computing the log operation t times.
    ########################################################
    z = m.addMVar(t, vtype=gp.GRB.CONTINUOUS, name='z')
    u = m.addMVar(t, vtype=gp.GRB.CONTINUOUS, name='u')
    for i in range(t):
        m.addConstr(theta @ phi_list[i] == z[i])
        m.addGenConstrLog(z[i], u[i])
    m.setObjective(gp.quicksum(u[i] for i in range(t)) + bias, gp.GRB.MAXIMIZE)
    ########################################################

    # The code of computing the log operation once.
    ########################################################
    # z = m.addVar(vtype=gp.GRB.CONTINUOUS, ub=10, name='z')
    # u = m.addVar(vtype=gp.GRB.CONTINUOUS, ub=10, name='u')
    # likelihood = theta @ phi_list[0]
    # for i in range(1, t):
    #     likelihood *= theta @ phi_list[i]
    # m.addConstr(likelihood == z)
    # m.addGenConstrLog(z, u)
    # m.setObjective(u + bias, gp.GRB.MAXIMIZE)
    ########################################################

    # The constraints of the probability simplex
    ########################################################
    for i in range(S):
        s = torch.from_numpy(np.array([state_space[i]])).float().cuda()
        for j in range(A):
            a = torch.from_numpy(np.array([action_space[j]])).float().cuda()
            with torch.no_grad():
                phi = func(s, a).T
                for k in range(S):
                    m.addConstr(phi[k] @ theta <= 1.0)
                    m.addConstr(phi[k] @ theta >= 0.0)
                total_prob = gp.quicksum([phi[k] @ theta for k in range(phi.shape[0])])
                m.addConstr(total_prob == 1)
    ########################################################

    # Solve the parameter
    m.optimize()
    maximizer = []
    for v in m.getVars():
        # print('%s %g' % (v.VarName, v.X))
        maximizer.append(v.X)
    print('Obj: %g' % m.ObjVal)
    return np.array(maximizer)[0:phi_list.shape[1]]

def solve_rbmle_regression(S, A, func, phi_list, n_list, phi_v_bias):
    action_space = np.eye(A)
    state_space = np.eye(S)
    if type(phi_v_bias) == np.float64:
        theta = m.addMVar(1)
    else:
        theta = m.addMVar(len(phi_v_bias))
    residuals = [n_list[i] - gp.quicksum(theta[j] * phi_list[i, j] for j in range(phi_list.shape[1])) for i in range(t)]
    m.setObjective(- gp.quicksum(residuals[i] * residuals[i] for i in range(t))
                   + gp.quicksum(theta[j] * phi_v_bias[j] for j in range(phi_v_bias.shape[0]))
                   - gp.quicksum(theta[i] * theta[i] for i in range(len(phi_v_bias))) , gp.GRB.MAXIMIZE)
    for i in range(S):
        s = torch.from_numpy(np.array([state_space[i]])).float().cuda()
        for j in range(A):
            a = torch.from_numpy(np.array([action_space[j]])).float().cuda()
            with torch.no_grad():
                phi = func(s, a).T
                for k in range(S):
                    m.addConstr(phi[k] @ theta <= 1, "c1")
                    m.addConstr(phi[k] @ theta >= 0, "c2")
                total_prob = sum([gp.quicksum(theta[j] * phi[i, j] for j in range(phi.shape[1])) for i in range(phi.shape[0])])
                m.addConstr( total_prob <= 1.0, "c3")
                m.addConstr( total_prob >= 0.0, "c4")
    m.optimize()
    maximizer = []
    for v in m.getVars():
        # print('%s %g' % (v.VarName, v.X))
        maximizer.append(v.X)
    print('Obj: %g' % m.ObjVal)
    return np.array(maximizer)

def solve_rbmle_approximated(S, A, func, phi_list, phi_v_bias):
    constraints = []
    state_space = np.eye(S)
    action_space = np.eye(A)
    obj = lambda theta: -(np.sum([np.log(np.dot(np.array([theta]), phi)) for phi in phi_list]) + np.dot(phi_v_bias, theta))
    for i in range(S):
        for j in range(A):
            state_vec = torch.from_numpy(np.array([state_space[i]])).float().cuda()
            action_vec = torch.from_numpy(np.array([action_space[j]])).float().cuda()
            phi_all_s = func(state_vec, action_vec).T
            # prob = lambda theta: np.sum(phi_all_s.dot(theta))
            # constraints.append(NonlinearConstraint(prob, 1, 1, keep_feasible = True)) # sum of probs = 1
            constraints.append(LinearConstraint(np.sum(phi_all_s, axis = 0), 0.9999, 1.0001, keep_feasible = True)) # sum of probs = 1
            constraints.append(LinearConstraint(phi_all_s, 0.0, 1.0, keep_feasible = True)) # each of prob that 0 <= prob <= 1
    
    try:
        res = minimize(obj,
                        x0=np.ones(len(phi_list[0]))/len(phi_list[0]), 
                        method='trust-constr',
                        constraints=constraints)
    except ValueError:
        print('Something contains infs or NaNs')
        return None
    return res.x

def solve_mle(S, A, func, phi_list):
    constraints = []
    state_space = np.eye(S)
    action_space = np.eye(A)

    # Define the reward-biased likelihood function
    def obj(theta):
        log_likelihood = np.sum([np.log(np.dot(np.array([theta]), phi)) for phi in phi_list]) 
        P = []
        for i in range(S):
            s = torch.from_numpy(np.array([state_space[i]])).float().cuda()
            P.append([])
            for j in range(A):
                a = torch.from_numpy(np.array([action_space[j]])).float().cuda()
                with torch.no_grad():
                    phi = func(s, a)
                P[i].append(np.dot(np.array([theta]), phi)[0].tolist())
        return -(log_likelihood)
    
    # Define the constraints
    for i in range(S):
        for j in range(A):
            state_vec = torch.from_numpy(np.array([state_space[i]])).float().cuda()
            action_vec = torch.from_numpy(np.array([action_space[j]])).float().cuda()
            phi_all_s = func(state_vec, action_vec).T
            constraints.append(LinearConstraint(np.sum(phi_all_s, axis = 0), 0.9999, 1.0001, keep_feasible = True)) # sum of probs = 1
            constraints.append(LinearConstraint(phi_all_s, 0.0, 1.0, keep_feasible = True)) # each of prob that 0 <= prob <= 1
    try:
        res = minimize(obj,
                        np.ones(len(phi_list[0]))/len(phi_list[0]), 
                        method='trust-constr',
                        constraints=constraints)
    except ValueError:
        print('Something contains infs or NaNs')
        return None
    return res.x

def solve_rbmle_exact(S, A, R, func, phi_list, state, alpha, gamma = 0.9, max_iterations=10**6, delta=10**-4):
    constraints = []
    state_space = np.eye(S)
    action_space = np.eye(A)

    # Define the reward-biased likelihood function
    def obj(theta):
        log_likelihood = np.sum([np.log(np.dot(np.array([theta]), phi)) for phi in phi_list]) 
        P = []
        for i in range(S):
            s = torch.from_numpy(np.array([state_space[i]])).float().cuda()
            P.append([])
            for j in range(A):
                a = torch.from_numpy(np.array([action_space[j]])).float().cuda()
                with torch.no_grad():
                    phi = func(s, a)
                P[i].append(np.dot(np.array([theta]), phi)[0].tolist())

        V = np.ones(S) / (1 - gamma) 
        for _ in range(max_iterations):
            previous_value_fn = V.copy()
            Q = np.einsum('ijk,ijk -> ij', P, R + gamma * V)
            V = np.max(Q, axis=1)
            if np.max(np.abs(V - previous_value_fn)) < delta:
                break
        return -(log_likelihood + alpha * V[state])
        # return -(log_likelihood)
    
    # Define the constraints
    for i in range(S):
        for j in range(A):
            state_vec = torch.from_numpy(np.array([state_space[i]])).float().cuda()
            action_vec = torch.from_numpy(np.array([action_space[j]])).float().cuda()
            phi_all_s = func(state_vec, action_vec).T
            constraints.append(LinearConstraint(np.sum(phi_all_s, axis = 0), 0.9999, 1.0001, keep_feasible = True)) # sum of probs = 1
            constraints.append(LinearConstraint(phi_all_s, 0.0, 1.0, keep_feasible = True)) # each of prob that 0 <= prob <= 1
    try:
        res = minimize(obj,
                        np.ones(len(phi_list[0]))/len(phi_list[0]), 
                        method='trust-constr',
                        constraints=constraints)
    except ValueError:
        print('Something contains infs or NaNs')
        return None
    return res.x

if __name__ == '__main__':
    np.random.seed(1)
    torch.manual_seed(1)
    S = 50
    A = 2
    K = 3
    t = 2
    state_space = np.eye(S)
    action_space = np.eye(A)
    s = torch.from_numpy(np.array([state_space[0]])).float().cuda()
    a = torch.from_numpy(np.array([action_space[0]])).float().cuda()

    phi_list = []
    func = Network(S+A, (S, K), temperature=1)
    state_vec = torch.from_numpy(np.array([state_space[0]])).float().cuda()
    action_vec = torch.from_numpy(np.array([action_space[0]])).float().cuda()
    phi = func(state_vec, action_vec).T[1]
    phi_list.append(phi)
    # state_vec = torch.from_numpy(np.array([state_space[1]])).float().cuda()
    # action_vec = torch.from_numpy(np.array([action_space[1]])).float().cuda()
    # phi = func(state_vec, action_vec).T[2]
    # phi_list.append(phi)
    state_vec = torch.from_numpy(np.array([state_space[2]])).float().cuda()
    action_vec = torch.from_numpy(np.array([action_space[2]])).float().cuda()
    phi = func(state_vec, action_vec)
    phi_v_bias = np.matmul(phi, np.array([4.0, 6.2, 3.2, 4.4, 5.5]))
    # maximizer = optimize(S, A, func, np.array(phi_list), phi_v_bias)
    R = np.random.rand(S, A, S)
    state = 3
    alpha = np.sqrt(2)
    maximizer = solve_rbmle_approximated(S, A, R, func, phi_list, state, alpha, gamma = 0.9)
    print(maximizer)

    transitions = []
    for i in range(S):
        s = torch.from_numpy(
            np.array([state_space[i]])).float().cuda()
        transitions.append([])
        for j in range(A):
            a = torch.from_numpy(
                np.array([action_space[j]])).float().cuda()
            with torch.no_grad():
                phi = func(s, a)
            transitions[i].append(
                np.dot(np.array([maximizer]), phi)[0].tolist())
    a = 1
