import csv
import json
import os
import random
import sys

import numpy as np
import sklearn
from sklearn import metrics
import torch
from torch.utils.data import DataLoader
import time
from AbstractClass.AbstractLogger import AbstractLogger
from SourceCode.TaskRelatedClasses import TaskData
import shutil


def exist_check(meta_task):
    try:
        unique_item = meta_task.support_set.support_x.shape[0]
        stream_length = meta_task.support_set.support_y.shape[0]
        check_stream_length = round(meta_task.query_set.query_y.sum().item())
        assert check_stream_length == stream_length, " exist task error! "
    except:
        print('exist_check error:',check_stream_length,stream_length)
        pass


# 将日志信息输出到excel表格
class ExistLogger(AbstractLogger):
    def __init__(self, test_meta_task_group_list, meta_task_group_discribe_list, dataset_name, config, loss_func,
                 eval_query_batch_size=500, eval_support_batch_size=500, flush_gap=1, train_comment=" ", eval_gap=1000,
                 save_gap=10000, prod_env=False):
        self.flush_gap = flush_gap
        self.config = config
        self.eval_query_batch_size = eval_query_batch_size
        self.eval_support_batch_size = eval_support_batch_size
        self.model_path = None
        self.save_gap = save_gap
        # self.log_path = log_path
        self.meta_task_discribe_list = meta_task_group_discribe_list
        self.test_meta_task_group_list = test_meta_task_group_list
        self.loss_func = loss_func
        # 训练注释
        self.train_comment = train_comment
        self.project_root = os.getcwd().split('ExpCode')[0]
        path = os.getcwd()
        self.path = path
        self.py_file_path = sys.argv[0]
        print(self.py_file_path)
        if prod_env:
            self.dataset_name = "Prod_" + time.strftime('%m%d_', time.localtime(time.time())) + dataset_name
        else:
            self.dataset_name = "Dev_" + time.strftime('%m%d_', time.localtime(time.time())) + dataset_name
        self.csv_writer_list = None
        self.log_file_list = None
        self.__init_log_file()
        self.file_header = None
        self.eval_gap = eval_gap
        self.best_loss_for_checkpointing = None
        self.checkpointing_tasks_list = []

    def add_tasks_for_checkpointing(self, meta_task):
        self.checkpointing_tasks_list.append(meta_task)

    def clean_tasks_for_checkpoingting(self):
        self.checkpointing_tasks_list = []

    def calculate_checkpointing_loss(self, model):
        all_checkpointing_loss_list = []
        for meta_task in self.checkpointing_tasks_list:
            with torch.no_grad():
                support_set = meta_task.support_set
                query_set = meta_task.query_set
                # train的时候防止batch为1
                while len(support_set) % self.eval_support_batch_size == 1:
                    if random.random() > 0.5:
                        self.eval_support_batch_size += 1
                    else:
                        self.eval_support_batch_size -= 1
                stream_length = None
                while len(query_set) % self.eval_query_batch_size == 1:
                    if random.random() > 0.5:
                        self.eval_query_batch_size += 1
                    else:
                        self.eval_query_batch_size -= 1
                stream_length = support_set.support_y.sum()
                model.clear()
                if self.eval_support_batch_size > len(support_set):
                    model.write(support_set.support_x, support_set.support_y)
                else:
                    support_data_loader = DataLoader(support_set, batch_size=self.eval_support_batch_size)
                    for i, (support_x, support_y) in enumerate(support_data_loader):
                        model.write(support_x, support_y)
                if self.eval_query_batch_size > len(query_set):
                    query_pred = model.query(query_set.query_x,
                                             stream_length.unsqueeze(-1).repeat(query_set.query_x.shape[0], 1))
                    query_y = query_set.query_y
                    one_loss = self.loss_func(query_pred, query_y)
                    loss = one_loss.item()
                else:
                    query_data_loader = DataLoader(query_set, batch_size=self.eval_query_batch_size, shuffle=False)
                    loss_list = []
                    weight_pred_list = []
                    query_y_list = []
                    for j, (query_x, query_y) in enumerate(query_data_loader):
                        query_pred = model.query(query_x, stream_length.unsqueeze(-1).repeat(query_x.shape[0], 1))
                        weight_pred_list.append(query_pred)
                        query_y_list.append(query_y)
                        one_loss = self.loss_func(query_pred, query_y)
                        loss_list.append(one_loss.item())
                    loss = sum(one_loss)
                all_checkpointing_loss_list.append(loss)
        return all_checkpointing_loss_list

    def compare_loss_and_best_loss(self,check_loss):
        ratio_list = []
        for i in range(len(check_loss)):
            ratio_list.append(self.best_loss_for_checkpointing[i]/check_loss[i])
        if sum(ratio_list)/len(ratio_list) >= 1:
            print('best     :',self.best_loss_for_checkpointing)
            print('next best:',check_loss)

            self.best_loss_for_checkpointing = check_loss
            return True
        else:
            print('best     :',self.best_loss_for_checkpointing)
            print('not best:',check_loss)
            return False

    def logging(self, model, step, comment=""):
        model.eval()
        assert len(self.checkpointing_tasks_list) != 0, "these is no tasks for checkpointing"
        print(step, 'step logging begin...')
        if self.best_loss_for_checkpointing is None:
            self.best_loss_for_checkpointing = self.calculate_checkpointing_loss(model)
            print('set model checkpoint at ',step,' step')
            torch.save(model,self.checkpointing_model_path)
        else:
            check_loss = self.calculate_checkpointing_loss(model)
            if self.compare_loss_and_best_loss(check_loss):
                print('set model checkpoint at ', step, ' step')
                torch.save(model, self.checkpointing_model_path)

        if step % self.save_gap == 1:
            torch.save(model, self.model_path)
        for i in range(len(self.log_file_list)):
            if self.file_header is None:
                for group in self.test_meta_task_group_list:
                    for task in group:
                        task.to_device()
            # warning
            log_file = self.log_file_list[i]
            csv_writer = self.csv_writer_list[i]
            test_meta_task_group = self.test_meta_task_group_list[i]
            group_test_merged_info_dict = self.eval_on_one_group(test_meta_task_group, model)

            # 如果是第一次logging，初始化所有文件的表头
            if self.file_header is None:
                self.file_header = list(group_test_merged_info_dict.keys())
                self.file_header.insert(0, "step")
                for writer in self.csv_writer_list:
                    writer.writerow(self.file_header)
            group_test_merged_info_dict['step'] = step
            row_content = []
            for key in self.file_header:
                row_content.append(group_test_merged_info_dict[key])
            csv_writer.writerow(row_content)
            if (step // self.eval_gap) % self.flush_gap == 0:
                log_file.flush()
        model.train()
        print(step, ' step logging done...')
        return

    def eval_on_one_group(self, test_meta_task_group, model):
        # store all test result for a group of meta tasks
        group_dict_list = []
        for test_meta_task in test_meta_task_group:
            mt_test_info_dict = self.eval_on_one_task(test_meta_task, model)
            group_dict_list.append(mt_test_info_dict)

        group_test_merged_info_dict = {}
        for key in group_dict_list[0].keys():
            value_mean = 0
            for dic in group_dict_list:
                value_mean += dic[key]
            value_mean /= len(group_dict_list)
            group_test_merged_info_dict[key] = value_mean
        return group_test_merged_info_dict

    def eval_on_one_task(self, test_meta_task, model):
        exist_check(test_meta_task)
        get_basic_func = self.get_basic_eval_info_on_one_task_only_exist
        basic_info_dict = get_basic_func(test_meta_task, model)
        get_additional_fuc = self.get_additional_eval_info_on_one_task_only_exist
        additional_info_dict = get_additional_fuc(test_meta_task, model)
        # items = basic_info_dict.items()
        # 合并信息字典
        info_dict = dict(list(basic_info_dict.items()) + list(additional_info_dict.items()))
        return info_dict

    def get_sparsity(self, batch_data, dim=(1,)):
        sparsity_data = torch.where(torch.abs(batch_data - 0.0) < 0.00001, 0.0, 1.0)
        res = sparsity_data.sum(dim=dim, keepdim=False)
        return res


    def get_additional_eval_info_on_one_task_only_exist(self, test_meta_task, model):
        info_dict = {}
        with torch.no_grad():
            query_set = test_meta_task.query_set
            query_data_loader = DataLoader(query_set, batch_size=self.eval_support_batch_size, shuffle=True)
            row_sparsity_list = []
            col_sparsity_list = []
            addresses_sparsity_list = []
            addresses_var_list = []
            source_var_list = []
            dest_var_list = []
            edge_embedding_sparsity_list = []
            edge_embedding_var_list = []
            edge_embedding_norm_list = []
            edge_embedding_l1_norm_list = []
            source_refine_vec_norm_list = []
            source_refine_vec_l1_norm_list = []
            dest_refine_vec_norm_list = []
            dest_refine_vec_l1_norm_list = []

        for j, (query_x, query_y) in enumerate(query_data_loader):
            source_embedding, dest_embedding, edge_embedding = model.get_embedding(query_x)
            source_refined, dest_refined = model.get_refined(source_embedding, dest_embedding)
            row_address, col_address = model.get_address(source_refined, dest_refined)

            addresses = row_address.view(row_address.shape[0], -1, 1).matmul(
                col_address.view(col_address.shape[0], 1, -1))
            addresses_var = ((addresses.shape[1] * addresses.shape[2]) * addresses.mean(0)).var().item()
            addresses_sparsity = self.get_sparsity((addresses), dim=(1, 2)).mean().item()
            # addresses_sparsity =
            edge_embedding_sparsity = self.get_sparsity(edge_embedding, dim=1).mean().item()
            edge_embedding_var = edge_embedding.sum(0).var().item()
            edge_embedding_norm = torch.sqrt((edge_embedding.square()).sum(dim=1)).mean().item()
            edge_embedding_l1_norm = edge_embedding.sum(dim=-1).mean().item()
            source_refine_vec_norm = torch.sqrt((source_refined.square()).sum(dim=1)).mean().item()
            dest_refine_vec_norm = torch.sqrt((dest_refined.square()).sum(dim=1)).mean().item()
            source_refine_vec_l1_norm = source_refined.sum(dim=1).mean().item()
            dest_refine_vec_l1_norm = dest_refined.sum(dim=1).mean().item()
            edge_embedding_sparsity_list.append(edge_embedding_sparsity)
            edge_embedding_var_list.append(edge_embedding_var)
            edge_embedding_norm_list.append(edge_embedding_norm)
            edge_embedding_l1_norm_list.append(edge_embedding_l1_norm)
            addresses_var_list.append(addresses_var)
            addresses_sparsity_list.append(addresses_sparsity)
            source_refine_vec_norm_list.append(source_refine_vec_norm)
            dest_refine_vec_norm_list.append(dest_refine_vec_norm)
            source_refine_vec_l1_norm_list.append(source_refine_vec_l1_norm)
            dest_refine_vec_l1_norm_list.append(dest_refine_vec_l1_norm)
            source_var = (row_address.sum(0) * row_address.shape[-1] / row_address.shape[0]).var().item()
            dest_var = (col_address.sum(0) * col_address.shape[-1] / col_address.shape[0]).var().item()
            source_var_list.append(source_var)
            dest_var_list.append(dest_var)
            row_sparsity = self.get_sparsity(row_address).mean().item()
            col_sparsity = self.get_sparsity(col_address).mean().item()
            row_sparsity_list.append(row_sparsity)
            col_sparsity_list.append(col_sparsity)
        row_sparsity = sum(row_sparsity_list) / len(row_sparsity_list)
        col_sparsity = sum(col_sparsity_list) / len(col_sparsity_list)
        source_var = sum(source_var_list) / len(source_var_list)
        dest_var = sum(dest_var_list) / len(dest_var_list)
        twodim_address_sparsity = sum(addresses_sparsity_list) / len(addresses_sparsity_list)
        twodim_address_var = sum(addresses_var_list) / len(addresses_var_list)
        edge_embedding_norm = sum(edge_embedding_norm_list) / len(edge_embedding_norm_list)
        edge_embedding_l1_norm = sum(edge_embedding_l1_norm_list) / len(edge_embedding_l1_norm_list)
        edge_embedding_var_list = sum(edge_embedding_var_list) / len(edge_embedding_var_list)
        edge_embedding_sparsity_list = sum(edge_embedding_sparsity_list) / len(edge_embedding_sparsity_list)
        source_refine_vec_norm = sum(source_refine_vec_norm_list) / len(source_refine_vec_norm_list)
        source_refine_vec_l1_norm = sum(source_refine_vec_l1_norm_list) / len(source_refine_vec_l1_norm_list)
        dest_refine_vec_norm = sum(dest_refine_vec_norm_list) / len(dest_refine_vec_norm_list)
        dest_refine_vec_l1_norm = sum(dest_refine_vec_l1_norm_list) / len(dest_refine_vec_l1_norm_list)

        info_dict['edge_embedding_norm'] = edge_embedding_norm
        info_dict['edge_embedding_l1_norm'] = edge_embedding_l1_norm
        info_dict["edge_embedding_var"] = edge_embedding_var_list
        info_dict['edge_embedding_sparsity'] = edge_embedding_sparsity_list
        info_dict['row_address_var'] = source_var
        info_dict['col_address_var'] = dest_var
        info_dict['row_address_sparsity'] = row_sparsity
        info_dict['col_address_sparsity'] = col_sparsity
        info_dict['twodim_address_sparsity'] = twodim_address_sparsity
        info_dict['twodim_address_var'] = twodim_address_var
        info_dict['source_refined_vec_norm'] = source_refine_vec_norm
        info_dict['source_refined_vec_l1_norm'] = source_refine_vec_l1_norm
        info_dict['dest_refined_vec_norm'] = dest_refine_vec_norm
        info_dict['dest_refined_vec_l1_norm'] = dest_refine_vec_l1_norm

        if hasattr(model, 'attention_matrix'):
            if hasattr(model.attention_matrix, 'scale'):
                info_dict['scale'] = model.attention_matrix.scale.item()
        return info_dict

    def get_basic_eval_info_on_one_task_only_exist(self, test_meta_task, model):
        info_dict = {}
        with torch.no_grad():
            support_set = test_meta_task.support_set
            query_set = test_meta_task.query_set
            # train的时候防止batch为1
            while len(support_set) % self.eval_support_batch_size == 1:
                if random.random() > 0.5:
                    self.eval_support_batch_size += 1
                else:
                    self.eval_support_batch_size -= 1
            stream_length = None
            while len(query_set) % self.eval_query_batch_size == 1:
                if random.random() > 0.5:
                    self.eval_query_batch_size += 1
                else:
                    self.eval_query_batch_size -= 1
            stream_length = support_set.support_y.sum()
            model.clear()
            if self.eval_support_batch_size > len(support_set):
                model.write(support_set.support_x, support_set.support_y)
            else:
                support_data_loader = DataLoader(support_set, batch_size=self.eval_support_batch_size)
                for i, (support_x, support_y) in enumerate(support_data_loader):
                    model.write(support_x, support_y)
            if self.eval_query_batch_size > len(query_set):
                query_pred = model.query(query_set.query_x,
                                         stream_length.unsqueeze(-1).repeat(query_set.query_x.shape[0], 1))
                query_y = query_set.query_y
                loss = self.loss_func(query_pred, query_y)
                # 如果换成 share paramter 写法的化
                query_y = query_y.view(-1).cpu()
                pred = query_pred.view(-1).cpu()
                loss_sum = loss
            else:
                query_data_loader = DataLoader(query_set, batch_size=self.eval_query_batch_size, shuffle=False)
                loss_list = []
                weight_pred_list = []
                query_y_list = []
                for j, (query_x, query_y) in enumerate(query_data_loader):
                    query_pred = model.query(query_x, stream_length.unsqueeze(-1).repeat(query_x.shape[0], 1))
                    weight_pred_list.append(query_pred)
                    query_y_list.append(query_y)
                    one_loss = self.loss_func(query_pred, query_y)
                    loss_list.append(one_loss)
                    if torch.isnan(query_y).any():
                        assert True
                pred = torch.cat(weight_pred_list).view(-1).cpu()
                query_y = torch.cat(query_y_list).cpu()
                loss_sum = torch.stack(loss_list)
                loss_sum = loss_sum.sum() / len(loss_list)
                query_y = query_y.view(-1)

            info_dict['loss'] = loss_sum.item()
            threshold = 0.5
            pred_label = torch.where(pred>threshold,torch.ones_like(pred),torch.zeros_like(pred))
            f1_score=metrics.f1_score(query_y, pred_label)
            acc_score=metrics.accuracy_score(query_y, pred_label)
            recall_score=metrics.recall_score(query_y, pred_label)
            precision_score=metrics.precision_score(query_y, pred_label)
            auc = metrics.roc_auc_score(query_y,pred)
            info_dict['f1_score'] = f1_score
            info_dict['acc_score'] = acc_score
            info_dict['recall_score'] = recall_score
            info_dict['precision_score'] = precision_score
            info_dict['auc_score'] = auc

            if hasattr(model, 'attention_matrix'):
                if hasattr(model.attention_matrix, 'scale_value'):
                    info_dict['scale_value'] = model.attention_matrix.scale_value.item()
        return info_dict

    # 为每一个group创建一个log文件存储相关信息
    def __init_log_file(self):
        self.log_file_list = []
        self.csv_writer_list = []
        time_str = time.strftime('_%m_%d_%H_%M_%S', time.localtime(time.time()))
        if not os.path.exists(os.path.join(self.project_root,
                                           'LogDir/{}/{}/'.format(self.dataset_name, self.train_comment + time_str))):
            os.makedirs(os.path.join(self.project_root,
                                     'LogDir/{}/{}/'.format(self.dataset_name, self.train_comment + time_str)))
        # info_example = ["step","embeding","col_sparsity","row_sparsity", "loss", "weight_ARE", "weight_AAE", "exist_precise", "exist_acc", "exist_F1","exist_recall"]
        self.model_path = os.path.join(self.project_root, 'LogDir/{}/{}/model'.format(self.dataset_name,self.train_comment + time_str))
        self.checkpointing_model_path = os.path.join(self.project_root,
                                                     'LogDir/{}/{}/checkpointing_model'.format(self.dataset_name,
                                                                                               self.train_comment + time_str))

        config_str = json.dumps(self.config)
        config_file = open(os.path.join(self.project_root, 'LogDir/{}/{}/config'.format(self.dataset_name,
                                                                                        self.train_comment + time_str)),
                           'w',
                           newline='', encoding='utf-8')
        config_file.write(config_str)
        config_file.close()
        for meta_task_group_discribe in self.meta_task_discribe_list:
            self.log_file_list.append(open(os.path.join(self.project_root,
                                                        'LogDir/{}/{}/log{}.csv'.format(self.dataset_name,
                                                                                        self.train_comment + time_str,
                                                                                        meta_task_group_discribe)), 'w',
                                           newline='', encoding='utf-8'))
            self.csv_writer_list.append(csv.writer(self.log_file_list[-1]))
        # 这里添加一个持久化test meta task的方法
        os.makedirs(os.path.join(self.project_root,
                                 'LogDir/{}/{}/test_tasks_{}/'.format(self.dataset_name, self.train_comment + time_str,
                                                                      self.train_comment + time_str)))
        for i in range(len(self.test_meta_task_group_list)):
            os.makedirs(os.path.join(self.project_root,
                                     'LogDir/{}/{}/test_tasks_{}/{}/'.format(self.dataset_name,
                                                                             self.train_comment + time_str,
                                                                             self.train_comment + time_str,
                                                                             self.meta_task_discribe_list[i])))
            for j in range(len(self.test_meta_task_group_list[i])):
                path = os.path.join(self.project_root,
                                    'LogDir/{}/{}/test_tasks_{}/{}/{}.npz'.format(self.dataset_name,
                                                                                  self.train_comment + time_str,
                                                                                  self.train_comment + time_str,
                                                                                  self.meta_task_discribe_list[i],
                                                                                  str(j)))

                self.save_meta_task(self.test_meta_task_group_list[i][j], path)
        print('持久化test_meta_task完成...')
        shutil.copytree(os.path.join(self.project_root, 'SourceCode'), os.path.join(self.project_root,
                                                                                    'LogDir/{}/{}/SourceCode/'.format(
                                                                                        self.dataset_name,
                                                                                        self.train_comment + time_str)))

        if '\\' in self.py_file_path:
            self.py_file_path = self.py_file_path.replace('\\', "/")
            self.path = self.path.replace('\\', "/")
            # print(self.py_file_path)
        os.makedirs(os.path.join(self.project_root,
                                 'LogDir/{}/{}/ExpCode/{}/'.format(self.dataset_name, self.train_comment + time_str,
                                                                   self.path.split('/')[-1])))
        shutil.copy(self.py_file_path, os.path.join(self.project_root,
                                                    'LogDir/{}/{}/ExpCode/{}/'.format(self.dataset_name,
                                                                                      self.train_comment + time_str,
                                                                                      self.path.split('/')[-1])))
        print('代码备份完成...')
        # 保存 baseline.py

        print('代码备份完成...')

    def save_meta_task(self, meta_task, path):
        np.savez(path, support_x=meta_task.support_set.support_x.cpu().numpy(),
                 support_y=meta_task.support_set.support_y.cpu().numpy(),
                 query_x=meta_task.query_set.query_x.cpu().numpy(), query_y=meta_task.query_set.query_y.cpu().numpy())

    def close_all_file(self):
        for file in self.log_file_list:
            file.close()
