"""
Adapted from prior work; see paper references
"""

try:
    import mosek 
    import mosek.fusion 
    from   mosek.fusion import * 
    mosek_available = True
except ImportError:
    mosek_available = False

try:
    import gurobipy as gp
    gurobi_available = True
except ImportError:
    gurobi_available = False

try:
    import networkx as nx
    networkx_available = True
except ImportError:
    networkx_available = False

try:
    from gswalk_kernel_weights import (
        GSwalk_poly_aug_many,
        GSwalk_kernel_many,
        GSwalk_poly_aug_fast,
    )
    gswalk_available = True
except ImportError:
    gswalk_available = False

import numpy as np
import numpy.linalg
import scipy
import scipy.linalg
from scipy.spatial.distance import pdist, squareform
from functools import reduce
import time

def CompleteRand(n, B = 1):
    """
    Draw B assignments of subjects from a completely randomized design
    
    Args:
        n: number of subjects
        B: number of draws
    Returns:
        list of lists of +/-1 denoting assignment
    """
    A = []
    for b in range(B):
        n2 = n//2
        zz = np.array([-1,]*n2+[1,]*n2)
        np.random.shuffle(zz)
        A.append(zz.tolist())
    return A

def Bernoulli(n, B = 1):
    """
    Draw B assignments of subjects from a Bernoulli design
    
    Args:
        n: number of subjects
        B: number of draws
    Returns:
        list of lists of +/-1 denoting assignment
    """
    A = []
    for b in range(B):
        zz = np.random.choice([-1, 1], size=n)
        A.append(zz.tolist())
    return A

def BlockOrthant(x, B = 1):
    """
    Draw B assignments of subjects from a design blocking on the sign of each
    covariate (i.e., block by orthant)
    
    Args:
        x: n by d array of covariates
        B: number of draws
    Returns:
        list of lists of +/-1 denoting assignment
    """
    A = []
    for b in range(B):
        n,d = x.shape
        nar = np.arange(n)
        quad = reduce(lambda z,y: 2*z+y, np.signbit(x.T))
        leftovers = []
        assgn = np.zeros(n,dtype=np.int_)
        for j in set(quad):
            idx = nar[quad==j].tolist()
            if len(idx)%2 != 0:
                idxleft = idx[np.random.randint(len(idx))]
                leftovers.append(idxleft)
                idx.remove(idxleft)
            nb2 = len(idx)//2
            zz = np.array([-1,]*nb2+[1,]*nb2,dtype=np.int_)
            np.random.shuffle(zz)
            assgn[idx] = zz
        nb2 = len(leftovers)//2
        zz = np.array([-1,]*nb2+[1,]*nb2,dtype=np.int_)
        np.random.shuffle(zz)
        assgn[leftovers] = zz
        A.append(assgn.tolist())
    return A

def PairwiseMatch(x, B = 1):
    """
    Draw B assignments of subjects from the optimal pairwise-matched design
    with respect to the Mahalanobis metric
    
    Args:
        x: n by d array of covariates
        B: number of draws
    Returns:
        list of lists of +/-1 denoting assignment
    """
    if not networkx_available:
        raise ImportError('NetworkX not available.')
    s = safeinvcov(x)
    n = len(x)
    n2 = n//2
    D = squareform(pdist(x, 'mahalanobis', VI = s))
    m = nx.matching.max_weight_matching(nx.Graph(-D),True)
    if isinstance(m, set):
        m = {i:j for i,j in m}
    A = []
    for b in range(B):
        zz = np.array([0,1])
        seen = []
        result = np.zeros(n)
        for i in m:
            if i in seen:
                continue
            seen.append(i)
            seen.append(m[i])
            np.random.shuffle(zz)
            result[i] = zz[0]
            result[m[i]] = zz[1]
        A.append([1 if t>0.5 else -1 for t in result])
    return A

def safeinvcov(x):
    """
    Safely invert the sample covariance matrix
    """
    n,d = x.shape
    if d==1:
        return np.array(1./np.var(x)).reshape((1,1))
    else:
        covar = np.cov(x,rowvar=0)
        if np.linalg.det(covar)==0.:
            return scipy.linalg.pinv(covar)
        else:
            return scipy.linalg.inv(covar)

def RerandMR(x, B = 1, p = 0.01):
    """
    Draw B assignments of subjects from the re-randomized design a la Morgan &
    Rubin with (exact) acceptance probability p
    
    Args:
        x: n by d array of covariates
        B: number of draws
        p: acceptance probability
    Returns:
        list of lists of +/-1 denoting assignment
    """
    nrand = int(float(B)/p)
    s = safeinvcov(x)
    n = len(x)
    n2 = n//2
    zz = np.array([-1,]*n2+[1,]*n2)
    l = []
    for i in range(nrand):
        np.random.shuffle(zz)
        y = np.dot(zz,x)/float(n2)
        l.append((np.dot(np.dot(y.T,s),y), zz.tolist()))
    l.sort(key=lambda z: z[0])
    return [[-1 if t>0 else 1 for t in ll[1]] for ll in l[:B]]

def QuadMatch(K, B = 1):
    """
    Return the top B solutions in increasing objective value to the following
    quadratic optimization problem (with symmetry eliminated)
    minimize    u^T K u
    subject to  u in {-1, +1}^n
                sum_i u_i = 0
                u_1 = -1
    
    Args:
        K: d by d array representing a PSD matrix
        B: number of solutions
    Returns:
        list of lists of +/-1 denoting assignment
    """
    if not gurobi_available:
        raise ImportError('Gurobi not available.')
    zs = []
    objs = []
    n=len(K)
    k=n//2
    K1 = np.dot(np.ones((1,n)),K)
    K11 = np.dot(K1,np.ones((n,1)))
    m = gp.Model()
    m.setParam("OutputFlag", 0)
    m.setParam('Threads',1)
    z=[0,]+[m.addVar(lb = 0., ub = 1., vtype=gp.GRB.BINARY) for i in range(n-1)]
    m.update()
    m.setObjective(gp.quicksum(float(4.*K[i,j])*z[i]*z[j] for i in range(1,n) 
        for j in range(1,n))+gp.quicksum(float(-4.*K1[0,i])*z[i]
        for i in range(1,n))+float(K11[0,0]), gp.GRB.MINIMIZE)
    m.addConstr(gp.quicksum(z[1:])==k)
    m.optimize()
    zs.append((0,)+tuple(1 if zz.x>.5 else 0 for zz in z[1:]))
    objs.append(m.getAttr('ObjVal'))
    for b in range(B-1):
        # if (b % (max(1,B//10))==0): print('QuadMatch: retrieving',b,'th solution')
        try:
            m.addConstr(gp.quicksum(zz for zz in z[1:] if zz.x>.5) <= k-1)
            m.optimize()
            zs.append((0,)+tuple(1 if zz.x>.5 else 0 for zz in z[1:]))
            objs.append(m.getAttr('ObjVal'))
        except:
            break
    return [[2*zz-1 for zz in z] for z in zs]

def PSOD(K):
    """
    Return the single top solution to the following quadratic optimization
    problem (with symmetry eliminated)
    minimize    u^T K u
    subject to  u in {-1, +1}^n
                sum_i u_i = 0
                u_1 = -1
    
    Args:
        K: d by d array representing a PSD matrix
    Returns:
        list of +/-1 denoting assignment (first is always -1 so always 
        randomize the sign; see PSODDraw below)
    """
    return QuadMatch(K,1)[0]

def SDPHeuristic(K, us, mixbound = 0.05):
    """
    Solve the semi-definite optimization problem in Algorithm 4.2
    """
    if not mosek_available:
        raise ImportError('Mosek not available.')
    n = len(K)
    (l,v)=np.linalg.eig(K)
    l=np.real(l)
    v=np.real(v)
    l[l<0]=0
    Ksqrt=np.dot(np.dot(v,np.diag(np.sqrt(l))),v.T)
    ZZs = [Matrix.dense(np.dot(np.dot(Ksqrt, np.outer(zz, zz)), Ksqrt).astype(np.float_))
       for zz in us]
    # ZZs=[DenseMatrix(np.dot(np.dot(Ksqrt,np.outer(zz,zz)),
            # Ksqrt).astype(np.float_).tolist()) for zz in us]
    I=Matrix.diag([1.,]*n)
    with Model("match") as M:
        M.setSolverParam('numThreads', 1)
        t = M.variable('t', len(ZZs), Domain.greaterThan(0.0)
            if mixbound==None else Domain.inRange(0.0, mixbound))
        z = M.variable("z",Domain.greaterThan(0.0))
        sum1cons=M.constraint(Expr.sum(t), Domain.equalsTo(1.0)) 
        opnormcons=M.constraint("z>=opnorm", Expr.sub(Expr.mul(z,I),
            reduce(Expr.add, (Expr.mul(t.index(i), ZZs[i])
            for i in range(len(ZZs)))) ), Domain.inPSDCone(n))
        M.objective(ObjectiveSense.Minimize, z)
        M.acceptedSolutionStatus(AccSolutionStatus.Anything) 
        M.solve()
        return (z.level()[0], t.level(), M.getPrimalSolutionStatus())

def MSODHeuristic(K, B):
    """
    Compute the MSOD as per heuristic Algorithm 4.3
    
    Args:
        K: d by d array representing a PSD matrix
        B: number of top solutions to use
    Returns:
        weights for each assignment vector,
        list of lists of +/-1 denoting assignment
        (first of each assignment is always -1 so always randomize the sign;
         see MSODDraw below)
    """
    # print('MSODHeuristic: getting top', B,'solutions')
    us = QuadMatch(K,B)
    # print('MSODHeuristic: solving SDP')
    res = SDPHeuristic(K, us)
    if type(res) != tuple:
        return res
    z2,t,stat2 = res
    return (t, us)

def LinearKernel(x, normalize=True):
    """
    Compute the Gram matrix for the linear kernel
    
    Args:
        x:          n by d array of covariates
        normalize:  whether to normalize the data
    Returns:
        n by n Gram matrix
    """
    if normalize:
        s = safeinvcov(x)
        xc = x - x.mean(0)
        return np.dot(np.dot(xc,s),xc.T)
    else:
        return np.dot(x,x.T)

def GaussianKernel(x, s=1., normalize=True):
    """
    Compute the Gram matrix for the Gaussian kernel
    
    Args:
        x:          n by d array of covariates
        s:          bandwidth
        normalize:  whether to normalize the data
    Returns:
        n by n Gram matrix
    """
    pairwise_dists = squareform(pdist(x, 'mahalanobis')**2 if normalize
                                else pdist(x, 'sqeuclidean'))
    return np.exp(-pairwise_dists / s**2)

def PolynomialKernel(x, deg=2, normalize=True):
    """
    Compute the Gram matrix for the polynomial kernel: K(x, x') = (dot(x,x')/deg + 1)^deg
    
    Args:
        x:          n by d array of covariates
        deg:        degree
        normalize:  whether to normalize the data
    Returns:
        n by n Gram matrix
    """
    if normalize:
        s = safeinvcov(x)
        xc = x - x.mean(0)
        return ((np.dot(np.dot(xc,s),xc.T)/float(deg)+1.)**deg)
    else:
        return ((np.dot(x,x.T)/float(deg)+1.)**deg)

def ExpKernel(x, normalize=True):
    """
    Compute the Gram matrix for the exponential kernel
    
    Args:
        x:          n by d array of covariates
        normalize:  whether to normalize the data
    Returns:
        n by n Gram matrix
    """
    if normalize:
        s = safeinvcov(x)
        xc = x - x.mean(0)
        return np.exp(np.dot(np.dot(xc,s),xc.T))
    else:
        return np.exp(np.dot(x,x.T))

Z = np.array([[0.5,-0.5],[0.5,-0.5]])
C = np.array([1.,-1.])
def RunOneSimulationExperiment(n, d,
    f0 = lambda xx: np.dot(C,xx[:2])+np.dot(xx[:2],np.dot(xx[:2],Z)),
    sigma = 0, its=500, seed = None, boot = 100, gsw_phi=0.5):
    """
    Run one replicate of the experiment in Examples 2.2 and 5.1 and return
    the conditional variance of the estimator under each design.
    
    Args:
        n:     number of subjects (must be even)
        d:     dimension of covariates
        f0:    the conditional expectation function of control outcomes
        sigma: the standard deviation of the residuals
        its:   # of assignments drawn for each design
        seed:  random seed to set (if not None)
        boot:  the number of top solutions to use for MSOD/PSOD
        gsw_phi: trade-off parameter for GSwalk_kernel_many (k=1)
    Returns:
        a dictionary of estimation variance under each of the designs
    Example:
        >>> variances = RunOneSimulationExperiment(30, 2, seed=0)
        >>> variances['gaus_PSOD']
    """
    
    if seed != None: np.random.seed(seed)
    
    n2 = n//2
    x  = np.random.rand(n*d).reshape((n,d))*2-1     # covariates: [-1,1]^d
    
    classic = {
        'comprand':  CompleteRand(n,its),
        'blocking':  BlockOrthant(x,its),
        'pairmatch': PairwiseMatch(x,its),
        'rerandom':  RerandMR(x,its,.01)
    }

    Ks   = {
        'lin':  LinearKernel(x, normalize=False),
        'quad': PolynomialKernel(x, deg=2, normalize=False),
        'gaus': GaussianKernel(x, s=1.),
        'exp':  ExpKernel(x, normalize=False)
    }
    # MSODKs = [ 'gaus', 'exp' ]

    # gswalk_kernel = {}
    gswalk_poly = {}
    if gswalk_available:
        # t0 = time.time()
        # for key, gram in Ks.items():
        #     gswalk_assignments = GSwalk_kernel_many(
        #         GSwalk_poly_aug_fast,
        #         gram,
        #         gsw_phi,
        #         its
        #     )
        #     gswalk_kernel[f'gswalk_kernel_{key}'] = gswalk_assignments
        # # print(f"GSwalk runtime: {time.time()-t0:.3f}s")

        weights = [1.0, 0.5]
        for k in range(1,3):
            if len(weights) < k:
                raise ValueError(f"weights must have length at least {k}")
            gswalk_poly_assignments = GSwalk_poly_aug_many(
                x.T,
                k,
                gsw_phi,
                its,
                weights = weights[:k]
            )
            gswalk_poly[f'gswalk_poly_{k}'] = gswalk_poly_assignments
    
    # t0 = time.time()
    # psods     = {k: PSOD(Ks[k]) for k in Ks}
    # print(f"PSOD runtime: {time.time()-t0:.3f}s")
    # t0 = time.time()
    # msods     = {k: MSODHeuristic(Ks[k], boot) for k in MSODKs}
    # print(f"MSOD runtime: {time.time()-t0:.3f}s")
    
    y0 = np.array(list(map(f0, x))) + sigma*np.random.randn(n)
    
    condvar     = {}
    
    for k, design in list(classic.items()) + list(gswalk_poly.items()):
        us = np.array(design)
        stats = np.dot(us, y0)/n2
        condvar[k] = (stats**2).mean()
    
    # for k in Ks:
    #     u    = np.array(psods[k])
    #     stat = np.dot(u, y0)/n2
    #     condvar[k+'_PSOD'] = (stat**2)
    
    # for k in MSODKs:
    #     t = np.array(msods[k][0])
    #     z = np.array(msods[k][1])
        
    #     stats              = np.dot(z, y0)/n2
    #     condvar[k+'_MSOD']  = np.dot(t, stats**2)/t.sum()
    
    return condvar
