import numpy as np
import pandas as pd
import torch
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
from snorkel.labeling.model import LabelModel
from scipy.stats import norm

import os
import sys
from label_functions import *
sys.path.append("..") 
from utils import DefineDevice, FindRowIndex, GetP_Y_Z, SuppressPrints
from tensors import GetLogLossTensor, GetPrecisionRecallTensor
from models import LogReg, TrainModelCI, CIRisk
from bound_expectation import BoundExpectation
from sentence_transformers import SentenceTransformer
from os import walk
from sklearn.model_selection import train_test_split
import math


def labeler(label):
    if label=='noHate': return 0
    else: return 1
    
def run_exp3(train_label_model, threshs, device, random_state, verbose=False):

    ### Some fixed params ###
    tol = 1e-4
    max_epochs = 1e4
    weight_decays_ws = np.logspace(0,-3,10)
    conf = .95
    approx_error = .001

    ### Data ###
    if verbose: print("\n >>>>>> Data prep <<<<<<")
        
    ## Loading language models and bank of terms used to obtain WLs ##
    #Loading BERT for hate speech detection (https://huggingface.co/IMSyPP/hate_speech_en)
    bert_tokenizer = BertTokenizer.from_pretrained('IMSyPP/hate_speech_en')
    bert_model = BertForSequenceClassification.from_pretrained('IMSyPP/hate_speech_en').to(device)

    #Loading Roberta for toxicity detection
    #https://huggingface.co/s-nlp/roberta_toxicity_classifier?text=I+like+you.+I+love+you
    roberta_tokenizer = RobertaTokenizer.from_pretrained('s-nlp/roberta_toxicity_classifier')
    roberta_model = RobertaForSequenceClassification.from_pretrained('s-nlp/roberta_toxicity_classifier').to(device)

    #Loading hate/offensive terms
    terms = list(pd.read_csv('../data/hate-speech-and-offensive-language/lexicons/refined_ngram_dict.csv').ngram)
    
    
    ## Loading data ##
    #https://github.com/Vicomtech/hate-speech-dataset
    train_path = '../data/hate-speech-dataset/sampled_train'
    train_filenames = next(walk(train_path), (None, None, []))[2] 

    test_path = '../data/hate-speech-dataset/sampled_test'
    test_filenames = next(walk(test_path), (None, None, []))[2]

    annot = pd.read_csv('../data/hate-speech-dataset/annotations_metadata.csv')


    ## Preparing data ##
    train = []
    test = []
    Y_train = []
    Y_test = []
  
    for file in train_filenames:
        with open(train_path + '/' + file, 'r') as f:
            train.append(f.read().rstrip())
        Y_train.append(labeler(annot.loc[annot.file_id==file.replace('.txt','')]['label'].iloc[0]))

    for file in test_filenames:
        with open(test_path + '/' + file, 'r') as f:
            test.append(f.read().rstrip())
        Y_test.append(labeler(annot.loc[annot.file_id==file.replace('.txt','')]['label'].iloc[0]))


    feature_extractor = SentenceTransformer('all-MiniLM-L6-v2')

    X_train = torch.stack([torch.tensor(feature_extractor.encode([t])[0]) for t in tqdm(train, disable=not verbose)]).double().to(device)
    X_test = torch.stack([torch.tensor(feature_extractor.encode([t])[0]) for t in tqdm(test, disable=not verbose)]).double().to(device)
    Y_train = torch.tensor(Y_train).to(device)
    Y_test = torch.tensor(Y_test).to(device)
    L_train = torch.tensor([[textblob_sentiment_lf(text), 
                             terms_lf(text, terms),
                             bert_hate_lf(text, bert_model, bert_tokenizer, device), 
                             roberta_toxicity_lf(text, roberta_model, roberta_tokenizer, device)] for text in tqdm(train, disable=not verbose)])
    L_test = torch.tensor([[textblob_sentiment_lf(text), 
                             terms_lf(text, terms),
                             bert_hate_lf(text, bert_model, bert_tokenizer, device), 
                             roberta_toxicity_lf(text, roberta_model, roberta_tokenizer, device)] for text in tqdm(test, disable=not verbose)])

    X_train, X_val,\
    Y_train, Y_val,\
    L_train, L_val = train_test_split(X_train, Y_train, L_train, test_size=.1, random_state=random_state)

    X_test = (X_test-X_train.mean(axis=0))/X_train.std(axis=0)
    X_val = (X_val-X_train.mean(axis=0))/X_train.std(axis=0)
    X_train = (X_train-X_train.mean(axis=0))/X_train.std(axis=0)

    Y = torch.hstack((Y_train, Y_val, Y_test)) 
    X = torch.vstack((X_train, X_val, X_test)) 

    ## Weak labels ##
    L = torch.vstack((L_train, L_val, L_test))

    ## Creating Z from L ##
    set_Z_aux = torch.unique(L, dim=0) 
    Z_train = torch.tensor([FindRowIndex(set_Z_aux, l) for l in L_train]) #used to train ws model
    Z_val = torch.tensor([FindRowIndex(set_Z_aux, l) for l in L_val]) #used to validate hyperpar. of ws model
    Z_test = torch.tensor([FindRowIndex(set_Z_aux, l) for l in L_test]) #used to validate hyperpar. of ws model
    Z = torch.tensor([FindRowIndex(set_Z_aux, l) for l in L])

    ## Defining supp(Y) and supp(Z) ##
    set_Y_aux = torch.unique(Y_train).tolist() #in this exp, it should be [0,1]
    set_Y = torch.tensor(range(len(set_Y_aux)))
    set_Z = torch.tensor(range(set_Z_aux.shape[0]))

    
    ### Estimating P_Y_Z ###
    if verbose: print("\n >>>>>> Estimating P_Y_Z <<<<<<")
    if train_label_model:
        label_model = LabelModel(cardinality=set_Y.shape[0], verbose=False)
        with SuppressPrints():
            label_model.fit(L_train = L, n_epochs=1000, class_balance=[1-Y.float().mean().item(),Y.float().mean().item()], seed=random_state)
        P_Y_Z = torch.tensor(label_model.predict_proba(L=set_Z_aux)).T   
    else:
        P_Y_Z = GetP_Y_Z(Y, Z, set_Y, set_Z)
    P_Y_Z = P_Y_Z.double().to(device)

    
    ### Training WS model ###
    if verbose: print("\n >>>>>> Validating and training end model <<<<<<")
    val_losses = []
    for weight_decay in tqdm(weight_decays_ws, disable=not verbose):
        model = LogReg(X_train.shape[1], set_Y.shape[0]).double().to(device)
        model = TrainModelCI(model, X_train, Z_train, set_Z, P_Y_Z, weight_decay=weight_decay, tol=tol, max_epochs=max_epochs, device=device)
        val_losses.append(CIRisk(GetLogLossTensor(model, X_val), Z_val, set_Z, P_Y_Z, device).item())

    model_ws = LogReg(X_train.shape[1], set_Y.shape[0]).double().to(device)
    model_ws = TrainModelCI(model_ws, X_train, Z_train, set_Z, P_Y_Z, weight_decay=weight_decays_ws[np.argmin(val_losses)], tol=tol, max_epochs=max_epochs, device=device)


    ### Computing bounds for recall and precision ###
    if verbose: print("\n >>>>>> Computing bounds for accuracy <<<<<<")
    bounds = {}
    precrec = {}

    bounds = {}
    bounds['centers'] = {}
    bounds['ics'] = {}

    precrecs = {}
    precrecs['centers'] = {}
    precrecs['ics'] = {}

    for bound in ['lower', 'upper']:
        bounds['centers'][bound] = {}
        bounds['ics'][bound] = {}
        precrecs['centers'] = {}
        precrecs['ics'] = {}

        for target in ['recall', 'precision']:
            bounds['centers'][bound][target] = []
            bounds['ics'][bound][target] = []
            precrecs['centers'][target] = []
            precrecs['ics'][target] = []

        for thresh in tqdm(threshs, disable = not verbose):
            tensor = GetPrecisionRecallTensor(model_ws, X_test, thresh=thresh) 

            temp = BoundExpectation(bound, tensor,
                                    Z_test, set_Z, P_Y_Z,
                                    conf=conf, epsilon=approx_error/np.log(set_Y.shape[0]),
                                    tol=tol, max_epochs=max_epochs, device=device)


            for target in ['recall', 'precision']:
                if target == 'recall':
                    p = Y.float().mean().item()
                else:
                    p = (model_ws(X)[:,1]>=thresh).float().mean().item()
                    if p==0: p = np.nan 
                    else: p=p
                bounds['centers'][bound][target].append(temp[0]/p)
                delta = (temp[0] - temp[1][0])/p
                bounds['ics'][bound][target].append([temp[0]/p-delta, temp[0]/p+delta])

                # precrecs does not really depend on 'bound'
                precrec = ((model_ws(X_test)[:,1]>=thresh).float() * Y_test.float()).mean().item()
                n = X_test.shape[0]
                delta = (norm.ppf((conf+1)/2)*((precrec*(1-precrec))/n)**.5)/p
                precrecs['centers'][target].append(precrec/p)
                precrecs['ics'][target].append([precrec/p-delta, precrec/p+delta])

    return bounds, precrecs
