import torch
from model import *
from dataloader import *
import numpy as np
from torch.utils.data import DataLoader
from pytorch_models.utils.logger import *
import random
from preprocess import *
import json

def load_dataset(session_args, logger=None):
    if session_args.grammar[:6] == "tomita":
        gnum = session_args.grammar[7:]
        if session_args.mtype in ["classification"]:
            train_dataset = TomitaDatasetPartialClassification(gnum, session_args.dataset, "train")
            val_dataset = TomitaDatasetPartialClassification(gnum, session_args.dataset, "val")
        
            train_data = DataLoader(train_dataset, batch_size=session_args.batch_size, shuffle=session_args.shuffle_data, num_workers = 0)
            val_data = DataLoader(val_dataset, batch_size=session_args.valid_batch_size, shuffle=False, num_workers = 0)

    elif session_args.grammar[:4] == "dyck":
        gnum = session_args.grammar[5:]
        if session_args.mtype in ["classification"]:
            train_pos = "../data/dyck/" + session_args.dataset + "_pos_train.npz"
            train_neg = "../data/dyck/" + session_args.dataset + "_neg_train.npz"
            val_pos = "../data/dyck/" + session_args.dataset + "_pos_val.npz"
            val_neg = "../data/dyck/" + session_args.dataset + "_neg_val.npz"

            train_dataset = DyckDataLoader(pos_file=train_pos, neg_file=train_neg)
            val_dataset = DyckDataLoader(pos_file=val_pos, neg_file=val_neg)
        
            train_data = DataLoader(train_dataset, batch_size=session_args.batch_size, shuffle=session_args.shuffle_data, num_workers = 0)
            val_data = DataLoader(val_dataset, batch_size=session_args.valid_batch_size, shuffle=False, num_workers = 0)


    
    if logger is not None:
        logger.info("train data size : {0} validation data size : {1}".format(len(train_dataset), len(val_dataset)))


    return train_data, val_data


def load_test_dataset(session_args, logger=None):
    if session_args.grammar[:6] == "tomita":
        gnum = session_args.grammar[7:]
        if session_args.mtype in ["classification"]:
            test_dataset = TomitaDatasetPartialClassification(gnum, session_args.dataset, "test", nbin = session_args.nbin)        
            test_data = DataLoader(test_dataset, batch_size=session_args.batch_size, shuffle=False, num_workers = 0)
            
    elif session_args.grammar[:4] == "dyck":
        gnum = session_args.grammar[5:]
        if session_args.mtype in ["classification"]:
            test_pos = "../data/dyck/" + session_args.dataset + "_pos_test.npz"
            test_neg = "../data/dyck/" + session_args.dataset + "_neg_test.npz"
            test_dataset = DyckDataLoader(pos_file=test_pos, neg_file=test_neg)
            test_data = DataLoader(test_dataset, batch_size=session_args.batch_size, shuffle=False, num_workers = 0)      

    if logger is not None:
        logger.info("test data size : {0} ".format(len(test_dataset)))

    return test_data


def load_model(session_args):
    if session_args.model == "rnn":
        if session_args.mtype == "classification":
            model = ClassificationModel( state_dim = session_args.state_dim, 
                                              num_layers = session_args.nlayers, 
                                              rnn_type = session_args.rnn_cell, 
                                              dropout = session_args.dropout,
                                              vocab_size = session_args.vocab_size, 
                                              hinit = session_args.hinit,
                                              partial_strings = session_args.partial_strings ).to(session_args.device)
            model = ClassificationModelAdaptor(model, session_args.device)
        else:
            raise NotImplementedError


    return model


def load_model_for_test(session_args):
    if session_args.model == "rnn":
        if session_args.mtype == "classification":
            model = ClassificationModel( state_dim = session_args.state_dim, 
                                              num_layers = session_args.nlayers, 
                                              rnn_type = session_args.rnn_cell, 
                                              dropout = session_args.dropout,
                                              vocab_size = session_args.vocab_size, 
                                              hinit = session_args.hinit,
                                              partial_strings = session_args.partial_strings ).to(session_args.device)
            model = ClassificationModelAdaptor(model, session_args.device)
        else:
            raise NotImplementedError
    
    checkpoint = torch.load(session_args.load_model)
    model.load_state_dict(checkpoint['model_state'])
    model_iter = checkpoint["iter"]
    model.eval()
    return model, model_iter


def load_model_for_lstar(session_args, alphabet, vocab):
    if session_args.model == "rnn":
        if session_args.mtype == "classification":
            model = ClassificationModel( state_dim = session_args.state_dim, 
                                              num_layers = session_args.nlayers, 
                                              rnn_type = session_args.rnn_cell, 
                                              dropout = session_args.dropout,
                                              vocab_size = session_args.vocab_size, 
                                              hinit = session_args.hinit,
                                              partial_strings = session_args.partial_strings ).to(session_args.device)
            checkpoint = torch.load(session_args.load_model)
            model.load_state_dict(checkpoint['model_state'])
            model.eval()
            lstar_model = ClassificationModelLSTARAdaptor(model, alphabet, vocab, session_args.device, start = True, end = True)
            test_model = ClassificationModelAdaptor(model, session_args.device)
        else:
            raise NotImplementedError
    
    
    return lstar_model, test_model

def get_dyck_vocab(gnum):
    with open("../data/dyck/dyck_{}_vocab.json".format(gnum), "r+") as ff:
        vocab = json.load(ff)
    return vocab

def get_vocab(session_args):
    if session_args.grammar[:6] == "tomita":
        vocab = get_tomita_vocab()
    elif session_args.grammar[:4] == "dyck":
        gnum = session_args.grammar[5]
        vocab = get_dyck_vocab(gnum)

    return vocab


def get_loss_func(args):
    loss_type = args.loss_func
    if loss_type == "ce":
        return torch.nn.CrossEntropyLoss()
    if loss_type == "bce":
        return torch.nn.BCELoss()


def get_random_string(length):
    ss = "".join([random.choice(["0", "1"]) for ii in range(length)])
    return ss


def load_excluded_data(flist):
    excluded_strings = []
    for fpath in flist:
        with open(fpath, "r+") as ff:
            data = ff.read().splitlines()
            excluded_strings.extend(data)

    return excluded_strings


def create_early_stopping_criteria(valid_acc):
    def early_stopping(valid_loss, valid_metrics):
        if valid_metrics["classification_avg"] >= valid_acc:
            return True
        return False
    
    return early_stopping