from sklearn.utils import shuffle

import pickle
import pandas


def read_TREC():
    data = {}

    def read(mode):
        x, y = [], []

        with open("data/TREC/TREC_" + mode + ".txt", "r", encoding="utf-8") as f:
            for line in f:
                if line[-1] == "\n":
                    line = line[:-1]
                y.append(line.split()[0].split(":")[0])
                x.append(line.split()[1:])

        x, y = shuffle(x, y, random_state=11)

        if mode == "train":
            dev_idx = len(x) // 10
            data["dev_x"], data["dev_y"] = x[:dev_idx], y[:dev_idx]
            data["train_x"], data["train_y"] = x[dev_idx:], y[dev_idx:]
        else:
            data["test_x"], data["test_y"] = x, y

    read("train")
    read("test")

    return data


def read_MR():
    data = {}
    x, y = [], []

    with open("data/MR/rt-polarity.pos", "r", encoding="utf-8") as f:
        for line in f:
            if line[-1] == "\n":
                line = line[:-1]
            x.append(line.split())
            y.append(1)

    with open("data/MR/rt-polarity.neg", "r", encoding="utf-8") as f:
        for line in f:
            if line[-1] == "\n":
                line = line[:-1]
            x.append(line.split())
            y.append(0)

    x, y = shuffle(x, y, random_state=11)
    dev_idx = len(x) // 10 * 8
    test_idx = len(x) // 10 * 9

    data["train_x"], data["train_y"] = x[:dev_idx], y[:dev_idx]
    data["dev_x"], data["dev_y"] = x[dev_idx:test_idx], y[dev_idx:test_idx]
    data["test_x"], data["test_y"] = x[test_idx:], y[test_idx:]

    return data

def read_pkl(f):
    data = {}
    corpus = pandas.read_pickle(f)

    if len(set(corpus['split'])) == 3:
        train = corpus.loc[corpus['split'] == 'train']
        data["train_x"] = list(train.sentence)
        data["train_y"] = list(train.label)
        data["train_x"] = [s.split() for s in data["train_x"]]

        dev = corpus.loc[corpus['split'] == 'dev']
        data["dev_x"] = list(dev.sentence)
        data["dev_y"] = list(dev.label)
        data["dev_x"] = [s.split() for s in data["dev_x"]]

        test = corpus.loc[corpus['split'] == 'test']
        data["test_x"] = list(test.sentence)
        data["test_y"] = list(test.label)
        data["test_x"] = [s.split() for s in data["test_x"]]

    else:
        sentences, labels = list(corpus.sentence), list(corpus.label)
        sentences = [s.split() for s in sentences]
        x, y = shuffle(sentences, labels, random_state=11)
        dev_idx = len(x) // 10 * 8
        test_idx = len(x) // 10 * 9

        # Randomly create a test set
        data["train_x"], data["train_y"] = x[:dev_idx], y[:dev_idx]
        data["dev_x"], data["dev_y"] = x[dev_idx:test_idx], y[dev_idx:test_idx]
        data["test_x"], data["test_y"] = x[test_idx:], y[test_idx:]

    return data

def read_SST1():
    data = read_pkl('data/sentiment_dataset/SST1.pkl')
    return data

def read_SST2():
    data = read_pkl('data/sentiment_dataset/SST2.pkl')
    return data

def read_SUBJ():
    data = read_pkl('data/sentiment_dataset/SUBJ.pkl')
    return data

def read_MPQA():
    data = read_pkl('data/sentiment_dataset/MPQA.pkl')
    return data

def read_CR():
    data = read_pkl('data/sentiment_dataset/CR.pkl')
    return data

def save_model(model, params):
    path = f"saved_models/{params['DATASET']}_{params['MODEL']}_{params['EPOCH']}.pkl"
    pickle.dump(model, open(path, "wb"))
    print(f"A model is saved successfully as {path}!")


def load_model(params):
    path = f"saved_models/{params['DATASET']}_{params['MODEL']}_{params['EPOCH']}.pkl"

    try:
        model = pickle.load(open(path, "rb"))
        print(f"Model in {path} loaded successfully!")

        return model
    except:
        print(f"No available model such as {path}.")
        exit()
