import csv
import json
import os
import random
import sys

import numpy as np
from sklearn import metrics
import torch
from torch.utils.data import DataLoader
import time
from AbstractClass.AbstractLogger import AbstractLogger
from SourceCode.TaskRelatedClasses import TaskData
import shutil


def check(meta_task):
    try:
        unique_item = meta_task.support_set.support_x.shape[0]
        steam_length = meta_task.support_set.support_y.sum()
        check_stream_length = meta_task.query_set.query_y[:, 0]
        sum_stream_length = check_stream_length.sum()
        if steam_length != sum_stream_length or steam_length.item() == 0 or steam_length.item() <= unique_item:
            print('logger error!!!!!!!!!!!!!!!!!!!!!!!!!')
    except:
        print('except')
        pass



class BasicLogger(AbstractLogger):
    def __init__(self, test_meta_task_group_list, meta_task_group_discribe_list, dataset_name, config, loss_func,
                 eval_query_batch_size=500, eval_support_batch_size=500, flush_gap=1, train_comment=" ", eval_gap=1000,
                 save_gap=10000, prod_env=False):
        self.flush_gap = flush_gap
        self.config = config
        self.eval_query_batch_size = eval_query_batch_size
        self.eval_support_batch_size = eval_support_batch_size
        self.model_path = None
        self.save_gap = save_gap
        self.meta_task_discribe_list = meta_task_group_discribe_list
        self.test_meta_task_group_list = test_meta_task_group_list
        self.loss_func = loss_func
        self.train_comment = train_comment
        self.project_root = os.getcwd().split('ExpCode')[0]
        path = os.getcwd()
        self.path = path
        self.py_file_path = sys.argv[0]
        print(self.py_file_path)
        if prod_env:
            self.dataset_name = "Prod_" + time.strftime('%m%d_', time.localtime(time.time())) + dataset_name
        else:
            self.dataset_name = "Dev_" + time.strftime('%m%d_', time.localtime(time.time())) + dataset_name
        self.csv_writer_list = None
        self.log_file_list = None
        self.__init_log_file()
        self.file_header = None
        self.eval_gap = eval_gap
        self.best_loss_for_checkpointing = None
        self.checkpointing_tasks_list = []

    def add_tasks_for_checkpointing(self, meta_task):
        self.checkpointing_tasks_list.append(meta_task)

    def clean_tasks_for_checkpoingting(self):
        self.checkpointing_tasks_list = []

    def calculate_checkpointing_loss(self, model):
        all_checkpointing_loss_list = []
        for meta_task in self.checkpointing_tasks_list:
            with torch.no_grad():
                support_set = meta_task.support_set
                query_set = meta_task.query_set

                while len(support_set) % self.eval_support_batch_size == 1:
                    if random.random() > 0.5:
                        self.eval_support_batch_size += 1
                    else:
                        self.eval_support_batch_size -= 1
                stream_length = None
                while len(query_set) % self.eval_query_batch_size == 1:
                    if random.random() > 0.5:
                        self.eval_query_batch_size += 1
                    else:
                        self.eval_query_batch_size -= 1
                stream_length = support_set.support_y.sum()
                model.clear()
                if self.eval_support_batch_size > len(support_set):
                    model.write(support_set.support_x, support_set.support_y)
                else:
                    support_data_loader = DataLoader(support_set, batch_size=self.eval_support_batch_size)
                    for i, (support_x, support_y) in enumerate(support_data_loader):
                        model.write(support_x, support_y)
                if self.eval_query_batch_size > len(query_set):
                    query_pred = model.query(query_set.query_x,
                                             stream_length.unsqueeze(-1).repeat(query_set.query_x.shape[0], 1))
                    query_y = query_set.query_y
                    one_loss = self.loss_func(query_pred, query_y)
                    loss = one_loss.item()
                else:
                    query_data_loader = DataLoader(query_set, batch_size=self.eval_query_batch_size, shuffle=False)
                    loss_list = []
                    weight_pred_list = []
                    query_y_list = []
                    for j, (query_x, query_y) in enumerate(query_data_loader):
                        query_pred = model.query(query_x, stream_length.unsqueeze(-1).repeat(query_x.shape[0], 1))
                        weight_pred_list.append(query_pred)
                        query_y_list.append(query_y)
                        one_loss = self.loss_func(query_pred, query_y)
                        loss_list.append(one_loss.item())
                    loss = sum(loss_list) / len(loss_list)
                all_checkpointing_loss_list.append(loss)
        return all_checkpointing_loss_list

    def compare_loss_and_best_loss(self, check_loss):
        ratio_list = []
        for i in range(len(check_loss)):
            ratio_list.append(self.best_loss_for_checkpointing[i] / check_loss[i])
        if sum(ratio_list) / len(ratio_list) >= 1:
            print('best     :', self.best_loss_for_checkpointing)
            print('next best:', check_loss)

            self.best_loss_for_checkpointing = check_loss
            return True
        else:
            print('best     :', self.best_loss_for_checkpointing)
            print('not best:', check_loss)
            return False

    def logging(self, model, step, comment=""):
        model.eval()
        assert len(self.checkpointing_tasks_list) != 0, "these is no tasks for checkpointing"
        print(step, 'step logging begin...')
        # if step % self.save_gap == 0:
        torch.save(model, self.model_path)
        if self.best_loss_for_checkpointing is None:
            self.best_loss_for_checkpointing = self.calculate_checkpointing_loss(model)
            print('set model checkpoint at ', step, ' step')
            torch.save(model, self.checkpointing_model_path)
        else:
            check_loss = self.calculate_checkpointing_loss(model)
            if self.compare_loss_and_best_loss(check_loss):
                print('set model checkpoint at ', step, ' step')
                torch.save(model, self.checkpointing_model_path)
        for i in range(len(self.log_file_list)):
            if self.file_header is None:
                for group in self.test_meta_task_group_list:
                    for task in group:
                        task.to_device()
            # warning
            log_file = self.log_file_list[i]
            csv_writer = self.csv_writer_list[i]
            test_meta_task_group = self.test_meta_task_group_list[i]
            group_test_merged_info_dict = self.eval_on_one_group(test_meta_task_group, model)

            # 如果是第一次logging，初始化所有文件的表头
            if self.file_header is None:
                self.file_header = list(group_test_merged_info_dict.keys())
                self.file_header.insert(0, "step")
                for writer in self.csv_writer_list:
                    writer.writerow(self.file_header)
            group_test_merged_info_dict['step'] = step
            row_content = []
            for key in self.file_header:
                row_content.append(group_test_merged_info_dict[key])
            csv_writer.writerow(row_content)
            if (step // self.eval_gap) % self.flush_gap == 0:
                log_file.flush()
        model.train()
        print(step, ' step logging done...')
        return

    def eval_on_one_group(self, test_meta_task_group, model):
        # store all test result for a group of meta tasks
        group_dict_list = []
        for test_meta_task in test_meta_task_group:
            mt_test_info_dict = self.eval_on_one_task(test_meta_task, model)
            group_dict_list.append(mt_test_info_dict)

        group_test_merged_info_dict = {}
        for key in group_dict_list[0].keys():
            value_mean = 0
            for dic in group_dict_list:
                value_mean += dic[key]
            value_mean /= len(group_dict_list)
            group_test_merged_info_dict[key] = value_mean
        return group_test_merged_info_dict

    def eval_on_one_task(self, test_meta_task, model):
        check(test_meta_task)
        get_basic_func = self.get_basic_eval_info_on_one_task_only_weight
        basic_info_dict = get_basic_func(test_meta_task, model)
        get_additional_fuc = self.get_additional_eval_info_on_one_task_only_weight
        additional_info_dict = get_additional_fuc(test_meta_task, model)
        # items = basic_info_dict.items()
        info_dict = dict(list(basic_info_dict.items()) + list(additional_info_dict.items()))
        return info_dict

    def get_sparsity(self, batch_data, dim=(1,)):
        sparsity_data = torch.where(torch.abs(batch_data - 0.0) < 0.00001, 0.0, 1.0)
        res = sparsity_data.sum(dim=dim, keepdim=False)
        return res


    def get_additional_eval_info_on_one_task_only_weight(self, test_meta_task, model):
        info_dict = {}
        with torch.no_grad():
            query_set = test_meta_task.query_set
            query_data_loader = DataLoader(query_set, batch_size=self.eval_support_batch_size, shuffle=True)
            row_sparsity_list = []
            col_sparsity_list = []
            addresses_sparsity_list = []
            addresses_var_list = []
            source_var_list = []
            dest_var_list = []
            edge_embedding_sparsity_list = []
            edge_embedding_var_list = []
            edge_embedding_norm_list = []
            edge_embedding_l1_norm_list = []
            source_refine_vec_norm_list = []
            source_refine_vec_l1_norm_list = []
            dest_refine_vec_norm_list = []
            dest_refine_vec_l1_norm_list = []

        for j, (query_x, query_y) in enumerate(query_data_loader):
            source_embedding, dest_embedding, edge_embedding = model.get_embedding(query_x)
            source_refined, dest_refined = model.get_refined(source_embedding, dest_embedding)
            row_address, col_address = model.get_address(source_refined, dest_refined)

            addresses = row_address.view(row_address.shape[0], -1, 1).matmul(
                col_address.view(col_address.shape[0], 1, -1))
            addresses_var = ((addresses.shape[1] * addresses.shape[2]) * addresses.mean(0)).var().item()
            addresses_sparsity = self.get_sparsity((addresses), dim=(1, 2)).mean().item()
            # addresses_sparsity =
            edge_embedding_sparsity = self.get_sparsity(edge_embedding, dim=1).mean().item()
            edge_embedding_var = edge_embedding.sum(0).var().item()
            edge_embedding_norm = torch.sqrt((edge_embedding.square()).sum(dim=1)).mean().item()
            edge_embedding_l1_norm = edge_embedding.sum(dim=-1).mean().item()
            source_refine_vec_norm = torch.sqrt((source_refined.square()).sum(dim=1)).mean().item()
            dest_refine_vec_norm = torch.sqrt((dest_refined.square()).sum(dim=1)).mean().item()
            source_refine_vec_l1_norm = source_refined.sum(dim=1).mean().item()
            dest_refine_vec_l1_norm = dest_refined.sum(dim=1).mean().item()
            edge_embedding_sparsity_list.append(edge_embedding_sparsity)
            edge_embedding_var_list.append(edge_embedding_var)
            edge_embedding_norm_list.append(edge_embedding_norm)
            edge_embedding_l1_norm_list.append(edge_embedding_l1_norm)
            addresses_var_list.append(addresses_var)
            addresses_sparsity_list.append(addresses_sparsity)
            source_refine_vec_norm_list.append(source_refine_vec_norm)
            dest_refine_vec_norm_list.append(dest_refine_vec_norm)
            source_refine_vec_l1_norm_list.append(source_refine_vec_l1_norm)
            dest_refine_vec_l1_norm_list.append(dest_refine_vec_l1_norm)
            source_var = (row_address.sum(0) * row_address.shape[-1] / row_address.shape[0]).var().item()
            dest_var = (col_address.sum(0) * col_address.shape[-1] / col_address.shape[0]).var().item()
            source_var_list.append(source_var)
            dest_var_list.append(dest_var)
            row_sparsity = self.get_sparsity(row_address).mean().item()
            col_sparsity = self.get_sparsity(col_address).mean().item()
            row_sparsity_list.append(row_sparsity)
            col_sparsity_list.append(col_sparsity)
        row_sparsity = sum(row_sparsity_list) / len(row_sparsity_list)
        col_sparsity = sum(col_sparsity_list) / len(col_sparsity_list)
        source_var = sum(source_var_list) / len(source_var_list)
        dest_var = sum(dest_var_list) / len(dest_var_list)
        twodim_address_sparsity = sum(addresses_sparsity_list) / len(addresses_sparsity_list)
        twodim_address_var = sum(addresses_var_list) / len(addresses_var_list)
        edge_embedding_norm = sum(edge_embedding_norm_list) / len(edge_embedding_norm_list)
        edge_embedding_l1_norm = sum(edge_embedding_l1_norm_list) / len(edge_embedding_l1_norm_list)
        edge_embedding_var_list = sum(edge_embedding_var_list) / len(edge_embedding_var_list)
        edge_embedding_sparsity_list = sum(edge_embedding_sparsity_list) / len(edge_embedding_sparsity_list)
        source_refine_vec_norm = sum(source_refine_vec_norm_list) / len(source_refine_vec_norm_list)
        source_refine_vec_l1_norm = sum(source_refine_vec_l1_norm_list) / len(source_refine_vec_l1_norm_list)
        dest_refine_vec_norm = sum(dest_refine_vec_norm_list) / len(dest_refine_vec_norm_list)
        dest_refine_vec_l1_norm = sum(dest_refine_vec_l1_norm_list) / len(dest_refine_vec_l1_norm_list)

        info_dict['edge_embedding_norm'] = edge_embedding_norm
        info_dict['edge_embedding_l1_norm'] = edge_embedding_l1_norm
        info_dict["edge_embedding_var"] = edge_embedding_var_list
        info_dict['edge_embedding_sparsity'] = edge_embedding_sparsity_list
        info_dict['row_address_var'] = source_var
        info_dict['col_address_var'] = dest_var
        info_dict['row_address_sparsity'] = row_sparsity
        info_dict['col_address_sparsity'] = col_sparsity
        info_dict['twodim_address_sparsity'] = twodim_address_sparsity
        info_dict['twodim_address_var'] = twodim_address_var
        info_dict['source_refined_vec_norm'] = source_refine_vec_norm
        info_dict['source_refined_vec_l1_norm'] = source_refine_vec_l1_norm
        info_dict['dest_refined_vec_norm'] = dest_refine_vec_norm
        info_dict['dest_refined_vec_l1_norm'] = dest_refine_vec_l1_norm

        if hasattr(model, 'attention_matrix'):
            if hasattr(model.attention_matrix, 'scale'):
                info_dict['scale'] = model.attention_matrix.scale.item()
        return info_dict

    def get_basic_eval_info_on_one_task_only_weight(self, test_meta_task, model):
        info_dict = {}
        with torch.no_grad():
            support_set = test_meta_task.support_set
            query_set = test_meta_task.query_set
            while len(support_set) % self.eval_support_batch_size == 1:
                if random.random() > 0.5:
                    self.eval_support_batch_size += 1
                else:
                    self.eval_support_batch_size -= 1
            stream_length = None
            while len(query_set) % self.eval_query_batch_size == 1:
                if random.random() > 0.5:
                    self.eval_query_batch_size += 1
                else:
                    self.eval_query_batch_size -= 1
            stream_length = support_set.support_y.sum()
            model.clear()
            if self.eval_support_batch_size > len(support_set):
                model.write(support_set.support_x, support_set.support_y)
            else:
                support_data_loader = DataLoader(support_set, batch_size=self.eval_support_batch_size)
                for i, (support_x, support_y) in enumerate(support_data_loader):
                    model.write(support_x, support_y)
            if self.eval_query_batch_size > len(query_set):
                query_pred = model.query(query_set.query_x,
                                         stream_length.unsqueeze(-1).repeat(query_set.query_x.shape[0], 1))
                query_y = query_set.query_y
                loss = self.loss_func(query_pred, query_y)
                weight_query_y = query_y.view(-1).cpu()
                weight_pred = query_pred.view(-1).cpu()
                loss_sum = loss
            else:
                query_data_loader = DataLoader(query_set, batch_size=self.eval_query_batch_size, shuffle=False)
                loss_list = []
                weight_pred_list = []
                query_y_list = []
                for j, (query_x, query_y) in enumerate(query_data_loader):
                    query_pred = model.query(query_x, stream_length.unsqueeze(-1).repeat(query_x.shape[0], 1))
                    weight_pred_list.append(query_pred)
                    query_y_list.append(query_y)
                    one_loss = self.loss_func(query_pred, query_y)
                    loss_list.append(one_loss)
                    if torch.isnan(query_y).any():
                        assert True
                weight_pred = torch.cat(weight_pred_list).view(-1).cpu()
                query_y = torch.cat(query_y_list).cpu()
                loss_sum = torch.stack(loss_list)
                loss_sum = loss_sum.sum() / len(loss_list)
                weight_query_y = query_y.view(-1)

            weight_pred_var = weight_pred.var().item()
            weight_label_var = weight_query_y.var().item()
            weight_ARE = torch.mean(torch.abs(weight_pred - weight_query_y) / weight_query_y)
            weight_AAE = torch.mean(torch.abs(weight_pred - weight_query_y))
            info_dict["pred_var"] = weight_pred_var
            info_dict["label_var"] = weight_label_var
            info_dict['ARE'] = weight_ARE.item()
            info_dict['AAE'] = weight_AAE.item()
            info_dict['label_sum'] = weight_query_y.sum().item()
            info_dict['pre_sum'] = weight_pred.sum().item()
            info_dict['item_num'] = weight_pred.shape[0]
            info_dict['loss'] = loss_sum.item()

            sorted_query_y, indices = query_y.view(-1).sort(descending=True)
            sorted_weighted_pred, indices = weight_pred.view(-1).sort(descending=True)
            # ratio as benchmark
            index = int(sorted_query_y.size(dim=0) * 0.2)
            benchmark_weight_pred = sorted_weighted_pred[index].item()
            benchmark_query_y = sorted_query_y[index].item()
            label_query_y = torch.where(query_y > benchmark_query_y, torch.ones_like(query_y),
                                        torch.zeros_like(query_y))
            label_weight_pred = torch.where(weight_pred > benchmark_weight_pred, torch.ones_like(weight_pred),
                                            torch.zeros_like(weight_pred))
            weight_f1 = metrics.f1_score(label_query_y.cpu(), label_weight_pred.cpu())
            info_dict['weight_F1'] = weight_f1
            if hasattr(model, 'attention_matrix'):
                if hasattr(model.attention_matrix, 'scale_value'):
                    info_dict['scale_value'] = model.attention_matrix.scale_value.item()
        return info_dict

    def __init_log_file(self):
        self.log_file_list = []
        self.csv_writer_list = []
        time_str = time.strftime('_%m_%d_%H_%M_%S', time.localtime(time.time()))
        if not os.path.exists(os.path.join(self.project_root,
                                           'LogDir/{}/{}/'.format(self.dataset_name, self.train_comment + time_str))):
            os.makedirs(os.path.join(self.project_root,
                                     'LogDir/{}/{}/'.format(self.dataset_name, self.train_comment + time_str)))
        self.model_path = os.path.join(self.project_root, 'LogDir/{}/{}/model'.format(self.dataset_name,
                                                                                      self.train_comment + time_str))
        self.checkpointing_model_path = os.path.join(self.project_root,
                                                     'LogDir/{}/{}/checkpointing_model'.format(self.dataset_name,
                                                                                               self.train_comment + time_str))

        config_str = json.dumps(self.config)
        config_file = open(os.path.join(self.project_root, 'LogDir/{}/{}/config'.format(self.dataset_name,
                                                                                        self.train_comment + time_str)),
                           'w',
                           newline='', encoding='utf-8')
        config_file.write(config_str)
        config_file.close()
        for meta_task_group_discribe in self.meta_task_discribe_list:
            self.log_file_list.append(open(os.path.join(self.project_root,
                                                        'LogDir/{}/{}/log{}.csv'.format(self.dataset_name,
                                                                                        self.train_comment + time_str,
                                                                                        meta_task_group_discribe)), 'w',
                                           newline='', encoding='utf-8'))
            self.csv_writer_list.append(csv.writer(self.log_file_list[-1]))
        os.makedirs(os.path.join(self.project_root,
                                 'LogDir/{}/{}/test_tasks_{}/'.format(self.dataset_name, self.train_comment + time_str,
                                                                      self.train_comment + time_str)))
        for i in range(len(self.test_meta_task_group_list)):
            os.makedirs(os.path.join(self.project_root,
                                     'LogDir/{}/{}/test_tasks_{}/{}/'.format(self.dataset_name,
                                                                             self.train_comment + time_str,
                                                                             self.train_comment + time_str,
                                                                             self.meta_task_discribe_list[i])))
            for j in range(len(self.test_meta_task_group_list[i])):
                path = os.path.join(self.project_root,
                                    'LogDir/{}/{}/test_tasks_{}/{}/{}.npz'.format(self.dataset_name,
                                                                                  self.train_comment + time_str,
                                                                                  self.train_comment + time_str,
                                                                                  self.meta_task_discribe_list[i],
                                                                                  str(j)))

                self.save_meta_task(self.test_meta_task_group_list[i][j], path)
        print('saving test_meta_task done...')
        shutil.copytree(os.path.join(self.project_root, 'SourceCode'), os.path.join(self.project_root,
                                                                                    'LogDir/{}/{}/SourceCode/'.format(
                                                                                        self.dataset_name,
                                                                                        self.train_comment + time_str)))

        if '\\' in self.py_file_path:
            self.py_file_path = self.py_file_path.replace('\\', "/")
            self.path = self.path.replace('\\', "/")

        os.makedirs(os.path.join(self.project_root,
                                 'LogDir/{}/{}/ExpCode/{}/'.format(self.dataset_name, self.train_comment + time_str,
                                                                   self.path.split('/')[-1])))
        shutil.copy(self.py_file_path, os.path.join(self.project_root,
                                                    'LogDir/{}/{}/ExpCode/{}/'.format(self.dataset_name,
                                                                                      self.train_comment + time_str,
                                                                                      self.path.split('/')[-1])))


    def save_meta_task(self, meta_task, path):
        np.savez(path, support_x=meta_task.support_set.support_x.cpu().numpy(),
                 support_y=meta_task.support_set.support_y.cpu().numpy(),
                 query_x=meta_task.query_set.query_x.cpu().numpy(), query_y=meta_task.query_set.query_y.cpu().numpy())

    def close_all_file(self):
        for file in self.log_file_list:
            file.close()
