import csv
import json
import os
import random
import sys

import numpy as np
from sklearn import metrics
import torch
from torch.utils.data import DataLoader
import time
from AbstractClass.AbstractLogger import AbstractLogger
from SourceCode.TaskRelatedClasses import TaskData
import shutil


def check(meta_task):
    try:
        stream_length = meta_task.support_set.support_y.sum()
        check_stream_length = meta_task.query_set.query_y.sum()
        if stream_length != check_stream_length:
            print('logger error!!!!!!!!!!!!!!!!!!!!!!!!!')
    except:
        print('except')
        pass


# 将日志信息输出到excel表格
class DegreeLogger(AbstractLogger):
    def __init__(self, test_meta_task_group_list, meta_task_group_discribe_list, dataset_name, config, loss_func,
                 eval_query_batch_size=500, eval_support_batch_size=500, flush_gap=1, train_comment=" ", eval_gap=1000,
                 save_gap=10000, prod_env=False):
        self.flush_gap = flush_gap
        self.config = config
        self.eval_query_batch_size = eval_query_batch_size
        self.eval_support_batch_size = eval_support_batch_size
        self.model_path = None
        self.save_gap = save_gap
        # self.log_path = log_path
        self.meta_task_discribe_list = meta_task_group_discribe_list
        self.test_meta_task_group_list = test_meta_task_group_list
        self.loss_func = loss_func
        # 训练注释
        self.train_comment = train_comment
        self.project_root = os.getcwd().split('ExpCode')[0]
        path = os.getcwd()
        self.path = path
        self.py_file_path = sys.argv[0]
        print(self.py_file_path)
        if prod_env:
            self.dataset_name = "Prod_" + time.strftime('%m%d_', time.localtime(time.time())) + dataset_name
        else:
            self.dataset_name = "Dev_" + time.strftime('%m%d_', time.localtime(time.time())) + dataset_name
        self.csv_writer_list = None
        self.log_file_list = None
        self.__init_log_file()
        self.file_header = None
        self.eval_gap = eval_gap
        self.best_loss_for_checkpointing = None
        self.checkpointing_tasks_list = []

    def add_tasks_for_checkpointing(self, meta_task):
        self.checkpointing_tasks_list.append(meta_task)

    def clean_tasks_for_checkpoingting(self):
        self.checkpointing_tasks_list = []

    def calculate_checkpointing_loss(self,model):
        all_checkpointing_loss_list = []
        for meta_task in self.checkpointing_tasks_list:
            with torch.no_grad():
                support_set = meta_task.support_set
                query_set = meta_task.query_set
                # train的时候防止batch为1
                while len(support_set) % self.eval_support_batch_size == 1:
                    if random.random() > 0.5:
                        self.eval_support_batch_size += 1
                    else:
                        self.eval_support_batch_size -= 1
                stream_length = None
                while len(query_set) % self.eval_query_batch_size == 1:
                    if random.random() > 0.5:
                        self.eval_query_batch_size += 1
                    else:
                        self.eval_query_batch_size -= 1
                stream_length = support_set.support_y.sum()
                model.clear()
                if self.eval_support_batch_size > len(support_set):
                    model.write(support_set.support_x, support_set.support_y)
                else:
                    support_data_loader = DataLoader(support_set, batch_size=self.eval_support_batch_size)
                    for i, (support_x, support_y) in enumerate(support_data_loader):
                        model.write(support_x, support_y)
                if self.eval_query_batch_size > len(query_set):
                    query_pred = model.query(query_set.query_x,
                                             stream_length.unsqueeze(-1).repeat(query_set.query_x.shape[0], 1))
                    query_y = query_set.query_y
                    one_loss = self.loss_func(query_pred, query_y)
                    loss = one_loss.item()
                else:
                    query_data_loader = DataLoader(query_set, batch_size=self.eval_query_batch_size, shuffle=False)
                    loss_list = []
                    weight_pred_list = []
                    query_y_list = []
                    for j, (query_x, query_y) in enumerate(query_data_loader):
                        query_pred = model.query(query_x, stream_length.unsqueeze(-1).repeat(query_x.shape[0], 1))
                        weight_pred_list.append(query_pred)
                        query_y_list.append(query_y)
                        one_loss = self.loss_func(query_pred, query_y)
                        loss_list.append(one_loss.item())
                    loss = sum(one_loss)
                all_checkpointing_loss_list.append(loss)
        return all_checkpointing_loss_list

    def compare_loss_and_best_loss(self, check_loss):
        ratio_list = []
        for i in range(len(check_loss)):
            ratio_list.append(self.best_loss_for_checkpointing[i] / check_loss[i])
        if sum(ratio_list) / len(ratio_list) >= 1:
            print('best     :', self.best_loss_for_checkpointing)
            print('next best:', check_loss)

            self.best_loss_for_checkpointing = check_loss
            return True
        else:
            print('best     :', self.best_loss_for_checkpointing)
            print('not best:', check_loss)
            return False

    def logging(self, model, step, comment=""):
        model.eval()
        assert len(self.checkpointing_tasks_list) != 0, "these is no tasks for checkpointing"
        print(step, 'step logging begin...')
        if self.best_loss_for_checkpointing is None:
            self.best_loss_for_checkpointing = self.calculate_checkpointing_loss(model)
            print('set model checkpoint at ',step,' step')
            torch.save(model,self.checkpointing_model_path)
        else:
            check_loss = self.calculate_checkpointing_loss(model)
            if self.compare_loss_and_best_loss(check_loss):
                print('set model checkpoint at ', step, ' step')
                torch.save(model, self.checkpointing_model_path)

        if step % self.save_gap == 1:
            torch.save(model, self.model_path)
        for i in range(len(self.log_file_list)):
            if self.file_header is None:
                for group in self.test_meta_task_group_list:
                    for task in group:
                        task.to_device()
            # warning
            log_file = self.log_file_list[i]
            csv_writer = self.csv_writer_list[i]
            test_meta_task_group = self.test_meta_task_group_list[i]
            group_test_merged_info_dict = self.eval_on_one_group(test_meta_task_group, model)

            # 如果是第一次logging，初始化所有文件的表头
            if self.file_header is None:
                self.file_header = list(group_test_merged_info_dict.keys())
                self.file_header.insert(0, "step")
                for writer in self.csv_writer_list:
                    writer.writerow(self.file_header)
            group_test_merged_info_dict['step'] = step
            row_content = []
            for key in self.file_header:
                row_content.append(group_test_merged_info_dict[key])
            csv_writer.writerow(row_content)
            if (step // self.eval_gap) % self.flush_gap == 0:
                log_file.flush()
        model.train()
        print(step, ' step logging done...')
        return

    def eval_on_one_group(self, test_meta_task_group, model):
        # store all test result for a group of meta tasks
        group_dict_list = []
        for test_meta_task in test_meta_task_group:
            mt_test_info_dict = self.eval_on_one_task(test_meta_task, model)
            group_dict_list.append(mt_test_info_dict)

        group_test_merged_info_dict = {}
        for key in group_dict_list[0].keys():
            value_mean = 0
            for dic in group_dict_list:
                value_mean += dic[key]
            value_mean /= len(group_dict_list)
            group_test_merged_info_dict[key] = value_mean
        return group_test_merged_info_dict

    def eval_on_one_task(self, test_meta_task, model):
        check(test_meta_task)
        get_basic_func = self.get_basic_eval_info_on_one_task_only_weight
        basic_info_dict = get_basic_func(test_meta_task, model)
        get_additional_fuc = self.get_additional_eval_info_on_one_task_only_weight
        additional_info_dict = get_additional_fuc(test_meta_task, model)
        # items = basic_info_dict.items()
        # 合并信息字典
        info_dict = dict(list(basic_info_dict.items()) + list(additional_info_dict.items()))
        return info_dict

    def get_sparsity(self, batch_data, dim=(1,)):
        sparsity_data = torch.where(torch.abs(batch_data - 0.0) < 0.00001, 0.0, 1.0)
        res = sparsity_data.sum(dim=dim, keepdim=False)
        return res

    # col_sparsity

    def get_additional_eval_info_on_one_task_only_weight(self, test_meta_task, model):
        info_dict = {}
        with torch.no_grad():
            support_set = test_meta_task.support_set
            support_data_loader = DataLoader(support_set, batch_size=self.eval_support_batch_size, shuffle=True)
            row_sparsity_list = []
            col_sparsity_list = []
            addresses_sparsity_list = []
            addresses_var_list = []
            source_var_list = []
            dest_var_list = []
            edge_embedding_sparsity_list = []
            edge_embedding_var_list = []
            edge_embedding_norm_list = []
            edge_embedding_l1_norm_list = []
            source_refine_vec_norm_list = []
            source_refine_vec_l1_norm_list = []
            dest_refine_vec_norm_list = []
            dest_refine_vec_l1_norm_list = []

        for j, (support_x, support_y) in enumerate(support_data_loader):
            source_embedding, dest_embedding, edge_embedding = model.get_embedding(support_x)
            source_refined, dest_refined = model.get_refined(source_embedding, dest_embedding)
            row_address, col_address = model.get_address(source_refined, dest_refined)
            #
            addresses = row_address.view(row_address.shape[0], -1, 1).matmul(
                col_address.view(col_address.shape[0], 1, -1))
            addresses_var = ((addresses.shape[1] * addresses.shape[2]) * addresses.mean(0)).var().item()
            addresses_sparsity = self.get_sparsity((addresses), dim=(1, 2)).mean().item()
            # # addresses_sparsity =
            edge_embedding_sparsity = self.get_sparsity(edge_embedding, dim=1).mean().item()
            edge_embedding_var = edge_embedding.sum(0).var().item()
            edge_embedding_norm = torch.sqrt((edge_embedding.square()).sum(dim=1)).mean().item()
            edge_embedding_l1_norm = edge_embedding.sum(dim=-1).mean().item()
            # source_refine_vec_norm = torch.sqrt((source_refined.square()).sum(dim=1)).mean().item()
            # dest_refine_vec_norm = torch.sqrt((dest_refined.square()).sum(dim=1)).mean().item()
            # source_refine_vec_l1_norm = source_refined.sum(dim=1).mean().item()
            # dest_refine_vec_l1_norm = dest_refined.sum(dim=1).mean().item()
            edge_embedding_sparsity_list.append(edge_embedding_sparsity)
            edge_embedding_var_list.append(edge_embedding_var)
            # edge_embedding_norm_list.append(edge_embedding_norm)
            # edge_embedding_l1_norm_list.append(edge_embedding_l1_norm)
            addresses_var_list.append(addresses_var)
            addresses_sparsity_list.append(addresses_sparsity)
            # source_refine_vec_norm_list.append(source_refine_vec_norm)
            # dest_refine_vec_norm_list.append(dest_refine_vec_norm)
            # source_refine_vec_l1_norm_list.append(source_refine_vec_l1_norm)
            # dest_refine_vec_l1_norm_list.append(dest_refine_vec_l1_norm)
            # source_var = (row_address.sum(0) * row_address.shape[-1] / row_address.shape[0]).var().item()
            # dest_var = (col_address.sum(0) * col_address.shape[-1] / col_address.shape[0]).var().item()
            # source_var_list.append(source_var)
            # dest_var_list.append(dest_var)
            # row_sparsity = self.get_sparsity(row_address).mean().item()
            # col_sparsity = self.get_sparsity(col_address).mean().item()
            # row_sparsity_list.append(row_sparsity)
            # col_sparsity_list.append(col_sparsity)
        # row_sparsity = sum(row_sparsity_list) / len(row_sparsity_list)
        # col_sparsity = sum(col_sparsity_list) / len(col_sparsity_list)
        # source_var = sum(source_var_list) / len(source_var_list)
        # dest_var = sum(dest_var_list) / len(dest_var_list)
        twodim_address_sparsity = sum(addresses_sparsity_list) / len(addresses_sparsity_list)
        twodim_address_var = sum(addresses_var_list) / len(addresses_var_list)
        # edge_embedding_norm = sum(edge_embedding_norm_list) / len(edge_embedding_norm_list)
        # edge_embedding_l1_norm = sum(edge_embedding_l1_norm_list) / len(edge_embedding_l1_norm_list)
        edge_embedding_var_list = sum(edge_embedding_var_list) / len(edge_embedding_var_list)
        edge_embedding_sparsity_list = sum(edge_embedding_sparsity_list) / len(edge_embedding_sparsity_list)
        # source_refine_vec_norm = sum(source_refine_vec_norm_list) / len(source_refine_vec_norm_list)
        # source_refine_vec_l1_norm = sum(source_refine_vec_l1_norm_list) / len(source_refine_vec_l1_norm_list)
        # dest_refine_vec_norm = sum(dest_refine_vec_norm_list) / len(dest_refine_vec_norm_list)
        # dest_refine_vec_l1_norm = sum(dest_refine_vec_l1_norm_list) / len(dest_refine_vec_l1_norm_list)
        #
        # info_dict['edge_embedding_norm'] = edge_embedding_norm
        # info_dict['edge_embedding_l1_norm'] = edge_embedding_l1_norm
        info_dict["edge_embedding_var"] = edge_embedding_var_list
        info_dict['edge_embedding_sparsity'] = edge_embedding_sparsity_list
        # info_dict['row_address_var'] = source_var
        # info_dict['col_address_var'] = dest_var
        # info_dict['row_address_sparsity'] = row_sparsity
        # info_dict['col_address_sparsity'] = col_sparsity
        info_dict['twodim_address_sparsity'] = twodim_address_sparsity
        info_dict['twodim_address_var'] = twodim_address_var
        # info_dict['source_refined_vec_norm'] = source_refine_vec_norm
        # info_dict['source_refined_vec_l1_norm'] = source_refine_vec_l1_norm
        # info_dict['dest_refined_vec_norm'] = dest_refine_vec_norm
        # info_dict['dest_refined_vec_l1_norm'] = dest_refine_vec_l1_norm
        #
        # if hasattr(model, 'attention_matrix'):
        #     if hasattr(model.attention_matrix, 'scale'):
        #         info_dict['scale'] = model.attention_matrix.scale.item()
        return info_dict

    def get_basic_eval_info_on_one_task_only_weight(self, test_meta_task, model):
        info_dict = {}
        with torch.no_grad():
            support_set = test_meta_task.support_set
            query_set = test_meta_task.query_set
            # train的时候防止batch为1
            # while len(support_set) % self.eval_support_batch_size == 1:
            #     if random.random() > 0.5:
            #         self.eval_support_batch_size += 1
            #     else:
            #         self.eval_support_batch_size -= 1
            # stream_length = None
            # while len(query_set) % self.eval_query_batch_size == 1:
            #     if random.random() > 0.5:
            #         self.eval_query_batch_size += 1
            #     else:
            #         self.eval_query_batch_size -= 1
            stream_length = support_set.support_y.sum()
            model.clear()
            if self.eval_support_batch_size > len(support_set):
                model.write(support_set.support_x, support_set.support_y)
            else:
                support_data_loader = DataLoader(support_set, batch_size=self.eval_support_batch_size)
                for i, (support_x, support_y) in enumerate(support_data_loader):
                    model.write(support_x, support_y)
            if self.eval_query_batch_size > len(query_set):
                query_pred = model.query(query_set.query_x,
                                         stream_length.unsqueeze(-1).repeat(query_set.query_x.shape[0], 1))
                query_y = query_set.query_y
                query_pred = query_pred.sum(dim=1,keepdim=True)
                # 如果换成 share paramter 写法的化
                weight_query_y = query_y.view(-1).cpu()
                weight_pred = query_pred.view(-1).cpu()
            else:
                query_data_loader = DataLoader(query_set, batch_size=self.eval_query_batch_size, shuffle=False)
                loss_list = []
                weight_pred_list = []
                query_y_list = []
                for j, (query_x, query_y) in enumerate(query_data_loader):
                    query_pred = model.query(query_x, stream_length.unsqueeze(-1).repeat(query_x.shape[0], 1))
                    query_pred = query_pred.sum(dim=1, keepdim=True)
                    weight_pred_list.append(query_pred)
                    query_y_list.append(query_y)
                    if torch.isnan(query_y).any():
                        assert False
                weight_pred = torch.cat(weight_pred_list).view(-1).cpu()
                weight_query_y = torch.cat(query_y_list).cpu().view(-1)

            weight_pred_var = weight_pred.var().item()
            weight_label_var = weight_query_y.var().item()
            weight_ARE = torch.mean(torch.abs(weight_pred - weight_query_y) / weight_query_y)
            weight_AAE = torch.mean(torch.abs(weight_pred - weight_query_y))
            info_dict["pred_var"] = weight_pred_var
            info_dict["label_var"] = weight_label_var
            info_dict['ARE'] = weight_ARE.item()
            info_dict['AAE'] = weight_AAE.item()
            info_dict['label_sum'] = weight_query_y.sum().item()
            info_dict['pre_sum'] = weight_pred.sum().item()
            info_dict['item_num'] = weight_pred.shape[0]
            # info_dict['loss'] = loss_sum.item()

            sorted_query_y, indices = query_y.view(-1).sort(descending=True)
            sorted_weighted_pred, indices = weight_pred.view(-1).sort(descending=True)
            # ratio as benchmark
            index = int(sorted_query_y.size(dim=0) * 0.2)
            benchmark_weight_pred = sorted_weighted_pred[index].item()
            benchmark_query_y = sorted_query_y[index].item()
            label_query_y = torch.where(query_y > benchmark_query_y, torch.ones_like(query_y), torch.zeros_like(query_y))
            label_weight_pred = torch.where(weight_pred > benchmark_weight_pred, torch.ones_like(weight_pred), torch.zeros_like(weight_pred))
            weight_f1 = metrics.f1_score(label_query_y.cpu(), label_weight_pred.cpu())
            info_dict['weight_F1'] = weight_f1
            # 最后排序高低频项目
            # sort_positive_weight_query, index = weight_query_y.sort(dim=-1, descending=True)
            # sort_positive_weight_pred = weight_pred[index]
            # top_positive_weight_pred_mean = sort_positive_weight_pred[:int(weight_pred.shape[0] * 0.01)].mean()
            # # top_positive_weight_pred_sum = sort_positive_weight_pred[:int(weight_pred.shape[0] * 0.01)].sum()
            #
            # top_positive_weight_query_mean = sort_positive_weight_query[:int(weight_pred.shape[0] * 0.01)].mean()
            # bottom_positive_weight_pred_mean = sort_positive_weight_pred[int(weight_pred.shape[0] * 0.01):].mean()
            # # bottom_positive_weight_pred_sum = sort_positive_weight_pred[int(weight_pred.shape[0] * 0.01):].sum()
            #
            # bottom_positive_weight_query_mean = sort_positive_weight_query[int(weight_pred.shape[0] * 0.01):].mean()
            # info_dict['top_positive_weight_pred_mean'] = top_positive_weight_pred_mean.item()
            # # info_dict['top_positive_weight_pred_sum'] = top_positive_weight_pred_sum.item()
            # info_dict['top_positive_weight_query_mean'] = top_positive_weight_query_mean.item()
            # info_dict['bottom_positive_weight_pred_mean'] = bottom_positive_weight_pred_mean.item()
            # # info_dict['bottom_positive_weight_pred_sum'] = bottom_positive_weight_pred_sum.item()
            # info_dict['bottom_positive_weight_query_mean'] = bottom_positive_weight_query_mean.item()
            if hasattr(model, 'attention_matrix'):
                if hasattr(model.attention_matrix, 'scale_value'):
                    info_dict['scale_value'] = model.attention_matrix.scale_value.item()
        return info_dict


    # 为每一个group创建一个log文件存储相关信息
    def __init_log_file(self):
        self.log_file_list = []
        self.csv_writer_list = []
        time_str = time.strftime('_%m_%d_%H_%M_%S', time.localtime(time.time()))
        if not os.path.exists(os.path.join(self.project_root,
                                           'LogDir/{}/{}/'.format(self.dataset_name, self.train_comment + time_str))):
            os.makedirs(os.path.join(self.project_root,
                                     'LogDir/{}/{}/'.format(self.dataset_name, self.train_comment + time_str)))
        # info_example = ["step","embeding","col_sparsity","row_sparsity", "loss", "weight_ARE", "weight_AAE", "exist_precise", "exist_acc", "exist_F1","exist_recall"]
        self.model_path = os.path.join(self.project_root, 'LogDir/{}/{}/model'.format(self.dataset_name,
                                                                                      self.train_comment + time_str))
        self.checkpointing_model_path = os.path.join(self.project_root,
                                                     'LogDir/{}/{}/checkpointing_model'.format(self.dataset_name,                                                                                         self.train_comment + time_str))

        config_str = json.dumps(self.config)
        config_file = open(os.path.join(self.project_root, 'LogDir/{}/{}/config'.format(self.dataset_name,
                                                                                        self.train_comment + time_str)),
                           'w',
                           newline='', encoding='utf-8')
        config_file.write(config_str)
        config_file.close()
        for meta_task_group_discribe in self.meta_task_discribe_list:
            self.log_file_list.append(open(os.path.join(self.project_root,
                                                        'LogDir/{}/{}/log{}.csv'.format(self.dataset_name,
                                                                                        self.train_comment + time_str,
                                                                                        meta_task_group_discribe)), 'w',
                                           newline='', encoding='utf-8'))
            self.csv_writer_list.append(csv.writer(self.log_file_list[-1]))
        # 这里添加一个持久化test meta task的方法
        os.makedirs(os.path.join(self.project_root,
                                 'LogDir/{}/{}/test_tasks_{}/'.format(self.dataset_name, self.train_comment + time_str,
                                                                      self.train_comment + time_str)))
        for i in range(len(self.test_meta_task_group_list)):
            os.makedirs(os.path.join(self.project_root,
                                     'LogDir/{}/{}/test_tasks_{}/{}/'.format(self.dataset_name,
                                                                             self.train_comment + time_str,
                                                                             self.train_comment + time_str,
                                                                             self.meta_task_discribe_list[i])))
            for j in range(len(self.test_meta_task_group_list[i])):
                path = os.path.join(self.project_root,
                                    'LogDir/{}/{}/test_tasks_{}/{}/{}.npz'.format(self.dataset_name,
                                                                                  self.train_comment + time_str,
                                                                                  self.train_comment + time_str,
                                                                                  self.meta_task_discribe_list[i],
                                                                                  str(j)))

                self.save_meta_task(self.test_meta_task_group_list[i][j], path)
        print('持久化test_meta_task完成...')
        shutil.copytree(os.path.join(self.project_root, 'SourceCode'), os.path.join(self.project_root,
                                                                                    'LogDir/{}/{}/SourceCode/'.format(
                                                                                        self.dataset_name,
                                                                                        self.train_comment + time_str)))

        if '\\' in self.py_file_path:
            self.py_file_path = self.py_file_path.replace('\\', "/")
            self.path = self.path.replace('\\', "/")
            # print(self.py_file_path)
        os.makedirs(os.path.join(self.project_root,
                                 'LogDir/{}/{}/ExpCode/{}/'.format(self.dataset_name, self.train_comment + time_str,
                                                                   self.path.split('/')[-1])))
        shutil.copy(self.py_file_path, os.path.join(self.project_root,
                                                    'LogDir/{}/{}/ExpCode/{}/'.format(self.dataset_name,
                                                                                      self.train_comment + time_str,
                                                                                      self.path.split('/')[-1])))
        print('代码备份完成...')
        # 保存 baseline.py

        print('代码备份完成...')

    def save_meta_task(self, meta_task, path):
        np.savez(path, support_x=meta_task.support_set.support_x.cpu().numpy(),
                 support_y=meta_task.support_set.support_y.cpu().numpy(),
                 query_x=meta_task.query_set.query_x.cpu().numpy(), query_y=meta_task.query_set.query_y.cpu().numpy())

    def close_all_file(self):
        for file in self.log_file_list:
            file.close()
