# Necessary functions to run the experiments

# Import necessary packages
import numpy as np
from tqdm import tqdm
import os

# pytorch
import torch
from torch.optim import Adam
from torch import nn

# BERT from Huggingface
from transformers import BertTokenizer
from transformers import BertModel


# Sigmoid function
def sigmoid(x):
    sig = 1 / (1 + np.exp(-x))
    return sig

# New function to generate data (X,y)
def generate_data(n,e0,e1,b,t):

    # Mean and variance for features x1,x2,x3
    mu = np.array([1,-1,0])
    var = np.array([[1,0.05,0],[0.05,1,0],[0,0,0.05]])
    X = np.random.multivariate_normal(mu,var,n)

    # Function from x3 to A
    a = ((X[:,2] + b) >= 0).astype('float')

    # Function from x1 and x2 to A
    eps_0 = np.random.normal(0,e0,n)
    eps_1 = np.random.normal(0,e1,n)

    # add noise to a = 0 or a = 1
    noise_a0 = eps_0*(a==0)
    noise_a1 = eps_1*(a==1)

    # Generate y depending on experiment
    y = (sigmoid(X[:,0] + X[:,1] + X[:,2] + noise_a0 + noise_a1) >= t).astype('float')
    
    return (X, a, y)

# Function to generate y_prob using random coefficients
def generate_y_hat(X,coeffs,exp,t):
    if exp == 1:
        y_prob = sigmoid(np.dot(X[:,:2],coeffs[:2]))
    else:
        y_prob = sigmoid(np.dot(X,coeffs))
    y_hat = (y_prob >= t).astype('float')

    return (y_prob, y_hat)

    return alpha_11, alpha_01, alpha_10, alpha_00

# Generate a_hat for experiment 2
def generate_a_hat(x3, b, e, imbalance = False):
    if imbalance == True:
        noise = e
    else:
        noise = np.random.normal(0,e,len(x3))
    a_hat = ((x3 + b + noise) >= 0).astype('float')
    return a_hat

# Dataset class for BERT
class Dataset(torch.utils.data.Dataset):

    def __init__(self, df):

        self.labels = [label for label in df['a']]
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        self.texts = [tokenizer(text, 
                               padding='max_length', max_length = 5, truncation=True,
                                return_tensors="pt") for text in df['long_name']]
        self.remain_data = [df[['age','overall','y','group']].iloc[idx] for idx in range(df.shape[0])]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        
        batch_texts = self.texts[idx]
        batch_y = torch.tensor(self.labels[idx])
        batch_rest = torch.tensor(self.remain_data[idx])

        return batch_texts, batch_y, batch_rest

# Class for classifier
class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.sigmoid(linear_output)

        return final_layer


def train(model, train_data, val_data, learning_rate, epochs):
    
    batch_sz = 2

    train, val = Dataset(train_data), Dataset(val_data)

    train_dataloader = torch.utils.data.DataLoader(train, batch_size=batch_sz, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=batch_sz)

    os.environ['CUDE_VISIBLE_DEVICES'] = '0'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
    model.to(device)

    criterion = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr= learning_rate)


    for epoch_num in range(epochs):

        model.train()
        total_loss_train = 0
        total_tp_train = 0

        for train_input, train_label, _ in tqdm(train_dataloader):

            train_label = train_label.to(device).float()
            mask = train_input['attention_mask'].to(device)
            input_id = train_input['input_ids'].squeeze(1).to(device)

            output = model(input_id, mask).reshape(1,-1)[0] 
            batch_loss = criterion(output, train_label)
            batch_tp = torch.sum((output >= 0.5) == train_label)
            
            total_tp_train += batch_tp.item()
            total_loss_train += batch_loss.item()

            model.zero_grad()
            batch_loss.backward()
            optimizer.step()
                
        total_loss_val = 0
        total_tp_val = 0

        with torch.no_grad():

            model.eval()

            for val_input, val_label, _ in val_dataloader:

                val_label = val_label.to(device).float()
                mask = val_input['attention_mask'].to(device)
                input_id = val_input['input_ids'].squeeze(1).to(device)

                output = model(input_id, mask).reshape(1,-1)[0]
                batch_loss = criterion(output, val_label)
                batch_tp = torch.sum((output >= 0.5) == val_label)
                    
                total_tp_val += batch_tp.item()    
                total_loss_val += batch_loss.item()
                    
        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / (len(train_data)/batch_sz): .3f} \
            | Train Accuracy: {total_tp_train / len(train_data): .3f} \
            | Val Loss: {total_loss_val / (len(val_data)/batch_sz): .3f} \
            | Val Accuracy: {total_tp_val / len(val_data): .3f}')