import os.path
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Optional, Any, Callable

import torch
from pydantic import BaseModel, ConfigDict, Field

from src.configs.checkpoint import CheckpointConfig
from src.configs.dataset import DatasetConfig
from src.configs.report import ReportConfig
from src.exceptions.ddp_consistency import DDPConsistencyError
from src.utils.save import save_state_dict
from src.utils.utils import report_values, free_all_memory
from torch_utils.distributed.distributed_manager import DistributedManager
from torch_utils.distributed.utils import check_ddp_consistency, check_ddp_params_consistency, \
    check_ddp_buffers_consistency
from torch_utils.tensorboard.tensorboard_logger import TensorboardLogger
from utils.base_object import BaseObject
from utils.logger.logger import Logger
from utils.utils import get_class_name


class Config(BaseModel):
    model_config: ConfigDict = ConfigDict(
        frozen=True, strict=True, validate_assignment=True, extra='forbid', protected_namespaces=())
    base_folder: str = Field()
    log_path: Optional[str] = Field(default='log/log.log')
    tensorboard_log_dir: Optional[str] = Field(default='tensorboard/logs')
    train_steps: int = Field()
    batch_size: int = Field()
    batch_repeats: int = Field(default=1)
    learning_rate_batch_repeats: bool = Field(default=False)
    data_loader_workers: int = Field(default=0)
    check_consistency_before_save: bool = Field(default=True)
    load_after_save: bool = Field(default=True)
    validate_ddp_consistency_steps: Optional[int] = Field(default=None)
    free_all_memory_steps: Optional[int] = Field(default=None)
    free_memory_every_sub_step: bool = Field(default=False)
    dataset: DatasetConfig = Field()
    report: ReportConfig = Field()
    checkpoint: CheckpointConfig = Field()


class StepStore(BaseModel):
    model_config: ConfigDict = ConfigDict(
        strict=True, validate_assignment=True, extra='forbid', protected_namespaces=())
    step: int = Field()

    def get_step_values(self) -> dict[str, Any]:
        return {
            'step': self.step
        }


C: 'C' = TypeVar('C', bound=Config)
S: 'S' = TypeVar('S', bound=StepStore)


class Trainer(BaseObject, Generic[C, S], ABC):
    def __init__(self, config: C) -> None:
        Logger.debug(f'{get_class_name(Trainer.__init__)} start')
        self.config: C = config
        self.tensorboard_logger: TensorboardLogger = TensorboardLogger(
            log_dir=f'{config.base_folder}/{config.tensorboard_log_dir}.{DistributedManager.rank}'
            if DistributedManager.initialized else f'{config.base_folder}/{config.tensorboard_log_dir}'
        )
        Logger.debug(f'{get_class_name(Trainer.__init__)} end')

    def start_callback(self) -> None:
        Logger.debug(f'{get_class_name(Trainer.start_callback)}')
        Logger.debug(f'distributed initialized: {DistributedManager.initialized}')
        Logger.debug(f'distributed world size: {DistributedManager.world_size}')
        Logger.debug(f'distributed local world size: {DistributedManager.local_world_size}')
        Logger.debug(f'distributed rank: {DistributedManager.rank}')
        Logger.debug(f'distributed local rank: {DistributedManager.local_rank}')

    def end_callback(self) -> None:
        Logger.debug(f'{get_class_name(Trainer.end_callback)}')

    @abstractmethod
    def train_step(self, step: int) -> StepStore:
        raise NotImplementedError('train_step method must be implemented')

    @abstractmethod
    def callbacks(self) -> list[Callable[[S], None]]:
        raise NotImplementedError('callbacks method must be implemented')

    def report_values(self, step: int, values: dict[str, any]) -> None:
        report_values(step, values, self.tensorboard_logger)

    def report_train_loss(self, step_store: StepStore) -> None:
        if self.config.report.train_steps is not None and (step_store.step + 1) % self.config.report.train_steps == 0:
            Logger.debug(f'{get_class_name(Trainer.report_train_loss)} {step_store.step + 1} start')
            values: dict[str, Any] = self.reset_values()
            self.report_values(step_store.step + 1, values)
            Logger.debug(f'{get_class_name(Trainer.report_train_loss)} {step_store.step + 1} end')

    def get_state_dicts(self, step: int) -> dict[str, dict]:
        return {
            'step': step,
            'config': self.config.model_dump()
        }

    @abstractmethod
    def check_consistency(self, func: Callable[[torch.nn.Module], bool]) -> bool:
        raise NotImplementedError(f'{get_class_name(Trainer.check_consistency)} method must be implemented')

    def check_ddp_consistency(self) -> bool:
        return self.check_consistency(check_ddp_consistency)

    def check_ddp_params_consistency(self) -> bool:
        return self.check_consistency(check_ddp_params_consistency)

    def check_ddp_buffers_consistency(self) -> bool:
        return self.check_consistency(check_ddp_buffers_consistency)

    def validate_consistency(self, step) -> None:
        Logger.debug(f'{get_class_name(Trainer.validate_consistency)} {step} start')
        free_all_memory()
        if DistributedManager.initialized:
            torch.distributed.barrier()
        if not self.check_ddp_params_consistency():
            Logger.error(f'ddp params consistency check failed: {step}, loading latest checkpoint')
            if DistributedManager.initialized:
                torch.distributed.barrier()
            self.load_from_folder(f'{self.config.base_folder}/{self.config.checkpoint.folder}/last')
        if DistributedManager.initialized:
            torch.distributed.barrier()
        if not self.check_ddp_buffers_consistency():
            Logger.warning(f'ddp buffers consistency check failed: {step}')
        if DistributedManager.initialized:
            torch.distributed.barrier()
        free_all_memory()
        Logger.debug(f'{get_class_name(Trainer.validate_consistency)} {step} end')

    def save_checkpoint(
            self,
            folder: str,
            step: int,
            consistency_check: bool = None,
            load_checkpoint: bool = None
    ) -> None:
        Logger.debug(
            f'{get_class_name(Trainer.save_checkpoint)} start - folder: {folder}, step: {step}, '
            f'consistency_check: {consistency_check}, load_checkpoint: {load_checkpoint}'
        )
        free_all_memory()
        consistency_check: bool = \
            self.config.check_consistency_before_save if consistency_check is None else consistency_check
        load_checkpoint: bool = self.config.load_after_save if load_checkpoint is None else load_checkpoint
        if consistency_check:
            self.validate_consistency(step)
        free_all_memory()
        if DistributedManager.is_main():
            Logger.debug(f'{get_class_name(Trainer.save_checkpoint)} {step} start')
            for name, state_dict in self.get_state_dicts(step).items():
                if os.path.dirname(f'{folder}/{name}.pth') != '':
                    os.makedirs(os.path.dirname(f'{folder}/{name}.pth'), exist_ok=True)
                save_state_dict(f'{folder}/{name}.pth', state_dict)
        free_all_memory()
        if DistributedManager.initialized:
            torch.distributed.barrier()
        if load_checkpoint:
            self.load_from_folder(folder)
        free_all_memory()
        if not self.check_ddp_consistency():
            Logger.error(f'ddp consistency check failed after loading: {step}')
            raise DDPConsistencyError('ddp consistency check failed after loading')
        free_all_memory()
        Logger.debug(f'{get_class_name(Trainer.save_checkpoint)} {step} end')

    def save_step_checkpoint(
            self, step: int, consistency_check: bool = None, load_checkpoint: bool = None) -> None:
        Logger.debug(
            f'{get_class_name(Trainer.save_step_checkpoint)} {step} start - step: {step}, '
            f'consistency_check: {consistency_check}, load_checkpoint: {load_checkpoint}'
        )
        self.save_checkpoint(
            f'{self.config.base_folder}/{self.config.checkpoint.folder}/{step}', step,
            consistency_check=consistency_check, load_checkpoint=load_checkpoint
        )
        Logger.debug(f'{get_class_name(Trainer.save_step_checkpoint)} {step} end')

    def save_step_checkpoint_steps_callback(self, step_store: StepStore) -> None:
        if self.config.checkpoint.folder is not None and self.config.checkpoint.save_steps is not None and \
                (step_store.step + 1) % self.config.checkpoint.save_steps == 0:
            Logger.debug(f'{get_class_name(Trainer.save_step_checkpoint_steps_callback)} {step_store.step + 1} start')
            self.save_step_checkpoint(step_store.step + 1)
            Logger.debug(f'{get_class_name(Trainer.save_step_checkpoint_steps_callback)} {step_store.step + 1} end')

    def save_last_checkpoint(
            self, step: int, consistency_check: bool = None, load_checkpoint: bool = None) -> None:
        Logger.debug(
            f'{get_class_name(Trainer.save_last_checkpoint)} {step} start - step: {step}, '
            f'consistency_check: {consistency_check}, load_checkpoint: {load_checkpoint}'
        )
        self.save_checkpoint(
            f'{self.config.base_folder}/{self.config.checkpoint.folder}/last', step,
            consistency_check=consistency_check, load_checkpoint=load_checkpoint
        )
        Logger.debug(f'{get_class_name(Trainer.save_last_checkpoint)} {step} end')

    def save_last_checkpoint_steps_callback(self, step_store: StepStore) -> None:
        if self.config.checkpoint.folder is not None and self.config.checkpoint.last_steps is not None and \
                (step_store.step + 1) % self.config.checkpoint.last_steps == 0:
            Logger.debug(f'{get_class_name(Trainer.save_last_checkpoint_steps_callback)} {step_store.step + 1} start')
            self.save_last_checkpoint(step_store.step + 1)
            Logger.debug(f'{get_class_name(Trainer.save_last_checkpoint_steps_callback)} {step_store.step + 1} end')

    def validate_ddp_consistency_steps_callback(self, step_store: StepStore) -> None:
        if self.config.validate_ddp_consistency_steps is not None and \
                (step_store.step + 1) % self.config.validate_ddp_consistency_steps == 0:
            Logger.debug(
                f'{get_class_name(Trainer.validate_ddp_consistency_steps_callback)} {step_store.step + 1} start')
            self.validate_consistency(step_store.step + 1)
            Logger.debug(f'{get_class_name(Trainer.validate_ddp_consistency_steps_callback)} {step_store.step + 1} end')

    def free_all_memory_callback(self, step_store: StepStore) -> None:
        if self.config.free_all_memory_steps is not None and \
                (step_store.step + 1) % self.config.free_all_memory_steps == 0:
            Logger.debug(f'{get_class_name(Trainer.free_all_memory_callback)} {step_store.step + 1} start')
            free_all_memory()
            Logger.debug(f'{get_class_name(Trainer.free_all_memory_callback)} {step_store.step + 1} end')

    @abstractmethod
    def load_from_folder(self, folder: str) -> None:
        raise NotImplementedError('load_from_folder method must be implemented')

    def load_last_checkpoint(self) -> None:
        Logger.debug(f'{get_class_name(Trainer.load_last_checkpoint)} start')
        self.load_from_folder(f'{self.config.base_folder}/{self.config.checkpoint.folder}/last')
        Logger.debug(f'{get_class_name(Trainer.load_last_checkpoint)} end')

    def load_step_checkpoint(self, step: int) -> None:
        Logger.debug(f'{get_class_name(Trainer.load_last_checkpoint)} start')
        self.load_from_folder(f'{self.config.base_folder}/{self.config.checkpoint.folder}/{step}')
        Logger.debug(f'{get_class_name(Trainer.load_last_checkpoint)} end')

    def reset_values(self) -> dict[str, Any]:
        Logger.debug(f'{get_class_name(self.reset_values)}')
        return {}

    def train_loop(self) -> None:
        Logger.debug(f'{get_class_name(Trainer.train_loop)} start')
        callbacks: list[Callable[[S], None]] = self.callbacks()
        Logger.debug(f'callbacks: {list(map(get_class_name, callbacks))}')
        for step in range(self.config.train_steps):
            step_store: S = self.train_step(step)
            for step_callback in callbacks:
                step_callback(step_store)
        Logger.debug(f'{get_class_name(Trainer.train_loop)} end')

    def run(self) -> None:
        Logger.debug(f'{get_class_name(Trainer.run)} start')
        if DistributedManager.initialized:
            torch.distributed.barrier()
        self.start_callback()
        self.save_last_checkpoint(0, consistency_check=False)
        self.train_loop()
        self.save_last_checkpoint(self.config.train_steps)
        self.end_callback()
        Logger.debug(f'{get_class_name(Trainer.run)} end')

    def shutdown(self) -> None:
        Logger.debug(f'{get_class_name(Trainer.shutdown)} start')
        self.tensorboard_logger.close()
        Logger.debug(f'{get_class_name(Trainer.shutdown)} end')

    def __del__(self) -> None:
        self.shutdown()
