# -*- coding: utf-8 -*-
import os
import glob
import math
import numpy as np
import itertools
from omegaconf import DictConfig
from typing import Dict, List, Tuple, Generator

import torch
from torch import nn
from torch import optim
import torch.nn.init as init
from torch.optim.lr_scheduler import _LRScheduler
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

from common.log import access_log
from hgtft.model.loss import (get_quantiles_loss_and_q_risk, get_mse_loss, compute_mse_loss, compute_log_likelihood_loss,
                              compute_CRS_loss, compute_FDS_loss, cal_loss_rule_multi_reverse_scale)
from hgtft.utils.data_utils import DictDataSet, GraphData
from .. import ROOT_DIR



class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if torch.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)


class QueueAggregator(object):
    def __init__(self, max_size):
        self._queued_list = []
        self.max_size = max_size

    def append(self, elem):
        self._queued_list.append(elem)
        if len(self._queued_list) > self.max_size:
            self._queued_list.pop(0)

    def get(self):
        return self._queued_list


def weight_init(m):
    """
    Usage:
        model = Model()
        model.apply(weight_init)
    """
    if isinstance(m, GPT2Model):
        pass
    else:
        if isinstance(m, nn.Module):
            if hasattr(m, 'children'):
                for child in m.children():
                    weight_init(child)
        elif isinstance(m, nn.Conv1d):
            init.normal_(m.weight.data)
            if m.bias is not None:
                init.normal_(m.bias.data)
        elif isinstance(m, nn.Conv2d):
            init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                init.normal_(m.bias.data)
        elif isinstance(m, nn.Conv3d):
            init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                init.normal_(m.bias.data)
        elif isinstance(m, nn.ConvTranspose1d):
            init.normal_(m.weight.data)
            if m.bias is not None:
                init.normal_(m.bias.data)
        elif isinstance(m, nn.ConvTranspose2d):
            init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                init.normal_(m.bias.data)
        elif isinstance(m, nn.ConvTranspose3d):
            init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                init.normal_(m.bias.data)
        elif isinstance(m, nn.BatchNorm1d):
            init.normal_(m.weight.data, mean=1, std=0.02)
            init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.normal_(m.weight.data, mean=1, std=0.02)
            init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm3d):
            init.normal_(m.weight.data, mean=1, std=0.02)
            init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.Linear):
            init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                init.normal_(m.bias.data)
        elif isinstance(m, nn.LSTM):
            for param in m.parameters():
                if len(param.shape) >= 2:
                    init.orthogonal_(param.data)
                else:
                    init.normal_(param.data)
        elif isinstance(m, nn.LSTMCell):
            for param in m.parameters():
                if len(param.shape) >= 2:
                    init.orthogonal_(param.data)
                else:
                    init.normal_(param.data)
        elif isinstance(m, nn.GRU):
            for param in m.parameters():
                if len(param.shape) >= 2:
                    init.orthogonal_(param.data)
                else:
                    init.normal_(param.data)
            for names in m._all_weights:
                for name in filter(lambda n: "bias" in n, names):
                    bias = getattr(m, name)
                    n = bias.size(0)
                    bias.data[:n // 3].fill_(-1.)
        elif isinstance(m, nn.GRUCell):
            for param in m.parameters():
                if len(param.shape) >= 2:
                    init.orthogonal_(param.data)
                else:
                    init.normal_(param.data)
        elif isinstance(m, nn.ModuleDict):
            for key, module in m.items():
                module.apply(weight_init)
        elif isinstance(m, nn.ParameterDict):
            for param in m.values():
                init.normal_(param.data, mean=0, std=1)
        elif isinstance(m, nn.Parameter):
            init.normal_(m.data, mean=0, std=1)
        elif isinstance(m, nn.Sequential):
            for sub_module in m.children():
                if isinstance(sub_module, nn.Linear):
                    init.xavier_normal_(sub_module.weight.data)
                    if sub_module.bias is not None:
                        init.normal_(sub_module.bias.data)
        elif isinstance(m, nn.TransformerEncoderLayer):
            for param in m.parameters():
                if len(param.shape) > 1:
                    init.xavier_normal_(param.data)
                else:
                    init.normal_(param.data, mean=0, std=0.02)


def recycle(iterable):
    while True:
        for x in iterable:
            yield x


def get_iter_loaders(data_dict: Dict, loader_config: Dict, obj_type_name_list: List,
                     ignore_keys: List[str] = list) -> Tuple[Generator, int]:
    full_data_dict = {}
    for obj_type in obj_type_name_list:
        if obj_type not in data_dict:
            continue
        data_sets = data_dict[obj_type]['data_sets']
        for obj_id in data_dict[obj_type]['obj_id_list']:
            obj_data_dict = data_sets[obj_id]
            for k, v in obj_data_dict.items():
                if k not in ignore_keys:
                    full_data_dict[f"{obj_type}+{obj_id}+{k}"] = v
                else:
                    continue
    dataset = DictDataSet(full_data_dict)
    loader = torch.utils.data.DataLoader(dataset, **loader_config)
    loader_len = len(loader)
    iter_loader = iter(recycle(loader))
    return iter_loader, loader_len


def process_batch_with_time_series_mask(batch: Dict[str, Dict], mask: Dict[str, Dict], target: Dict[str, Dict],
                                        model: nn.Module, graph: Dict) -> torch.tensor:
    """

    @param batch:
    @param mask:
    @param target:
    @param model:
    @param graph:
    @return:
    """
    batch_outputs = model(batch, graph)
    loss_list = []
    for obj_type, obj_predict_dict in batch_outputs.items():
        for object_id, predict_dict in obj_predict_dict.items():
            for feature_type, predict_data in predict_dict.items():
                target_data = target[obj_type][object_id][feature_type]
                mask_data = mask[obj_type][object_id][feature_type]

                masked_pred = torch.masked_select(predict_data, mask_data).float()
                masked_true = torch.masked_select(target_data, mask_data).float()

                loss_list.append(compute_mse_loss(masked_pred, masked_true))

    total_loss = torch.mean(torch.stack(loss_list), dim=0)
    return total_loss


def process_batch_task(batch_outputs, target: Dict[str, Dict]) -> torch.tensor:
    loss_list = []
    for obj_type, obj_predict_dict in batch_outputs.items():
        labels = {}
        if len(obj_predict_dict) == 0:
            continue
        else:
            labels[obj_type] = {}
            for obj_id in batch_outputs[obj_type].keys():
                labels[obj_type][obj_id] = target[obj_type][obj_id]['temporal_feats_numeric']
        q_loss = get_mse_loss(outputs=batch_outputs, targets=labels)
        loss_list.append(q_loss)
    total_loss = torch.mean(torch.stack(loss_list), dim=0)
    return total_loss


def process_batch_task_CRS(batch_outputs, target: Dict[str, Dict], graph) -> torch.tensor:
    loss_list = []
    for obj_type, obj_predict_dict in batch_outputs.items():
        if len(obj_predict_dict) == 0:
            continue
        else:
            for obj_id, obj_data in obj_predict_dict.items():
                obj_loss_list = []
                # 自身相关性
                # batch
                for batch_index in range(obj_data.shape[0]):
                    # feature
                    for i in itertools.combinations(range(obj_data.shape[2]), 2):
                        feature_index_1, feature_index_2 = i[0], i[1]

                        predict_obj_feature_1 = obj_data[batch_index, :, feature_index_1, :].squeeze()
                        predict_obj_feature_2 = obj_data[batch_index, :, feature_index_2, :].squeeze()

                        target_obj_feature_1 = target.get(obj_type).get(obj_id).get('temporal_feats_numeric')[
                                               batch_index, :, feature_index_1].squeeze()
                        target_obj_feature_2 = target.get(obj_type).get(obj_id).get('temporal_feats_numeric')[
                                               batch_index, :, feature_index_2].squeeze()

                        correlation_loss = compute_CRS_loss(predict_obj_feature_1, predict_obj_feature_2,
                                                            target_obj_feature_1, target_obj_feature_2)
                        obj_loss_list.append(correlation_loss)

                # 图关系的相关性
                for graph_name, graph_data_obj in graph.items():
                    try:
                        from_obj, to_obj = graph_name.split('-')[0], graph_name.split('-')[-1]
                        if to_obj == obj_type and from_obj in batch_outputs.keys():
                            graph_data = graph_data_obj.graph_data
                            if obj_id not in graph_data.columns:
                                continue
                            from_obj_id_list = graph_data.loc[str(obj_id), graph_data.loc[str(obj_id)] == 1].index.tolist()
                            for from_obj_id in from_obj_id_list:
                                from_obj_data = batch_outputs.get(from_obj).get(from_obj_id)
                                for batch_index in range(obj_data.shape[0]):
                                    for i in itertools.product(range(obj_data.shape[2]), range(from_obj_data.shape[2])):
                                        feature_index_1, feature_index_2 = i[0], i[1]
                                        predict_obj_feature_1 = obj_data[batch_index, :, feature_index_1, :].squeeze()
                                        predict_obj_feature_2 = from_obj_data[batch_index, :, feature_index_2, :].squeeze()

                                        target_obj_feature_1 = target.get(obj_type).get(obj_id).get(
                                            'temporal_feats_numeric')[
                                                               batch_index, :, feature_index_1].squeeze()
                                        target_obj_feature_2 = target.get(from_obj).get(from_obj_id).get(
                                            'temporal_feats_numeric')[
                                                               batch_index, :, feature_index_2].squeeze()

                                        correlation_loss = compute_CRS_loss(predict_obj_feature_1, predict_obj_feature_2,
                                                                            target_obj_feature_1, target_obj_feature_2)
                                        obj_loss_list.append(correlation_loss)
                    except Exception as e:
                        access_log.error(f'{obj_type}-{obj_id}-{graph_name}')
                        raise ValueError(e)
                    else:
                        continue
                loss_list.append(torch.mean(torch.stack(obj_loss_list), dim=0))
    total_loss = torch.mean(torch.stack(loss_list), dim=0)
    return total_loss


def process_batch_task_FDS(batch_outputs, target: Dict[str, Dict]) -> torch.tensor:
    loss_list = []
    for obj_type, obj_predict_dict in batch_outputs.items():
        if len(obj_predict_dict) == 0:
            continue
        else:
            for obj_id, obj_data in obj_predict_dict.items():
                obj_loss_list = []
                for batch_index in range(obj_data.shape[0]):
                    for i in range(obj_data.shape[2]):
                        predict_obj_feature = obj_data[batch_index, :, i, :].squeeze()
                        target_obj_feature = target.get(obj_type).get(obj_id).get('temporal_feats_numeric')[
                                               batch_index, :, i].squeeze()
                        similarity_loss = compute_FDS_loss(predict_obj_feature, target_obj_feature)
                        obj_loss_list.append(similarity_loss)
                loss_list.append(torch.mean(torch.stack(obj_loss_list), dim=0))
    total_loss = torch.mean(torch.stack(loss_list), dim=0)
    return total_loss

def _update_loss_history(name: str, loss: torch.Tensor, history_loss_dict: Dict) -> None:
    """
    更新损失的统计，包括损失的计数、平均值和标准差
    @param name:
    @param loss:
    @param history_loss_dict:
    @return:
    """
    with torch.no_grad():
        history_count = history_loss_dict[name]['count']
        history_mean = history_loss_dict[name]['mean']
        history_std = history_loss_dict[name]['std']
        if history_count == 0:
            count = history_count + 1
            mean = loss
            std = torch.sqrt(((loss - mean) ** 2) / count)
        else:
            count = history_count + 1
            mean = (history_mean * history_count + loss) / count
            std = torch.sqrt((history_std ** 2 * history_count + (loss - mean) ** 2) / count)
        history_loss_dict[name]['count'] = count
        history_loss_dict[name]['mean'] = mean
        history_loss_dict[name]['std'] = std


def process_batch_with_time_series_mask_mse_likelihood_loss(
        batch: Dict[str, Dict], mask: Dict[str, Dict], target: Dict[str, Dict], model: nn.Module, graph: Dict,
        config_future_len: int) -> Tuple[torch.tensor, ...]:
    """
    时序遮蔽，计算的损失包括mse和似然损失
    @param batch:
    @param mask:
    @param target:
    @param model:
    @param graph:
    @param config_future_len:
    @return:
    """
    batch_outputs = model(batch, graph)
    mse_loss_list = []
    likelihood_loss_list = []

    for obj_type, obj_predict_dict in batch_outputs.items():
        for object_id, predict_dict in obj_predict_dict.items():
            for feature_type, predict_data in predict_dict.items():
                target_data = target[obj_type][object_id][feature_type]
                mask_data = mask[obj_type][object_id][feature_type]

                masked_pred = torch.masked_select(predict_data, mask_data).float()
                masked_true = torch.masked_select(target_data, mask_data).float()
                mse_loss_list.append(compute_mse_loss(masked_pred, masked_true))

                masked_pred_1 = predict_data[:, config_future_len:, :]
                masked_true_1 = target_data[:, config_future_len:, :]
                likelihood_loss_list.append(compute_log_likelihood_loss(masked_pred_1, masked_true_1))

    mse_total_loss = torch.mean(torch.stack(mse_loss_list), dim=0)
    likelihood_total_loss = torch.mean(torch.stack(likelihood_loss_list), dim=0)
    return mse_total_loss, likelihood_total_loss


def process_batch(batch: Dict, model: nn.Module, graph: Dict, quantiles_tensor: torch.tensor = torch.tensor([0.5]),
                  loss_func: str = 'quantiles') -> torch.tensor:
    """
    计算损失，包含分位数损失和MSE损失
    @param batch:
    @param model:
    @param graph:
    @param quantiles_tensor:
    @param loss_func:
    @return:
    """
    batch_outputs = model(batch, graph)

    loss_list = []
    for obj_type, obj_predict_dict in batch_outputs.items():
        labels = {}
        if len(obj_predict_dict) == 0:
            continue
        else:
            labels[obj_type] = {}
            for obj_id in batch_outputs[obj_type].keys():
                labels[obj_type][obj_id] = batch[obj_type][obj_id]['target']
        if loss_func == 'quantiles':
            # 分位数损失
            q_loss = get_quantiles_loss_and_q_risk(
                outputs=batch_outputs, targets=labels, desired_quantiles=quantiles_tensor
            )
        elif loss_func == 'mse':
            # mse损失
            q_loss = get_mse_loss(outputs=batch_outputs, targets=labels)
        else:
            ValueError(f'not exit {loss_func} loss function')
        loss_list.append(q_loss)
    total_loss = torch.mean(torch.stack(loss_list), dim=0)
    return total_loss


def clear_files(path: str, pattern: str = '*') -> None:
    """
    clear existing files
    @param path:
    @param pattern:
    @return:
    """
    access_log.info(f'del {path.split("/")[-2]}')
    full_pattern = os.path.join(path, pattern)
    files_to_delete = glob.glob(full_pattern)
    for file in files_to_delete:
        os.remove(file)


def load_iter_loaders(config: DictConfig, data: Dict, meta_keys: List, data_type: str) -> Tuple[Generator, int]:
    if data_type == 'train':
        loader_config = {
            'batch_size': config.configuration.optimization.batch_size.training,
            'drop_last': False,
            'shuffle': True
        }
    else:
        loader_config = {
            'batch_size': config.configuration.optimization.batch_size.inference,
            'drop_last': False,
            'shuffle': False
        }
    loader, batch_count = get_iter_loaders(
        data, loader_config, config.obj_type_name_list, ignore_keys=meta_keys
    )
    access_log.info(f'{data_type} batch count: {batch_count}')
    return loader, batch_count


def load_graph(project_name: str, graph_relation: list[Dict], dataset_data_path: str = None) -> Dict:
    if dataset_data_path is None:
        graph_file_name_list = [i.split('.')[0] for i in os.listdir(os.path.join(ROOT_DIR, f'data/{project_name}/graph'))]
    else:
        graph_file_name_list = [i.split('.')[0] for i in os.listdir(os.path.join(dataset_data_path, f'{project_name}/graph'))]
    graph_data_dict = {}
    for item in graph_relation:
        src_type_name = item.get('from')
        edge_name = item.get('edge')
        dst_type_name = item.get('to')
        file_name = f"{src_type_name}-{edge_name}-{dst_type_name}"
        if file_name in graph_file_name_list:
            graph_data_dict[file_name] = GraphData(project_name, src_type_name, dst_type_name, edge_name,
                                                   dataset_data_path=dataset_data_path)
    return graph_data_dict


class WarmupCosineScheduler(_LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, min_lr=0, max_lr=None, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.min_lr = min_lr
        self.max_lr = max_lr if max_lr is not None else optimizer.param_groups[0]['lr']
        super(WarmupCosineScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            access_log.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['initial_lr'] for group in self.optimizer.param_groups]
        elif self.last_epoch < self.warmup_epochs:
            return [
                group['initial_lr'] + (self.max_lr - group['initial_lr']) * self.last_epoch / self.warmup_epochs
                for group in self.optimizer.param_groups
            ]
        elif self.last_epoch <= self.max_epochs:
            cosine_decay = 0.5 * (
                1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))
            )
            decayed = cosine_decay * self.max_lr + (1 - cosine_decay) * self.min_lr
            return [decayed for _ in self.optimizer.param_groups]
        else:
            return [self.min_lr for _ in self.optimizer.param_groups]


def load_optimizer_and_scheduler(model, config):
    if config.configuration.optimization.scheduler.type == 'step_lr':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, list(model.parameters())),
                               lr=config.configuration.optimization.scheduler.learning_rate)
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=config.configuration.optimization.scheduler.learning_rate_step_size,
            gamma=config.configuration.optimization.scheduler.learning_rate_gamma)
        return optimizer, scheduler
    elif config.configuration.optimization.scheduler.type == 'warm_up+cosine_anneal':
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, list(model.parameters())),
                               lr=config.configuration.optimization.scheduler.min_learning_rate)
        scheduler = WarmupCosineScheduler(
            optimizer,
            warmup_epochs=config.configuration.optimization.scheduler.warm_up_epochs,
            max_epochs=config.configuration.optimization.max_epochs,
            max_lr=config.configuration.optimization.scheduler.max_learning_rate,
            min_lr=config.configuration.optimization.scheduler.min_learning_rate)
        return optimizer, scheduler
    else:
        raise f'no func {config.configuration.optimization.scheduler.type}'


class DWALossWeightAdjustment:
    def __init__(self, num_tasks: int, T: int = 2):
        self.num_tasks = num_tasks
        self.T = T
        self.loss_history = np.ones((num_tasks, 2))

    def get_weights(self) -> np.array:
        weights = np.exp(self.loss_history[:, -1] / self.loss_history[:, -2] / self.T)
        return weights / np.sum(weights)

    def update_loss_history(self, loss_list: List[torch.Tensor]):
        for loss_idx, loss in enumerate(loss_list):
            self.loss_history[loss_idx] = np.roll(self.loss_history[loss_idx], -1)
            self.loss_history[loss_idx, -1] = loss.item()


class GradNormLossAdjustment:
    def __init__(self, num_tasks: int, device: torch.device, alpha: float = 0.5):
        self.num_tasks = num_tasks
        self.alpha = alpha
        self.device = device
        self.weights = torch.ones(num_tasks).to(device)

    def update_weights(self, task_losses: List[torch.Tensor], model) -> None:
        grads = []
        for loss in task_losses:
            model.zero_grad()
            loss.backward(retain_graph=True)
            grad_norm = 0
            for param in model.parameters():
                if param.grad is not None:
                    grad_norm += param.grad.norm().item()
            grads.append(grad_norm)
        avg_grad = sum(grads) / len(grads)
        for i in range(len(self.weights)):
            self.weights[i] = self.alpha * self.weights[i] + (1 - self.alpha) * (grads[i] / avg_grad)

    def get_weights(self) -> torch.Tensor:
        return self.weights


def process_batch_rule_check_reverse_scale(batch_outputs, device, src_map, predict_map, batch, mse_loss, relation_accc_zone) -> torch.tensor:
    loss_list = []

    r_tag, loss, multi_check = cal_loss_rule_multi_reverse_scale(batch_outputs, predict_map, device, relation_accc_zone, batch, src_map)
    if r_tag:
        loss_list.append(loss)

    ret = [False, mse_loss, [], multi_check]
    if loss_list:
        check_loss = torch.mean(torch.stack(loss_list), dim=0)
        total_loss = check_loss
        ret = [True, total_loss, [mse_loss, check_loss], multi_check]
    return ret


def get_rel_from_graph(src_graph):
    all_ids = []
    rel = {}
    graph = {k: v for k, v in src_graph.items() if k.startswith('ACCC-') or k.startswith('GeneralZone-')
             or k.startswith('building-') or k.startswith('ACATACATAH-') or k.startswith('ACATACATFC-') or k.startswith(
        'ACATACATFU-')}

    for k, v in graph.items():
        all_ids.extend(v.src_type_id)
        all_ids.extend(v.dst_type_id)
    all_ids = list(set(all_ids))

    dict_terminal_device = {'ACATAH': [], 'ACATFC': [], 'ACATFU': []}
    dict_system_terminal = {'ACATACATAH': [], 'ACATACATFC': [], 'ACATACATFU': []}
    for k, v in graph.items():
        if (not list(set(all_ids).intersection(set(v.src_type_id))) or
                not list(set(all_ids).intersection(set(v.dst_type_id)))):
            continue
        p1, p2 = k.index('-'), k.rfind('-')
        from_type, to_type = k[:p1], k[p2 + 1:]
        if from_type == to_type:
            continue
        if from_type in dict_terminal_device:
            dict_terminal_device[from_type].extend(v.src_type_id)
        if to_type in dict_terminal_device:
            dict_terminal_device[to_type].extend(v.dst_type_id)
        if from_type in dict_system_terminal:
            dict_system_terminal[from_type].extend(v.src_type_id)
        if to_type in dict_system_terminal:
            dict_system_terminal[to_type].extend(v.dst_type_id)

        df = v.graph_data
        if from_type not in rel:
            rel[from_type] = {}

        if_building = False
        if from_type == 'GeneralZone' and to_type == 'building':
            if_building = True
            if to_type not in rel:
                rel[to_type] = {}

        for s1 in v.src_type_id:
            if s1 not in all_ids:
                continue
            if s1 not in rel[from_type]:
                rel[from_type][s1] = {to_type: []}
                if from_type == 'GeneralZone':
                    rel[from_type][s1].update({'all': {}})
            s_df = df[df[s1] == 1]
            arr = []
            if not s_df.empty:
                arr = [item for item in s_df.index.tolist() if item in all_ids]
            if arr:
                rel[from_type][s1][to_type] = arr
                if from_type == 'GeneralZone' and to_type in dict_terminal_device:
                    rel[from_type][s1]['all'].update({term: to_type for term in arr})
                if if_building:
                    for each_b in arr:
                        if each_b not in rel[to_type]:
                            rel[to_type][each_b] = {from_type: []}
                        elif from_type not in rel[to_type][each_b]:
                            rel[to_type][each_b][from_type] = []
                        rel[to_type][each_b][from_type].append(s1)

    for k in dict_terminal_device:
        dict_terminal_device[k] = list(set(dict_terminal_device[k]))
    for k in dict_system_terminal:
        dict_system_terminal[k] = list(set(dict_system_terminal[k]))
    for accc_key, accc_value in rel['ACCC'].items():
        rel['ACCC'][accc_key]['all_eq'] = {}
        for t_type_sys, arr_idx_sys in dict_system_terminal.items():
            if t_type_sys in accc_value and accc_value[t_type_sys]:
                for sys_id in arr_idx_sys:
                    for eq_type in dict_terminal_device:
                        if eq_type in rel[t_type_sys][sys_id]:
                            rel['ACCC'][accc_key]['all_eq'].update(
                                {idx: eq_type for idx in rel[t_type_sys][sys_id][eq_type]})
    return rel
