# -*- coding: utf-8 -*-
import gc
import os
import copy
import numpy as np
from omegaconf import DictConfig

import torch
from torch import nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from common.log import access_log
from hgtft.model.model import HeterogeneousGraphTemporalFusionTransformerTask
from hgtft.utils.train_utils import (weight_init, process_batch_task, load_optimizer_and_scheduler, load_graph,
                                     process_batch_task_CRS, process_batch_task_FDS, process_batch_rule_check_reverse_scale,
                                     get_rel_from_graph)
from hgtft.utils.data_utils import make_batch, task_temporal_batch_mask, task_temporal_target_mask, ModelData, init_minmax_normalization_model

from hgtft.train.train_base import TrainBase

os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'


class FinetuningManager(TrainBase):
    def __init__(self, config: DictConfig):
        super().__init__(config)

    def _get_model(self, device):
        model = HeterogeneousGraphTemporalFusionTransformerTask(config=self.config, device=device)
        model.apply(weight_init)

        # 迁移参数
        if self.config.reuse_structure.mode:
            init_state_dict = model.state_dict()
            for reuse_model_item in self.config.reuse_structure.model:
                access_log.info('reuse model: {}'.format(reuse_model_item.name))
                load_model = ModelData.load_model(reuse_model_item.name, device)
                load_model_layers = load_model.keys()
                for layer_name, state in init_state_dict.items():
                    for reuse_structure_layers in reuse_model_item.layers:
                        if (reuse_structure_layers in layer_name) and (layer_name in load_model_layers):
                            init_state_dict[layer_name] = load_model[layer_name]
            model.load_state_dict(init_state_dict)

        model = model.to(device)

        if self.config.freeze_layers:
            for name, param in model.named_parameters():
                if any(layer_name in name for layer_name in self.config.freeze_layers):
                    param.requires_grad = False

        for name, param in model.named_parameters():
            if name.startswith('task_output') and not any(task_name in name for task_name in self.config.use_task):
                param.requires_grad = False

        use_object_type_list = []
        for use_task_name in self.config.use_task:
            for task_name, task_item in self.config.task.items():
                if use_task_name == task_name:
                    use_object_type_list += list(task_item.get('task_feature_info').keys())
                    for task_graph_relation_item in task_item.get('task_graph_relation'):
                        use_object_type_list.append(task_graph_relation_item.get('from'))
                        use_object_type_list.append(task_graph_relation_item.get('to'))
        use_object_type_list = list(set(use_object_type_list))
        use_object_type_list = [f".{obj_type}." for obj_type in use_object_type_list]
        for name, param in model.named_parameters():
            if any(layer_name in name for layer_name in use_object_type_list):
                pass
            else:
                param.requires_grad = False

        access_log.info(f'params count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
        model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
        return model

    @staticmethod
    def _get_input_target_mask_tensor(task_history_feature_list, task_future_feature_list, task_target_feature_list,
                                      temporal_feats_numeric, temporal_feats_categorical, history_len, future_len):
        # input - numeric
        input_history_numeric_tensor_list = []
        for history_feature in temporal_feats_numeric:
            if history_feature in task_history_feature_list:
                input_history_numeric_tensor_list.append(torch.ones((history_len, 1), dtype=torch.bool))
            else:
                input_history_numeric_tensor_list.append(torch.zeros((history_len, 1), dtype=torch.bool))
        input_history_mask_numeric_tensor = torch.concat(input_history_numeric_tensor_list, dim=1)

        input_future_numeric_tensor_list = []
        for future_feature in temporal_feats_numeric:
            if future_feature in task_future_feature_list:
                input_future_numeric_tensor_list.append(torch.ones((future_len, 1), dtype=torch.bool))
            else:
                input_future_numeric_tensor_list.append(torch.zeros((future_len, 1), dtype=torch.bool))
        input_future_mask_numeric_tensor = torch.concat(input_future_numeric_tensor_list, dim=1)
        input_mask_numeric_tensor = torch.concat([input_history_mask_numeric_tensor, input_future_mask_numeric_tensor],
                                                 dim=0)

        # input - numeric
        input_history_categorical_tensor_list = []
        for history_feature in temporal_feats_categorical:
            if history_feature in task_history_feature_list:
                input_history_categorical_tensor_list.append(torch.ones((history_len, 1), dtype=torch.bool))
            else:
                input_history_categorical_tensor_list.append(torch.zeros((history_len, 1), dtype=torch.bool))
        input_history_mask_categorical_tensor = torch.concat(input_history_categorical_tensor_list, dim=1)

        input_future_categorical_tensor_list = []
        for future_feature in temporal_feats_categorical:
            if future_feature in task_future_feature_list:
                input_future_categorical_tensor_list.append(torch.ones((future_len, 1), dtype=torch.bool))
            else:
                input_future_categorical_tensor_list.append(torch.zeros((future_len, 1), dtype=torch.bool))
        input_future_mask_categorical_tensor = torch.concat(input_future_categorical_tensor_list, dim=1)
        input_mask_categorical_tensor = torch.concat(
            [input_history_mask_categorical_tensor, input_future_mask_categorical_tensor], dim=0)

        # target
        target_history_mask_tensor = torch.zeros((history_len, len(temporal_feats_numeric)), dtype=torch.bool)
        target_future_tensor_list = []
        for future_feature in temporal_feats_numeric:
            if future_feature in task_target_feature_list:
                target_future_tensor_list.append(torch.ones((future_len, 1), dtype=torch.bool))
            else:
                target_future_tensor_list.append(torch.zeros((future_len, 1), dtype=torch.bool))
        target_future_mask_tensor = torch.concat(target_future_tensor_list, dim=1)
        target_mask_tensor = torch.concat([target_history_mask_tensor, target_future_mask_tensor], dim=0)

        return input_mask_numeric_tensor, input_mask_categorical_tensor, target_mask_tensor

    def _get_mask_tensor(self):
        history_len = self.config.configuration.model.history_len
        future_len = self.config.configuration.model.future_len
        mask_result_dict = {}
        for task_name, task_item_data in self.config.task.items():
            task_feature_info = task_item_data.task_feature_info
            mask_result_dict[task_name] = {}
            for obj_type, obj_task_feature_dict in task_feature_info.items():
                # feature info
                temporal_feats_numeric = self.config.sample_data[obj_type]['feature_map']['temporal_feats_numeric']
                temporal_feats_categorical = self.config.sample_data[obj_type]['feature_map']['temporal_feats_categorical']
                # task feature
                task_future_feature_list = obj_task_feature_dict['future']
                task_history_feature_list = obj_task_feature_dict['history']
                task_target_feature_list = obj_task_feature_dict['target']
                input_mask_numeric_tensor, input_mask_categorical_tensor, target_mask_tensor = (
                    self._get_input_target_mask_tensor(task_history_feature_list, task_future_feature_list,
                                                       task_target_feature_list, temporal_feats_numeric,
                                                       temporal_feats_categorical, history_len, future_len
                                                       )
                )
                mask_result_dict[task_name][obj_type] = {
                    "input_numeric": input_mask_numeric_tensor,
                    "input_categorical": input_mask_categorical_tensor,
                    "target": target_mask_tensor
                }
        return mask_result_dict

    def _train(self, model, optimizer, scheduler, data_loader, full_graph_data_dict, epoch_id, rank, device,
               train_project_count):
        IS_CUDA = torch.cuda.is_available()
        data_loader.sampler.set_epoch(epoch_id)
        train_result_loss_list = []
        scaler = torch.cuda.amp.GradScaler(init_scale=2.0**16, growth_interval=1000)  # 定义梯度缩放器
        task_mask_tensor_dict = self._get_mask_tensor()

        model.train()
        for batch_idx, next_dict in enumerate(data_loader):
            full_task_loss_list = []
            for task_name, task_item in self.config.task.items():
                task_feature_info_dict = task_item.task_feature_info
                if task_name not in self.config.use_task:
                    continue
                graph_data_dict = {}
                for item in task_item.task_graph_relation:
                    file_name = "{}-{}-{}".format(item.get('from'), item.get('edge'), item.get('to'))
                    if file_name in full_graph_data_dict:
                        graph_data_dict[file_name] = full_graph_data_dict.get(file_name)
                batch_input, batch_target = dict(), dict()
                for key, values in next_dict.items():
                    key_list = key.split('+')
                    obj_type, obj_id, data_type = key_list[0], key_list[1], key_list[2]
                    if obj_type not in list(task_feature_info_dict.keys()):
                        continue
                    if data_type == 'temporal_feats_numeric':
                        input_value_mask = task_mask_tensor_dict[task_name][obj_type]['input_numeric']
                        X = task_temporal_batch_mask(values, input_value_mask)
                        batch_input = make_batch(batch_input, obj_type, obj_id, data_type, X, device=device,
                                                 is_cuda=IS_CUDA)
                        task_target_len = len(task_item.task_feature_info[obj_type]['target'])
                        if task_target_len > 0:
                            target_value_mask = task_mask_tensor_dict[task_name][obj_type]['target']
                            target = task_temporal_target_mask(values, target_value_mask, task_target_len)
                            batch_target = make_batch(batch_target, obj_type, obj_id, data_type, target, device=device,
                                                      is_cuda=IS_CUDA)
                    elif data_type == 'temporal_feats_categorical':
                        value_mask = task_mask_tensor_dict[task_name][obj_type]['input_categorical']
                        X = task_temporal_batch_mask(values, value_mask)
                        batch_input = make_batch(batch_input, obj_type, obj_id, data_type, X, device=device,
                                                 is_cuda=IS_CUDA)
                    else:
                        batch_input = make_batch(batch_input, obj_type, obj_id, data_type, values, device=device,
                                                 is_cuda=IS_CUDA)
                if len(batch_target) == 0:
                    continue
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    batch_outputs = model(batch_input, graph_data_dict, task_name)
                    mse_task_loss = process_batch_task(batch_outputs, batch_target)

                    inverse_batch_outputs = self._inverse_transform(batch_outputs, task_item)
                    batch_input = self.input_inverse_transform(batch_input, self.config.sample_data)
                    relation_accc_zone = get_rel_from_graph(full_graph_data_dict)
                    r_tag, check_loss, pair_ret, dict_check = \
                        process_batch_rule_check_reverse_scale(inverse_batch_outputs, device, self.config.sample_data,
                                                               task_feature_info_dict, batch_input, mse_task_loss,
                                                               relation_accc_zone)
                    crs_task_loss = process_batch_task_CRS(batch_outputs, batch_target, graph_data_dict)
                    fds_task_loss = process_batch_task_FDS(batch_outputs, batch_target)

                    full_task_loss_list.append(
                        mse_task_loss * 0.5 + crs_task_loss * 0.1 + fds_task_loss * 0.2 + check_loss * 0.2)

                del batch_input
                del batch_target
                gc.collect()
                if IS_CUDA:
                    torch.cuda.empty_cache()
            if len(full_task_loss_list) == 0:
                continue
            loss = torch.mean(torch.stack(full_task_loss_list), dim=0)
            scaler.scale(loss).backward()

            if self.config.configuration.optimization.max_grad_norm > 0:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), self.config.configuration.optimization.max_grad_norm)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            train_result_loss_list.append(loss.item())

            if rank == 0:
                if (batch_idx % self.config.configuration.optimization.log_interval == 0) or (
                        batch_idx == len(data_loader) - 1):
                    access_log.info(
                        f"Epoch: {epoch_id}, Total Batch: {len(data_loader)}, Batch Index: {batch_idx}, "
                        f"Train Loss = {loss.item()}, lr = {scheduler.state_dict().get('_last_lr')[0]}"
                    )

        if rank == 0 and self.config.wandb.mode:
            train_result = round(np.mean(train_result_loss_list), 5)
            access_log.info(f"train_project_loss: {train_result}")

    def _inverse_transform(self, batch_outputs, task_item):
        try:
            out_put_data_dict = {}
            scaler_model_dict = init_minmax_normalization_model(self.config.dataset_data_path)
            task_feature_info_dict = task_item.get('task_feature_info')
            for object_type, obj_data in batch_outputs.items():
                out_put_data_dict.setdefault(object_type, {})
                if len(task_feature_info_dict.get(object_type, {}).get('target', [])) > 0:
                    for obj_id, predict_data in obj_data.items():
                        out_put_data_dict[object_type].setdefault(obj_id, {})
                        device = predict_data.device
                        for feature_index, feature_name in enumerate(task_feature_info_dict[object_type]['target']):
                            scaler_model_list = scaler_model_dict.get(object_type, {}).get(feature_name, [])
                            scaler_model_list_copy = copy.deepcopy(scaler_model_list)
                            scaler_model_list_copy.reverse()
                            feature_data = predict_data[:, :, feature_index, 0]
                            for scaler_model in scaler_model_list_copy:
                                min_val = torch.tensor(scaler_model.min_, device=device, dtype=predict_data.dtype)
                                max_val = torch.tensor(scaler_model.max_, device=device, dtype=predict_data.dtype)
                                range_val = max_val - min_val
                                feature_data = feature_data * range_val + min_val
                            out_put_data_dict[object_type][obj_id][feature_name] = feature_data

            return out_put_data_dict
        except Exception as e:
            access_log.error(e)
            raise

    def input_inverse_transform(self, batch_input, feature_all):
        try:
            out_put_data_dict = {}
            scaler_model_dict = init_minmax_normalization_model(self.config.dataset_data_path)
            for object_type, obj_data in batch_input.items():
                feature_obj = feature_all[object_type]['feature_map']
                out_put_data_dict.setdefault(object_type, {})
                for obj_id, predict_data in obj_data.items():
                    out_put_data_dict[object_type].setdefault(obj_id, {})
                    device = next(iter(predict_data.values())).device
                    def inverse_normalize(feature_data, feature_name):
                        scaler_model_list = scaler_model_dict.get(object_type, {}).get(feature_name, [])
                        scaler_model_list_copy = copy.deepcopy(scaler_model_list)
                        scaler_model_list_copy.reverse()
                        for scaler_model in scaler_model_list_copy:
                            min_val = torch.tensor(scaler_model.min_, device=device, dtype=feature_data.dtype)
                            max_val = torch.tensor(scaler_model.max_, device=device, dtype=feature_data.dtype)
                            range_val = max_val - min_val
                            feature_data = feature_data * range_val + min_val
                        return feature_data

                    #  temporal_feats_numeric
                    if 'temporal_feats_numeric' in predict_data:
                        temporal_data = predict_data['temporal_feats_numeric']
                        for feature_index, feature_name in enumerate(feature_obj['temporal_feats_numeric']):
                            feature_data = temporal_data[:, :, feature_index]
                            out_put_data_dict[object_type][obj_id][feature_name] = inverse_normalize(feature_data, feature_name)
                    # static_feats_numeric
                    if 'static_feats_numeric' in predict_data:
                        static_data = predict_data['static_feats_numeric']
                        for feature_index, feature_name in enumerate(feature_obj['static_feats_numeric']):
                            feature_data = static_data[:, feature_index]
                            out_put_data_dict[object_type][obj_id][feature_name] = inverse_normalize(feature_data, feature_name)
            return out_put_data_dict
        except Exception as e:
            access_log.error(e)
            raise

    def _main_worker(self, rank):
        device = torch.device(f"cuda:{rank}")

        self._setup(rank)
        model = self._get_model(device)
        optimizer, scheduler = load_optimizer_and_scheduler(model, self.config)

        train_project_count = 0
        for epoch_id in range(self.config.configuration.optimization.max_epochs):
            for project_index, train_project_id in enumerate(self.config.project_train):
                if rank == 0:
                    access_log.info(f'===== epoch: {epoch_id}, project: {train_project_id}-{project_index} =====')
                data_loader = self._get_dataloader(train_project_id, 'train')
                graph_data_dict = load_graph(train_project_id, self.config.graph_relation,
                                             dataset_data_path=self.config.get('dataset_data_path', None))
                self._train(model, optimizer, scheduler, data_loader, graph_data_dict, epoch_id, rank, device,
                            train_project_count)
                train_project_count += 1

                scheduler.step()

        if rank == 0:
            self._persistence_model_optimizer(model, optimizer)
        dist.barrier()
        self._cleanup()
