import argparse
import os
import errno
import random
import pickle

import numpy as np

import torch
import torch.nn as nn
import torch.utils.data


class Net(nn.Module):
    def __init__(self, input_dim, output_dim, num_hidden=2, hidden_dim=[64, 64]):
        super(Net, self).__init__()
        self.num_hidden = num_hidden
        self.fc = nn.Linear(input_dim, hidden_dim[0])
        self.fc_list = []

        for i in range(num_hidden - 1):
            self.fc_list.append(nn.Linear(hidden_dim[i], hidden_dim[i + 1]))
            self.add_module('fc' + str(i + 2), self.fc_list[-1])
        self.fc_list.append(nn.Linear(hidden_dim[-1], output_dim))
        self.add_module('fc' + str(num_hidden + 1), self.fc_list[-1])

    def forward(self, x):
        x = nn.Tanh()(self.fc(x))
        for i in range(self.num_hidden - 1):
            x = nn.Tanh()(self.fc_list[i](x))
        x = self.fc_list[-1](x)
        return x
    
    def penultimate(self, x):
        x = nn.Tanh()(self.fc(x))
        for i in range(self.num_hidden - 1):
            x = nn.Tanh()(self.fc_list[i](x))
        return x

    
def load_data_post_stonet(dataset_name, random_state, base_path, model_path):
    
    seed = random_state
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    num_hidden = 2
    hidden_dim = [1000, 100]

    PATH = base_path + dataset_name + '/' + 'data_split_' + str(random_state) + '/' + model_path

    # determines the size of test set
    test_ratio = 0.2

    filename = PATH + 'data.txt'
    f = open(filename, 'rb')
    [X_train, Y_train, X_test, Y_test, x_train, x_cal, x_test, y_train, y_cal, y_test, scalerX, scalerY] = pickle.load(f)
    f.close()
    
    ntrain = x_train.shape[0]
    ntest = x_test.shape[0]
    dim = x_train.shape[1]
    output_dim = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    load_net = Net(dim, output_dim, num_hidden, hidden_dim)
    load_net.to(device)

    load_epoch = 5000
    if dataset_name == "Year":
        load_epoch = 200
    if dataset_name == "Protein":
        load_epoch = 1000
    
    load_net.load_state_dict(torch.load(PATH + 'model' + str(load_epoch) + '.pt'))
    
    load_net.eval()
    if (device.index is not None and x_cal.get_device() == -1) or (device.index is None and x_cal.get_device() >= 0):
        x_cal = x_cal.to(device)
        y_cal = y_cal.to(device)
        x_test = x_test.to(device)
        y_test = y_test.to(device)
    x_train_post = load_net.penultimate(x_cal).data
    y_train_post = y_cal.data
    x_test_post = load_net.penultimate(x_test).data
    y_test_post = y_test.data
    
    return x_train_post, y_train_post, x_test_post, y_test_post





