import os
import random
import sys
import numpy as np
import torch
from pickle import load

from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, mean_squared_error

from blogfeedback_loader import *

def test():
    data_test_path='test.csv'
    test_dataset=BlogFeedbackDataLoader(csv_path=data_test_path)
    print("Testing dataset size:", len(test_dataset))

    test_loader=DataLoader(test_dataset,batch_size=len(test_dataset),shuffle=True,drop_last=False)

    model=torch.load("model.pth")

    param_amout = 0
    for p in model.named_parameters():
        param_amout += p[1].numel()
    print('The total param amount:', param_amout)

    criterion=torch.nn.MSELoss()
    ## make prediction
    with torch.no_grad():
        for j,(x,y) in enumerate(test_loader):
            pred_y=model(x)

            true_y=y
            mse = criterion(pred_y, true_y)
            rmse=mean_squared_error(pred_y,true_y,squared=False)

    print('Test MSE: {}'.format(rmse))



if __name__=="__main__":
    test()