from data_generator_val import DataGenerator
import numpy as np
import json
import os
import pandas as pd
import pickle
import time
import argparse
import torch
import traceback
from sklearn.model_selection import train_test_split


from adbench.myutils import Utils

import model_util
import utils.data_handling
import utils.file_handling
import utils.metric
import utils.monitor_train
import test
import interference_util
import utils.plots

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
RESULT_PATH = "./result/"
CODE_PATH = "./"

def safe_test(config, id, result_path, file_path_toml):
    print(file_path_toml)
    utils.file_handling.create_folder(result_path)
    utils.file_handling.copy_file(file_path_toml, result_path + "config.toml")
    
    try:
        main(config, id, result_path)
    except Exception as e:
        print("Terminated unexpected")
        print(e)
        print(traceback.format_exc())
    finally:
        pass


def main(config, id, result_path):
    start_seed = config["start_seed"]
    end_seed = config["end_seed"]
    ckps__base_path = config["model_path"] + config["model_type"] + "/" + config["backbone_model"] + "_" + str(id) + "/"
    result_dict_list  = []
    anomalie_score_fn = interference_util.get_anomalie_score(config)
    for seed in range(start_seed, end_seed):
        datagenerator = DataGenerator(seed = seed, test_size=0.5, normal=True)
        dataset_list = utils.data_handling.create_dataset_list(datagenerator, config["dataset_list_AD"], config["dataset_list"])
        print(dataset_list)
        # time.sleep(2)

        for dataset in dataset_list:
            utils.monitor_train.print_console_output(id, seed, dataset, config["num_epochs"], config["model_type"], config["backbone_model"])
            print(dataset)
            data_dir = config["data_path"] + "/" + dataset + "/seed_" + str(seed) + ".pkl"
            if os.path.exists(data_dir):
                with (open(f"{data_dir}" , "rb")) as data_file:
                            data = pickle.load(data_file)
            else:
                print("Seed for dataset ", dataset, " does not exists")
                break                                

            ckps_path = ckps__base_path + dataset
            ckps_path = os.path.join(ckps_path, "seed_" + str(seed) +"best.pkl")
            net = load_model(ckps_path, config, data['X_test'].shape[1], seed)

            result_dict = test.test(config, net, data['X_test'], device)
            result_dict["seed"] = seed
            result_dict["dataset"] = dataset
            result_dict["save_epoch"] = 0
            result_dict["anomalie_score"], result_dict["anomalie_score_all_features"] = anomalie_score_fn(result_dict["generated_data"], data["X_test"])
            del result_dict['generated_data'] #save disk space when save dict

            result_dict["classifier_methods"] = config["classify_method"]
            for method in config["classify_method"]:
                result_dict_method = create_metrics_and_plots(config, data, result_dict, result_path, method)
                result_dict["classify_method_" + method] = result_dict_method
            result_dict_list.append(result_dict)
    
    # test.average_seeds(result_dict_list)
    # test.save_result_dict_list(result_dict_list, result_path)
    test.create_result_csv(result_dict_list, result_path)
    
def create_metrics_and_plots(config, data, result_dict, result_path, method):
    result_dict_method = {}
    if "y_test" in data:
        f1_score, aucroc, aucpr = utils.metric.calculate_metrics_threshold(config, result_dict["anomalie_score"], data["y_test"])
        # utils.plots.create_mse_plot_picture(result_path, result_dict, result_dict_method)
        

        result_dict_method["mean"] = np.mean(np.squeeze(np.array(result_dict["anomalie_score"])))
        result_dict_method["method"] = method
        result_dict_method["f1_score"] = f1_score
        result_dict_method["aucroc"] = aucroc
        result_dict_method["aucpr"] = aucpr
    else:
        assert False
    return result_dict_method


def load_model(path, config, d_in, seed):
    if config["backbone_model"] in ["TTVAE", "TabM", "MLP2048", "Base_Transformer"]:
        directory = os.path.dirname(path)
        if config["validation"]:
            new_filename = "model_object_"+ "seed_" + str(seed) +"_val_best.pkl"
        else:
            new_filename = "model_object_" + "seed_" + str(seed) +"best.pkl"
        new_path = os.path.join(directory, new_filename)
        net = torch.load(new_path, map_location=device, weights_only=False) 
    else:
        net, optimizer, _, _ = model_util.get_model(config, device, d_in, None)

        checkpoint = torch.load(path, map_location=device)
        # feed model dict and optimizer state
        net.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print('Successfully loaded model at iteration best')
    
    return net


def start_test(train_id, config_id):
    csv_path = CODE_PATH + "training_monitor.csv"
    model_type, backbone = utils.data_handling.get_model_backbone_from_csv(csv_path, train_id)
    result_path = RESULT_PATH + str(train_id) + "_" + model_type + "_" + backbone + "/"
    if config_id == 0:
        filename = "config.toml"
    elif config_id == -1:
        for config_file in os.listdir(result_path):
            if config_file.endswith('.toml'):
                print(f"Found the first .toml file: {config_file}")
                filename = config_file
                break
    else:
        filename = "config_" + str(config_id) + ".toml"

    
    if config_id == 0:
        result_path_new = RESULT_PATH + str(train_id) + "_" + model_type + "_" + backbone + "/"
    else:
        result_path_new = RESULT_PATH + str(train_id) + "_" + model_type + "_" + backbone + "/" + str(config_id) + "/"

    data = utils.file_handling.read_toml(result_path + filename)
    flatten_dict = utils.data_handling.flatten_dict(data, ["General", data["General"]["model_type"]])
    print(flatten_dict)

    safe_test(flatten_dict, train_id, result_path_new, result_path + filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Settings')
    parser.add_argument('--config', type=int, 
        default=2, help='number config file') #-1 to get first toml
    parser.add_argument('--train_id', type=int, 
        default=2143, help='number training_id')
    args = parser.parse_args()
    
    
    run_loop = False
    if run_loop:
        for i in [
        1062, 1066
    ]:
            start_test(i, 1)
    else:
        start_test(args.train_id, args.config)


    
