import torch
from model import *
from dataloader import *
import numpy as np
from torch.utils.data import DataLoader
from pytorch_models.utils.logger import *
import random
from preprocess import *
from lstar_extraction.DFA import DFA

def save_lstar_dfa(dfa, dfa_path):
    state_dict = dfa.state_dict()
    with open(dfa_path, "wb") as ff:
        pickle.dump(state_dict, ff)


def load_lstar_dfa(dfa_path):
    with open(dfa_path, "rb") as ff:
        state_dict = pickle.load(ff)

    dfa = DFA()
    dfa.load_state_dict(state_dict)
    return dfa


def words_from_dataset(dataset, vocab, remove_start = False, remove_end = False):
    int2chr = {i:c for c, i in vocab.items()}
    words = []
    y = []
    for batch_data in dataset:
        input_seq, out_seq, seq_len = batch_data
        if remove_start:
            input_seq = input_seq[:, 1:]
            seq_len = seq_len - 1

        for idx in range(input_seq.shape[0]):
            # removes pads and <END>
            sl = seq_len[idx]
            y_pos = seq_len[idx]
            if remove_end:
                sl = sl - 1
                y_pos = y_pos - 1
            
            x = input_seq[idx][:sl].tolist()
            word = "".join([int2chr[i] for i in x])
            words.append(word)
            if out_seq.dim() > 1:
                pred = out_seq[idx][y_pos].item()
            else:
                pred = out_seq[idx].item()
            y.append(pred==1)

    return words, y


def test_lstar_dfa(words, y_true, dfa):
    y_pred = []
    for w in words:
        y_pred.append(dfa.classify_word(w))

    y_true = torch.tensor(y_true)
    y_pred = torch.tensor(y_pred)

    pos_mask = y_true == True
    neg_mask = y_true == False

    tp = torch.sum(y_pred*pos_mask)
    fp = torch.sum(y_pred*neg_mask)
    #tn = torch.sum((~y_pred)*neg_mask)
    fn = torch.sum((~y_pred)*pos_mask)

    accuracy = torch.mean((y_pred == y_true).type(torch.float))
    precision = tp/(tp + fp)
    recall = tp/(tp + fn)
    f1 = 2*tp/(2*tp + fp + fn)

    return precision, recall, accuracy, f1, fp, fn