import argparse
import copy
import json
import os
from typing import Any

from src.exceptions.ddp_consistency import DDPConsistencyError
from src.train.lines_gan_unconditional_folder_trainer_gan_train_real import Config, LinesGANUnconditionalFolderTrainerGANTrainReal
from torch_utils.distributed.distributed_manager import DistributedManager
from utils.logger.logger import Logger
from utils.utils import get_class_name


def run(config: dict[str, Any], distributed: bool = False, backend: str = 'nccl', retries: int = None) -> None:
    print(
        f'{get_class_name(run)} - config: {config}, distributed: {distributed}, backend: {backend}, retries: {retries}')
    try:
        if distributed:
            DistributedManager.init(backend=backend)
        base_config_object: Config = Config(**config)
        i: int = 0
        while retries is None or i < retries:
            duplicated_config: dict[str, Any] = copy.deepcopy(config)
            duplicated_config['base_folder'] = f'{base_config_object.base_folder}.{i}'
            if i > 0:
                duplicated_config['model']['load_path'] = \
                    f'{base_config_object.base_folder}.{i - 1}/{base_config_object.checkpoint.folder}/last/model.pth'
                duplicated_config['model']['load_keys'] = []
                duplicated_config['model_optimizer']['load_path'] = \
                    f'{base_config_object.base_folder}.{i - 1}/{base_config_object.checkpoint.folder}/last/model_optimizer.pth'
                duplicated_config['model_optimizer']['load_keys'] = []
                for ema_index in range(len(base_config_object.ema)):
                    duplicated_config['ema'][ema_index]['load_path'] = (
                        f'{base_config_object.base_folder}.{i - 1}/{base_config_object.checkpoint.folder}/last/ema/'
                        f'{base_config_object.ema[ema_index].get_str()}.pth'
                    )
                    duplicated_config['ema'][ema_index]['load_keys'] = []

                duplicated_config['discriminator']['discriminator_load_path'] = \
                    f'{base_config_object.base_folder}.{i - 1}/{base_config_object.checkpoint.folder}/last/discriminator.pth'
                duplicated_config['discriminator']['discriminator_load_keys'] = []

                duplicated_config['discriminator']['discriminator_feature_extractor_load_path'] = (
                    f'{base_config_object.base_folder}.{i - 1}/{base_config_object.checkpoint.folder}/last/'
                    f'discriminator_feature_extractor.pth'
                )
                duplicated_config['discriminator']['discriminator_feature_extractor_load_keys'] = []

                duplicated_config['discriminator_optimizer']['load_path'] = \
                    f'{base_config_object.base_folder}.{i - 1}/{base_config_object.checkpoint.folder}/last/discriminator_optimizer.pth'
                duplicated_config['discriminator_optimizer']['load_keys'] = []

                duplicated_config['distributed_sampler_seed'] = base_config_object.distributed_sampler_seed + i
                duplicated_config['distributed_sampler_real_seed'] = \
                    base_config_object.distributed_sampler_real_seed + i
            config_object: Config = Config(**duplicated_config)
            Logger.init(
                base_filename=f'{config_object.base_folder}/{config_object.log_path}.{DistributedManager.rank}'
                if distributed else f'{config_object.base_folder}/{config_object.log_path}'
            )
            Logger.debug(f'config_object: {config_object.model_dump()}')
            trainer: LinesGANUnconditionalFolderTrainerGANTrainReal = LinesGANUnconditionalFolderTrainerGANTrainReal(config_object)
            try:
                trainer.run()
                trainer.shutdown()
                break
            except DDPConsistencyError as e:
                Logger.exception(f'ddp consistency error: {e}', e)
                trainer.shutdown()
                i += 1
    except Exception as e:
        Logger.exception(f'exception: {e}', e)
    finally:
        if distributed:
            DistributedManager.destroy()
        Logger.shutdown()


def parse_args() -> argparse.Namespace:
    parser: argparse.ArgumentParser = argparse.ArgumentParser()
    parser.add_argument('--config_filepath', type=str, required=True)
    parser.add_argument('--distributed', action='store_true')
    parser.add_argument('--backend', type=str, default='nccl')
    parser.add_argument('--retries', type=int, default=None)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    if not os.path.exists(args.config_filepath):
        raise FileNotFoundError(f'config_filepath not found: {args.config_filepath}')
    with open(args.config_filepath, 'r') as file:
        config: dict[str, Any] = json.load(file)
    return config


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict[str, Any] = get_config_from_args(args)
    run(config, args.distributed, args.backend, args.retries)


if __name__ == '__main__':
    main()
