import time
import os
import numpy as np
import torch
from tqdm import tqdm
import pickle

import torch.optim as optim
from torch.optim import AdamW
from torch_ema import ExponentialMovingAverage
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import mean_squared_error
import sklearn.metrics as skm


import pandas as pd
import pickle
from utils.data_handling import CustomDataset
import interference_util 


def test(config, net, X_test, device):
    
    all_generated, start_time, end_time = inference_loop(config, net, X_test, device)
    


    test_resut_dict = {
        "time": (end_time - start_time),
        "generated_data": all_generated
    }#end_time = time.time(); time_fit = end_time - start_time

    return test_resut_dict
    

def inference_loop(config, net, X_test, device):
    start_time = time.time()
    test_data_pkl = torch.tensor(X_test, dtype=torch.float32).unsqueeze(1)
    batch_size = config["batch_size_test"] #Batchsize muss mit num_samples unten übereinstimmen
    dataset = CustomDataset(test_data_pkl)
    test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,  num_workers=4, pin_memory=True, prefetch_factor=2, persistent_workers=False)
    net.eval()
    net = net.to(device)

    all_generated_list  = []
    all_mse = []
    interference = interference_util.get_interference(config, device)
    for i, x in enumerate(tqdm(test_loader)):
        # start_time = time.time()
        x = x.to(device, non_blocking=True)
        generated_samples = interference(config, net, x) #return dimension ()
        # print(generated_samples.shape)
        #print(generated_samples[1])
        #print(generated_samples[2])
        if not x.shape == generated_samples.shape:
            print("x_test", x.shape)
            print("x_pred", generated_samples.shape)
        assert x.shape == generated_samples.shape #Prüfen, dass Dimensionen zusammenpassen
        all_generated_list.append(generated_samples.squeeze(1))

    end_time = time.time() 
    all_generated = torch.cat(all_generated_list, dim=0).cpu()

    return all_generated, start_time, end_time

def create_result_csv(dict_data_list, result_path):
    columns = ["dataset", "seed","method", "save_epoch", "f1_score", "aucroc", "aucpr", "time"]
    columns_main = ["dataset", "seed", "save_epoch", "time"]
    colums_second = ["method", "mean", "f1_score", "aucroc", "aucpr"]

    #fix problem with epoch value
    filtered_data = []
    for d in dict_data_list:
        print(d)
        table_dict = {key: d[key] for key in columns_main}
        for method in d["classifier_methods"]:
            table_dict_spezial = table_dict.copy()
            table_dict_spezial.update({key: d["classify_method_" + method][key] for key in colums_second})
            filtered_data.append(table_dict_spezial)
    #filtered_data = [{key: d[key] for key in columns} for d in dict_data_list]
    df = pd.DataFrame(filtered_data)
    df.to_csv(result_path + "metrics.csv", index=False)






