import gurobipy as gp
from gurobipy import GRB
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingRegressor
import matplotlib.pyplot as plt


def min_max_feas(n,constr_coefs,constr_p0): 
    # min, max constr disparity over policies 
    m = gp.Model("mip1")
    pi = m.addVars(n, vtype=GRB.CONTINUOUS) # nonnegative
    m.setObjective( gp.quicksum( constr_coefs[i]*pi[i] for i in range(n) ) + constr_p0 , GRB.MAXIMIZE)
    m.addConstrs((pi[i] <= 1 for i in range(n)))
    m.optimize()
    return m.ObjVal


def train_propensities_outcomes(x,t,y,propensitymodel,outcomemodel):
    X_train, X_test, t_train, t_test = train_test_split(x, t, test_size=0.2, random_state=42)
    p_t1_xra = propensitymodel()
    p_t1_xra.fit(X_train, t_train)
    print(p_t1_xra.score(X_test, t_test))
    # double check test error
    p_t1_xra.fit(x, t)
    y_t1 = outcomemodel()
    y_t1.fit(x[(t == 1) ], y[(t == 1) ])
    y_t0 = outcomemodel()
    y_t0.fit(x[(t == 0) ], y[(t == 0)])
    return [p_t1_xra, y_t0, y_t1]

def train_propensities_separate(x,r,t,A,As, propensitymodel): 
    p_r1 = propensitymodel()
    p_r0 = propensitymodel()
    p_r1.fit(x[(r==1)],t[r==1])
    p_r0.fit(x[(r==0)],t[r==0])
    return [p_r0,p_r1]

# plotting  


def plot_tau_lr(tau,A,As,Anames,colors,Aname=None): 
    plt.figure(figsize=((3,3)))
    for i,g in enumerate(As): 
        plt.hist(tau[A==g],alpha=0.5,density=True,label=Anames[i],color=colors[i])
        plt.title(r'$\tau$ by '+Aname ); plt.legend()
    return

def fit_compliances_lr_ontau(A,As,tau,R,T):
    # fit P(T| \tau)
    ps_R1 = [None]*len(As); ps_R0 = [None]*len(As); 
    for i,g in enumerate(As): 
        m = LogisticRegression()
        m.fit(tau[(R == 1) & (A==g)].reshape(-1, 1), T[(R == 1) & (A==g)])
        ps_R1[i] = m;    m = LogisticRegression()
        m.fit(tau[(R == 0) & (A==g)].reshape(-1, 1), T[(R == 0) & (A==g)])
        ps_R0[i] = m
    return [ps_R0,ps_R1]

def plot_compliances(p_r,x,A,As,R,colors,Aname=None,labels=None): 
    plt.figure(figsize=((3,3)))
    plt.tight_layout()
    for i,g in enumerate(As): 
        compliance_diff = p_r[1].predict_proba(x[A==g])[:,1] - p_r[0].predict_proba(x[A==g])[:,1]
        if labels is not None: 
            plt.hist(compliance_diff,alpha=0.5,density=True,label=labels[i],color=colors[i])
        else: 
            plt.hist(compliance_diff,alpha=0.5,density=True,label=As[i],color=colors[i])


    plt.title(r'$p_{1\mid 1,a}-p_{1\mid 0,a}$, a= '+Aname); plt.legend()
    return

def plot_compliances_lr(p_rs,A,As,R,tau,colors,Aname=None): 
    plt.figure(figsize=((3,3)))
    for i,g in enumerate(As): 
        compliance_diff = p_rs[1][i].predict_proba((tau[A==g]).reshape(-1, 1))[:,1] - p_rs[0][i].predict_proba((tau[A==g]).reshape(-1, 1))[:,1]
        plt.hist(compliance_diff,alpha=0.5,density=True,label=g,color=colors[i])
        plt.xlabel(r'distn. of $\tau$'); plt.title(r'$p_{1a}-p_{0a}$, a= '+Aname); plt.legend()

    plt.figure(figsize=((3,3)))
    for i,g in enumerate(As): 
        compr1 = p_rs[1][i].predict_proba((tau[A==g]).reshape(-1, 1))[:,1]
        plt.hist(compr1.reshape(-1, 1),alpha=0.5,density=True,label=g,color=colors[i])
        plt.title(r'distn. of $p_{1a}$, a= '+Aname); plt.legend()
        
    plt.figure(figsize=((3,3)))
    for i,g in enumerate(As): 
        compr0 = p_rs[0][i].predict_proba((tau[A==g]).reshape(-1, 1))[:,1]
        plt.hist(compr0.reshape(-1, 1),alpha=0.5,density=True,label=g,color=colors[i])
        plt.title(r'distn. of $p_{0a}$, a= '+Aname); plt.legend()



def plot_compliance_on_tau(p_rs,A,As,R,tau,colors,Aname=None,labels=None): 
    plt.figure(figsize=((3,3)))
    for i,g in enumerate(As): 
        compliance_diff = p_rs[1][i].predict_proba((tau[A==g]).reshape(-1, 1))[:,1] - p_rs[0][i].predict_proba((tau[A==g]).reshape(-1, 1))[:,1]
        
        if labels is not None: 
            plt.scatter(tau[A==g],compliance_diff,label=labels[i],color=colors[i],s=0.25)
        else: 
            plt.scatter(tau[A==g],compliance_diff,label=g,color=colors[i],s=0.25)

        plt.xlabel(r'$s$'); plt.title(r'$p_{1a}-p_{0a}$ vs $s$, a= '+Aname); plt.legend()


def setup_obj_coefs(mu_t0, mu_t1, p_t_mid_r, params):
    obj_coefs = np.asarray([p_t_mid_r[:,r]*(c11-c10)*mu_t1 + (1-p_t_mid_r[:,r])*(c01-c00)*mu_t0  for r in range(d_r)]).T
    obj_constants = np.mean(np.asarray([ p_t_mid_r[:,r]*c10 +(1-p_t_mid_r[:,r])*c00  for r in range(d_r)]).T,axis=0)
    constr_coefs = np.asarray([(ct*p_t_mid_r[:,r] * ((A_==a_s[0])/p_as[0]  - (A_==a_s[1])/p_as[1])) for r in range(d_r)]).T 
    constr_const = np.mean(ct*p_t_mid_r[:,0] * ((A_==a_s[0])/p_as[0]  - (A_==a_s[1])/p_as[1]))
    return [obj_coefs, obj_constants, constr_coefs, constr_const]


def min_equality(n, obj_coefs, obj_constants, constr_coefs, constr_const, eps = 0.0001, obj_dir = 'min'): 
    m = gp.Model("mip1")
    pi = m.addVars(n,2, vtype=GRB.CONTINUOUS) # nonnegative
    if obj_dir == 'min': 
        m.setObjective(gp.quicksum(1.0/n*obj_coefs[i,0]*pi[i,0]+1.0/n*obj_coefs[i,1]*pi[i,1] for i in range(n))\
                  , GRB.MINIMIZE)
    else: 
        m.setObjective(gp.quicksum(1.0/n*obj_coefs[i,0]*pi[i,0]+1.0/n*obj_coefs[i,1]*pi[i,1] for i in range(n))\
                  , GRB.MAXIMIZE)
    # abs formulation 
    
    # | E[ T(\pi)|A=a] - E[ T(\pi)|A=b] | < \eps
    m.addConstr(gp.quicksum( 1.0/n*constr_coefs[i,0]*pi[i,0] + 1.0/n*constr_coefs[i,1]*pi[i,1]  for i in range(n) )  <= eps, name='dp1')
    m.addConstr(gp.quicksum( 1.0/n*constr_coefs[i,0]*pi[i,0] + 1.0/n*constr_coefs[i,1]*pi[i,1]  for i in range(n) ) >= -eps, name='dp2')

    # m.addConstr(gp.quicksum( -1*1.0/n*constr_coefs[i,0]*pi[i,0] + -1*1.0/n*constr_coefs[i,1]*pi[i,1]  for i in range(n) ) <= eps, name='dp2')
    # 0 < pi < 1
    m.addConstrs((pi[i,j] <= 1 for i in range(n) for j in range(d_r)))
    # pi_0 + pi_1 = 1 
    m.addConstrs( gp.quicksum( pi[i,j] for j in range(d_r) ) == 1 for i in range(n) )
    m.update()
    m.optimize()
    opt = m.ObjVal + obj_constants[0] + obj_constants[1]
    print(opt)
    return [m, pi]




