# evaluate metric w/o MIA on prediction model
# evaluate metric on statistics
import pandas as pd
import torch
from my_utils.utils import DATASET_NAMES, PRE_MODEL_TYPES, get_dataset_params, get_metric, setup_seed
from dataset.dataset import MyDataset
from torch.utils.data import DataLoader

use_filter = True
window = 15
order = 5
ATTACK_INDEX = -1
END = 1000
device = "cuda:0"
num_samples = 1001
csv_path = 'result-metric_wo_defense-0817.csv'
pre_model_dir = "result_prediction_0813"


setup_seed()
df = pd.DataFrame()
count = 0


with torch.no_grad():

    # choose dataset_name
    for dataset_name in DATASET_NAMES[:]:
        context_length, prediction_length = get_dataset_params(dataset_name)
        dataset = MyDataset(dataset_name,'test',context_length,prediction_length,num_samples,use_filter,window,order)
        dataloader = DataLoader(dataset,num_samples,shuffle=False)

        # choose pre_model
        for pre_model_type in PRE_MODEL_TYPES[:]:
            ck_path = f'{pre_model_dir}/{dataset_name}/{pre_model_type}/filter={use_filter}/checkpoint/best.pt'
            model = torch.load(ck_path).to(device).eval()

            context, target = next(iter(dataloader))
            context = context.to(device)
            target = target.to(device)
            context, target, _, _ = dataset.normalize(context,target)
            prediction, _ = model(context,target)
            y = target[:,ATTACK_INDEX]
            y_hat = prediction[:,ATTACK_INDEX]
            metric = get_metric(y,y_hat)
            mae = metric['mae']
            mse = metric['mse']
            # mse = ((prediction[:,ATTACK_INDEX] - target[:,ATTACK_INDEX])**2).mean().item()
            # mae = ((prediction[:,ATTACK_INDEX] - target[:,ATTACK_INDEX]).abs()).mean().item()
            df = df.append({
                "pre_dir": pre_model_dir,
                "dataset_name": dataset_name,
                "pre_type": pre_model_type,
                "MAE": mae,
                "MSE": mse,
                "use_filter": use_filter,
            },ignore_index=True)
            count += 1
            print(count)


df.to_csv(csv_path,index=False)
