import numpy as np
import pandas as pd
import torch
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
from snorkel.labeling.model import LabelModel
import os
import sys
sys.path.append("..") 
from utils import DefineDevice, FindRowIndex, GetP_Y_Z, SuppressPrints
from tensors import GetLogLossTensor, GetAccuracyTensor
from models import LogReg, TrainModelCI, CIRisk
from bound_expectation import BoundExpectation
from scipy.stats import norm
    
    
def run_exp1(dataset, train_label_model, k, threshs, device, random_state, verbose=False):

    ### Some fixed params ###
    tol = 1e-4
    max_epochs = 1e4
    weight_decays_ws = np.logspace(0,-3,10)
    conf = .95
    approx_error = .001

    ### Loading data ###
    if verbose: print("\n >>>>>> Data prep <<<<<<")
    with open('../data/wrench_class/' + dataset + '/processed_data.pickle', 'rb') as handle:
        dic = pickle.load(handle)

    # Features and true labels #
    X_train, X_val, X_test, Y_train, Y_val, Y_test, L_train, L_val, L_test = dic['X_train'].to(device), dic['X_val'].to(device), dic['X_test'].to(device),\
                                                                             dic['Y_train'].to(device), dic['Y_val'].to(device), dic['Y_test'].to(device),\
                                                                             dic['L_train'], dic['L_val'], dic['L_test']
    # Weak labels #
    L_train, L_val, L_test = torch.tensor(L_train[:,:k]), torch.tensor(L_val[:,:k]), torch.tensor(L_test[:,:k])
    L = torch.vstack((L_train, L_val, L_test))

    # Creating Z from L #
    set_Z_aux = torch.unique(L, dim=0) 
    Z_train = torch.tensor([FindRowIndex(set_Z_aux, l) for l in L_train]) #used to train ws model
    Z_val = torch.tensor([FindRowIndex(set_Z_aux, l) for l in L_val]) #used to validate hyperpar. of ws model
    Z_test = torch.tensor([FindRowIndex(set_Z_aux, l) for l in L_test]) #used to validate hyperpar. of ws model
    Z = torch.tensor([FindRowIndex(set_Z_aux, l) for l in L])
    Y = torch.hstack((Y_train, Y_val, Y_test)) 

    # Defining supp(Y) and supp(Z) #
    set_Y_aux = torch.unique(Y_train).tolist() #in this exp, it should be [0,1]
    set_Y = torch.tensor(range(len(set_Y_aux)))
    set_Z = torch.tensor(range(set_Z_aux.shape[0]))

    ### Estimating P_Y_Z ###
    if verbose: print("\n >>>>>> Estimating P_Y_Z <<<<<<")
    if train_label_model:
        label_model = LabelModel(cardinality=set_Y.shape[0], verbose=False)
        with SuppressPrints():
            p=Y.float().mean().item()
            label_model.fit(L_train = L, n_epochs=1000, class_balance=[1-p,p], seed=random_state)
        P_Y_Z = torch.tensor(label_model.predict_proba(L=set_Z_aux)).T   
    else:
        P_Y_Z = GetP_Y_Z(Y, Z, set_Y, set_Z)
    P_Y_Z = P_Y_Z.double().to(device)

    ### Training WS model ###
    if verbose: print("\n >>>>>> Validating and training end model <<<<<<")
    val_losses = []
    for weight_decay in tqdm(weight_decays_ws, disable=not verbose):
        model = LogReg(X_train.shape[1], set_Y.shape[0]).double().to(device)
        model = TrainModelCI(model, X_train, Z_train, set_Z, P_Y_Z, weight_decay=weight_decay, tol=tol, max_epochs=max_epochs, device=device)
        val_losses.append(CIRisk(GetLogLossTensor(model, X_val), Z_val, set_Z, P_Y_Z, device).item())

    model_ws = LogReg(X_train.shape[1], set_Y.shape[0]).double().to(device)
    model_ws = TrainModelCI(model_ws, X_train, Z_train, set_Z, P_Y_Z, weight_decay=weight_decays_ws[np.argmin(val_losses)], tol=tol, max_epochs=max_epochs, device=device)


    ### Computing bounds for accuracy ###
    if verbose: print("\n >>>>>> Computing bounds for accuracy <<<<<<")
    bounds = {}
    bounds['centers'] = {}
    bounds['ics'] = {}
    
    accs = {}
    accs['centers'] = {}
    accs['ics'] = {}

    for bound in ['lower', 'upper']:
        bounds['centers'][bound] = []
        bounds['ics'][bound] = []
        accs['centers'][bound] = []
        accs['ics'][bound] = []
        
        for thresh in tqdm(threshs, disable=not verbose):
            tensor = GetAccuracyTensor(model_ws, X_test, thresh=thresh) 

            temp = BoundExpectation(bound, tensor,
                                    Z_test, set_Z, P_Y_Z,
                                    conf=conf, epsilon=approx_error/np.log(set_Y.shape[0]),
                                    tol=tol, max_epochs=max_epochs, device=device)

            bounds['centers'][bound].append(temp[0])
            bounds['ics'][bound].append(temp[1])
            
            acc = ((model_ws(X_test)[:,1]>thresh).long()==Y_test).float().mean().item()
            n = X_test.shape[0]
            delta = norm.ppf((conf+1)/2)*((acc*(1-acc))/n)**.5
            accs['centers'][bound].append(acc)
            accs['ics'][bound].append([acc-delta, acc+delta])

        
        bounds['centers'][bound] = np.array(bounds['centers'][bound])
        bounds['ics'][bound] = np.array(bounds['ics'][bound])
        accs['centers'][bound] = np.array(accs['centers'][bound])
        accs['ics'][bound] = np.array(accs['ics'][bound])

    return bounds, accs