# -*- coding: utf-8 -*-
import os
from omegaconf import DictConfig
from typing import Any

from abc import ABC, abstractmethod

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

from common.log import access_log
from hgtft.model.model import HeterogeneousGraphTemporalFusionTransformerTask
from hgtft.utils.data_utils import add_feature_configuration, DatasetData, ModelData, DictDataSet
from hgtft.utils.train_utils import weight_init, load_graph, load_optimizer_and_scheduler


class TrainBase(ABC):
    def __init__(self, config: DictConfig):
        self.config = add_feature_configuration(config)

    def _setup(self, rank: int) -> None:
        """
        :param rank: 当前进程在进程组中的编号
        :return:
        """
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12358' if self.config.get('MASTER_PORT', None) is None else self.config.get('MASTER_PORT')
        dist.init_process_group("nccl", rank=rank, world_size=self.config.world_size)
        torch.cuda.set_device(rank)

    @staticmethod
    def _cleanup():
        """
        销毁进程组
        :return:
        """
        dist.destroy_process_group()

    @abstractmethod
    def _get_model(self, device):
        """
        加载模型
        :param device:
        :return:
        """
        # np.random.seed(42)
        # torch.manual_seed(42)
        # torch.cuda.manual_seed_all(42)
        model = HeterogeneousGraphTemporalFusionTransformerTask(config=self.config, device=device)
        model.apply(weight_init)
        model = model.to(device)
        model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()], find_unused_parameters=True)
        return model

    def _get_dataloader(self, project_id: str, data_type: str) -> DataLoader:
        """
        加载数据集
        :param project_id:
        :param data_type:
        :return:
        """
        project_data = DatasetData.load_dataset(project_id, data_type, self.config.get('dataset_data_path', None))
        ignore_keys = ['time']
        obj_type_name_list = project_data.keys()
        full_data_dict = {}
        for obj_type in obj_type_name_list:
            data_sets = project_data[obj_type]['data_sets']
            for obj_id in project_data[obj_type]['obj_id_list']:
                obj_data_dict = data_sets[obj_id]
                for k, v in obj_data_dict.items():
                    if k not in ignore_keys:
                        full_data_dict[f"{obj_type}+{obj_id}+{k}"] = v
                    else:
                        continue
        dataset = DictDataSet(full_data_dict)
        if data_type == 'train':
            project_name = project_id.split('_')[0]
            if project_name in list(self.config.configuration.optimization.batch_size.keys()):
                batch_size = self.config.configuration.optimization.batch_size.get(project_name)
            else:
                batch_size = self.config.configuration.optimization.batch_size.training
        else:
            batch_size = self.config.configuration.optimization.batch_size.inference
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            sampler=DistributedSampler(dataset, shuffle=True)
        )
        return data_loader

    @abstractmethod
    def _train(self, model, optimizer, scheduler, data_loader, graph_data_dict, epoch_id, rank, device,
               train_project_count):
        pass

    def _persistence_model_optimizer(self, model: Any, optimizer: Any) -> None:
        model_name = self.config.model_name
        ModelData.save_model(model_name, model.state_dict())

        optimizer_name = f'{model_name}_optimizer'
        ModelData.save_model(optimizer_name, optimizer.state_dict())
        access_log.info(f'save model {model_name}')

    def _main_worker(self, rank):
        device = torch.device(f"cuda:{rank}")
        self._setup(rank)
        model = self._get_model(device)
        if rank == 0:
            access_log.info(f'params count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
        optimizer, scheduler = load_optimizer_and_scheduler(model, self.config)

        train_project_count = 0
        for epoch_id in range(self.config.configuration.optimization.max_epochs):
            for project_index, train_project_id in enumerate(self.config.project_train):
                if rank == 0:
                    access_log.info(f'===== epoch: {epoch_id}, project: {train_project_id}-{project_index}=====')
                data_loader = self._get_dataloader(train_project_id, 'train')
                graph_data_dict = load_graph(train_project_id, self.config.graph_relation)
                self._train(model, optimizer, scheduler, data_loader, graph_data_dict, epoch_id, rank, device,
                            train_project_count)
                train_project_count += 1

        dist.barrier()
        self._cleanup()

        if rank == 0:
            self._persistence_model_optimizer(model, optimizer)

    def main(self):
        mp.spawn(self._main_worker, nprocs=self.config.world_size, join=True)
