# -*- coding: UTF-8 -*-

import torch
import numpy as np
from torch import nn
import torch
import torch.nn.functional as F
from typing import Union
import os.path as osp
from glob import glob
import os
import time
from torchvision.utils import save_image
import torch.distributed as dist
import math
import inspect
# from torch._six import string_classes
import collections.abc as container_abcs
import warnings
from utils.ops import load_network

string_classes = str


def get_varname(var):
    """
    Gets the name of a variable by searching the call stack for the first frame where the variable is defined.
    
    Args:
        var (Any): The variable to get the name of.
    
    Returns:
        str: The name of the variable.
    """

    for fi in reversed(inspect.stack()):
        names = [var_name for var_name, var_val in fi.frame.f_locals.items() if var_val is var]
        if len(names) > 0:
            return names[0]


def reduce_tensor(rt):
    """
    Reduces a tensor across all processes in a distributed environment.
    
    If the distributed environment is initialized, this function will perform an
    all-reduce operation on the input tensor `rt` and divide the result by the
    world size (number of processes). If the distributed environment is not
    initialized, the function will simply return the input tensor as-is.
    
    Args:
        rt (torch.Tensor): The input tensor to be reduced.
    
    Returns:
        torch.Tensor: The reduced tensor.
    """

    rt = rt.clone()
    if dist.is_initialized():
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        world_size = dist.get_world_size()
    else:
        world_size = 1
    rt /= world_size
    return rt


class LoggerX(object):
    """
    The `LoggerX` class is a utility class for managing logging and checkpointing in a distributed training environment.
    
    It provides the following functionality:
    - Manages the directory structure for saving model checkpoints and images.
    - Allows adding modules to be logged and checkpointed.
    - Provides methods for saving and loading model checkpoints, both regular and "best" checkpoints.
    - Provides a method for logging training statistics to the console and to Weights & Biases (if enabled).
    - Provides a method for saving images.
    
    The class is designed to work in a distributed training environment, with support for multiple GPUs and processes. It ensures that only the main process (rank 0) performs logging and checkpoint saving/loading.
    """


    def __init__(self, save_root, enable_wandb=False, **kwargs):
        """
        Initializes the logger with the specified save root directory and optional Weights & Biases (WandB) integration.
        
        Args:
            save_root (str): The root directory where the model and image files will be saved.
            enable_wandb (bool, optional): Whether to enable Weights & Biases integration. Defaults to False.
            **kwargs: Additional keyword arguments to pass to the WandB initialization.
        
        Raises:
            AssertionError: If PyTorch distributed is not initialized.
        """

        assert dist.is_initialized()
        self.models_save_dir = osp.join(save_root, 'save_models')
        self.images_save_dir = osp.join(save_root, 'save_images')
        os.makedirs(self.models_save_dir, exist_ok=True)
        os.makedirs(self.images_save_dir, exist_ok=True)
        self._modules = []
        self._module_names = []
        self.world_size = dist.get_world_size()
        self.local_rank = dist.get_rank()
        self.enable_wandb = enable_wandb
        if enable_wandb and self.local_rank == 0:
            import wandb
            wandb.init(dir=save_root, settings=wandb.Settings(_disable_stats=True, _disable_meta=True), **kwargs)

    @property
    def modules(self):
        """
        Returns the modules associated with this logger.
        """

        return self._modules

    @property
    def module_names(self):
        """
        Returns the module names associated with this logger.
        """

        return self._module_names

    @modules.setter
    def modules(self, modules):
        """
        Adds the given modules to the internal lists of modules and module names.
        
        Args:
            modules (list): A list of module objects to add.
        """

        for i in range(len(modules)):
            self._modules.append(modules[i])
            self._module_names.append(get_varname(modules[i]))

    def append(self, module, name=None):
        """
        Appends a module to the list of modules and optionally assigns a name to it.
        
        Args:
            module (object): The module to be appended.
            name (str, optional): The name to be assigned to the module. If not provided, the name will be automatically determined using `get_varname()`.
        """

        self._modules.append(module)
        if name is None:
            name = get_varname(module)
        self._module_names.append(name)

    def checkpoints(self, epoch):
        """
        Saves the current state of the modules in the model at the specified epoch, and deletes the previously saved model for the previous epoch.
        
        Args:
            epoch (int): The current epoch number.
        
        Returns:
            None
        """

        if self.local_rank != 0:
            return
        for i in range(len(self.modules)):
            module_name = self.module_names[i]
            module = self.modules[i]
            # delete other saved model
            model_path = osp.join(self.models_save_dir, f'{module_name}-{epoch-1}')
            if osp.exists(model_path):
                os.remove(model_path)
            torch.save(module.state_dict(), osp.join(self.models_save_dir, '{}-{}'.format(module_name, epoch)))

    def best_checkpoints(self, results):
        """
        Saves the best checkpoints for each module based on the provided results.
        
        If the current process is not the main process (i.e. `self.local_rank != 0`), this function returns without doing anything.
        
        For the top 3 results, it generates a string that includes the module name and the corresponding metric value. It then iterates through all the modules, deleting any existing "best" checkpoint for that module and saving the current module state dictionary to a new "best" checkpoint file with the generated string appended to the filename.
        
        Args:
            results (dict): A dictionary mapping module names to their corresponding metric values.
        """

        if self.local_rank != 0:
            return
        output_str = ''
        for i in range(3):
            var_name, var = list(results.items())[i]
            output_str += '-{}-{:2.2f}'.format(var_name[10:], var)
        for i in range(len(self.modules)):
            module_name = self.module_names[i]
            module = self.modules[i]
            # delete other saved model
            model_path = glob(osp.join(self.models_save_dir, f'{module_name}-best*'))
            if model_path == []:
                warnings.warn(f'No best checkpoint found for {module_name}')
                model_path = ''
            else:
                model_path = model_path[0]
            if osp.exists(model_path):
                os.remove(model_path)
            torch.save(module.state_dict(), osp.join(self.models_save_dir, f'{module_name}-best'+output_str))



    def load_checkpoints(self, epoch):
        """
        Loads the saved state dictionaries for the modules in the model.
        
        Args:
            epoch (int): The epoch number to load the checkpoints for.
        
        Returns:
            None
        """

        for i in range(len(self.modules)):
            module_name = self.module_names[i]
            module = self.modules[i]
            module.load_state_dict(load_network(osp.join(self.models_save_dir, '{}-{}'.format(module_name, epoch))))

    def load_best_checkpoints(self):
        """
        Loads the best checkpoint for each module in the model.
        
        For each module in the model, this method searches for the best checkpoint file in the models_save_dir directory, and loads the state dictionary from that checkpoint into the module.
        
        If no best checkpoint is found for a module, a warning is issued and the module is left unmodified.
        """

        for i in range(len(self.modules)):
            module_name = self.module_names[i]
            module = self.modules[i]
            model_path = glob(osp.join(self.models_save_dir, f'{module_name}-best*'))
            if model_path == []:
                warnings.warn(f'No best checkpoint found for {module_name}')
                model_path = ''
            else:
                model_path = model_path[0]
            module.load_state_dict(load_network(model_path))

    def load_restart_checkpoints(self, epoch, restart_dir):
        """
        Loads the restart checkpoints for the specified module and epoch.
        
        Args:
            epoch (int): The epoch to load the checkpoint for.
            restart_dir (str): The directory containing the checkpoint files.
        
        Raises:
            FileNotFoundError: If the checkpoint file for the specified module and epoch is not found.
        """

        for i in range(len(self.modules)):
            module_name = self.module_names[i]
            if module_name == 'byol':
                module = self.modules[i]
                module.load_state_dict(load_network(osp.join('ckpt',restart_dir,'save_models',  '{}-{}'.format(module_name, epoch))))

    def msg(self, stats, step):
        """
        Logs a message with various statistics during training or evaluation.
        
        Args:
            stats (list, tuple, or dict): The statistics to log. If a list or tuple, each element will be logged with its variable name. If a dict, the keys will be used as the variable names.
            step (int): The current training or evaluation step.
        
        Returns:
            None
        """

        output_str = '[{}] {:05d}, '.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), step)
        output_dict = {}
        for i in range(len(stats)):
            if isinstance(stats, (list, tuple)):
                var = stats[i]
                var_name = get_varname(stats[i])
            elif isinstance(stats, dict):
                var_name, var = list(stats.items())[i]
            else:
                raise NotImplementedError
            if isinstance(var, torch.Tensor):
                var = var.detach().mean()
                var = reduce_tensor(var)
                var = var.item()
            output_str += '{} {:2.5f}, '.format(var_name, var)
            output_dict[var_name] = var

        if self.enable_wandb and self.local_rank == 0:
            import wandb
            wandb.log(output_dict, step)

        if self.local_rank == 0:
            print(output_str)

    def msg_str(self, output_str):
        """
        Prints the provided output string to the console if the current process is the main process (local_rank == 0).
        """

        if self.local_rank == 0:
            print(str(output_str))

    def save_image(self, grid_img, n_iter, sample_type):
        """
        Saves an image to the specified directory with the given filename format.
        
        Args:
            grid_img (torch.Tensor): The image tensor to be saved.
            n_iter (int): The current iteration number.
            sample_type (str): The type of sample being saved (e.g. "train", "val", "test").
        
        Returns:
            None
        """

        save_image(grid_img, osp.join(self.images_save_dir,
                                      '{}_{}_{}.jpg'.format(n_iter, self.local_rank, sample_type)),
                   nrow=1)
