import csv
import os
import sys
import argparse
import multiprocessing
import time

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

from SourceCode.TaskRelatedClasses.TaskData import SupportSet, QuerySet, MetaTask

root_path = os.path.abspath(__file__)
root_path = '/'.join(root_path.split('/')[:-3])
sys.path.append(root_path)




def load_model(model_path):
    # os.environ['CUDA_VISIBLE_DEVICES'] = '2'
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torch.load(model_path)
    model.memory_matrix.device = device
    model.to(device)
    return model, device


def convert_npz_list_2_meta_task_list(support_x_list, support_y_list, query_x_list, query_y_list, device):
    meta_task_list = []
    for i in range(len(support_x_list)):
        support_x = torch.tensor(support_x_list[i],device=device)
        support_y = torch.tensor(support_y_list[i],device=device)
        query_x = torch.tensor(query_x_list[i],device=device)
        query_y = torch.tensor(query_y_list[i],device=device)
        support_set = SupportSet(support_x, support_y, device)
        query_set = QuerySet(query_x, query_y, device)
        test_meta_task = MetaTask(support_set, query_set)
        meta_task_list.append(test_meta_task)
    return meta_task_list


def load_npz(path):
    file = np.load(path)
    support_y = file['support_y']
    support_x = file['support_x']
    query_x = file['query_x']
    query_y = file['query_y']
    # if support_y.min() <= 0.0001 or query_y.min() <= 0.0001:
    #     print('find illegal frequency: zero!')
    #     exit()
    return support_x, support_y, query_x, query_y


# extract all queries&counts from all file in dir_path
def load_all_npz_in_dir(dir_path):
    file_name_list = os.listdir(dir_path)
    for file_name in file_name_list:
        if 'npz' not in file_name:
            print('illegal file! Not NPZ File')
            exit()
    file_path_list = [os.path.join(dir_path, file_name) for file_name in file_name_list]
    support_x_list = []
    support_y_list = []
    query_x_list = []
    query_y_list = []
    for file_path in file_path_list:
        support_x, support_y, query_x, query_y = load_npz(file_path)
        support_x_list.append(support_x)
        support_y_list.append(support_y)
        query_x_list.append(query_x)
        query_y_list.append(query_y)
    return support_x_list, support_y_list, query_x_list, query_y_list


def mayfly_eval_on_one_task(model, test_meta_task):
    support_set = test_meta_task.support_set
    query_set = test_meta_task.query_set
    stream_length = support_set.support_y.sum()
    model.eval()
    batch_size = 10000
    with torch.no_grad():
        support_data_loader = DataLoader(support_set, batch_size=batch_size)
        model.clear()
        for i, (support_x, support_y) in enumerate(support_data_loader):
            model.write(support_x, support_y)
        query_data_loader = DataLoader(query_set, batch_size=batch_size, shuffle=False)
        weight_pred_list = []
        query_y_list = []
        for j, (query_x, query_y) in enumerate(query_data_loader):
            query_pred = model.query(query_x, stream_length.unsqueeze(-1).repeat(query_x.shape[0], 1))
            weight_pred_list.append(query_pred)
            query_y_list.append(query_y)
        query_pred = torch.cat(weight_pred_list).view(-1).cpu()
        query_y = torch.cat(query_y_list).view(-1).cpu()
        AAE = torch.mean(torch.abs(query_pred - query_y)).cpu().item()
        ARE = torch.mean(torch.abs(query_pred - query_y) / query_y).cpu().item()
    return AAE, ARE


def mayfly_eval_one_group(model, one_group_dir, device):
    support_x_list, support_y_list, query_x_list, query_y_list = load_all_npz_in_dir(one_group_dir)
    meta_task_list = convert_npz_list_2_meta_task_list(support_x_list, support_y_list, query_x_list, query_y_list,
                                                       device)
    AAE_list = []
    ARE_list = []
    for meta_task in meta_task_list:
        AAE,ARE = mayfly_eval_on_one_task(model,test_meta_task=meta_task)
        AAE_list.append(AAE)
        ARE_list.append(ARE)
    num_of_task = len(AAE_list)
    ave_AAE = round(sum(AAE_list)/num_of_task,2)
    ave_ARE = round(sum(ARE_list)/num_of_task,2)
    AAE_VAR = round(float(np.var(AAE_list)),5)
    ARE_VAR = round(float(np.var(ARE_list)),5)
    return ave_ARE,ave_AAE,ARE_VAR,AAE_VAR

def calculate_model_memory_size(model):
    memory_matrix = model.memory_matrix.memory_matrix.data
    shape = memory_matrix.shape
    space = 1
    for i in shape:
        space *= i
    # KB
    space = (space * 4) / 1024
    print(space)
    return round(space,2)


def sort_by_length(name_list):
    new_list = []
    length_list = []
    for name in name_list:
        length = name.split('_')[1]
        length_list.append(int(length))
    length_nd = np.array(length_list)
    index_list = np.argsort(length_nd)
    for index in index_list:
        new_list.append(name_list[index])
    return new_list


if __name__ == '__main__':
    path_list = os.listdir('.')
    model_path_list = []
    groups_dir = None
    for path in path_list:
        if 'model' in path and 'csv' not in path:
            model_path_list.append(path)
        if 'test_tasks_' in path:
            if groups_dir is None:
                groups_dir = path
            else:
                print('error')
                exit(0)
    for model_path in model_path_list:
        model, device = load_model(model_path)
        space_budget = calculate_model_memory_size(model)
        groups_name_list = os.listdir(groups_dir)
        print(groups_name_list)
        groups_name_list = sort_by_length(groups_name_list)
        print(groups_name_list)
        result_file = open("./calc{}KB_{}_eval_result.csv".format(space_budget,model_path), "w", newline='')
        csv_writer = csv.writer(result_file)
        header = ["Group_Name", "ARE", "AAE", "ARE_VAR", "AAE_VAR"]
        csv_writer.writerow(header)
        for one_group_name in groups_name_list:
            print('eval ' + one_group_name)
            one_group_dir = os.path.join(groups_dir, one_group_name)
            result_tuple = mayfly_eval_one_group(model, one_group_dir, device=device)
            result_list = [one_group_name]
            result_list.extend(list(result_tuple))
            csv_writer.writerow(result_list)
        result_file.close()

