import glob
import os
import logging
import re
import torch
import yaml

from pathlib import Path
from typing import List, Union, Dict


def load_checkpoint(model: torch.nn.Module,
                    path: Union[Path, str]):
    if torch.cuda.is_available():
        logging.info(' | Checkpoint | loading from checkpoint %s for GPU' % path)
        checkpoint = torch.load(path)
    else:
        logging.info(' | Checkpoint | loading from checkpoint %s for CPU' % path)
        checkpoint = torch.load(path, map_location='cpu')

    # breakpoint()

    state_dict = model.state_dict()
    state_dict.update(checkpoint)
    model.load_state_dict(state_dict)


def save_checkpoint(model: torch.nn.Module,
                    path: Union[Path, str],
                    infos=None):
    logging.debug(' | Checkpoint | save to checkpoint %s' % path)
    if isinstance(model, torch.nn.DataParallel):
        state_dict = model.module.state_dict()
    elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()
    torch.save(state_dict, path)
    info_path = re.sub('.pt$', '.yaml', path)
    if infos is None:
        infos = {}
    with open(info_path, 'w') as fout:
        data = yaml.dump(infos)
        fout.write(data)
        

def save_optimizer(optimizer: torch.optim.Optimizer,
                   path: Union[Path, str]):
    logging.debug(" | Optimizer | save to optimizer %s" % path)
    torch.save(optimizer.state_dict(), path)


def load_optimizer(optimizer: torch.optim.Optimizer,
                   path: Union[Path, str]):
    if torch.cuda.is_available():
        logging.info(' | Optimizer | loading from optimizer %s for GPU' % path)
        optimizer_ckpt = torch.load(path)
    else:
        logging.info(' | Optimizer | loading from optimizert %s for CPU' % path)
        optimizer_ckpt = torch.load(path, map_location='cpu')
    optimizer.load_state_dict(optimizer_ckpt)


def find_checkpoints(out_dir: Union[Path, str], iteration: int = 0) -> List[str]:
    """Find all available checkpoints in a directory.

    The checkpoint filenames have the form: `checkpoint-xxx.pt`
    where xxx is a numerical value.

    Assume you have the following checkpoints in the folder `foo`:

        - checkpoint-1.pt
        - checkpoint-20.pt
        - checkpoint-300.pt
        - checkpoint-4000.pt

    Case 1 (Return all checkpoints)::

      find_checkpoints(out_dir='foo')

    Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
    checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)

        find_checkpoints(out_dir='foo', iteration=20)

    Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
    checkpoint-20.pt, checkpoint-1.pt)::

        find_checkpoints(out_dir='foo', iteration=-20)

    Args:
      out_dir:
        The directory where to search for checkpoints.
      iteration:
        If it is 0, return all available checkpoints.
        If it is positive, return the checkpoints whose iteration number is
        greater than or equal to `iteration`.
        If it is negative, return the checkpoints whose iteration number is
        less than or equal to `-iteration`.
    Returns:
      Return a list of checkpoint filenames, sorted in descending
      order by the numerical value in the filename.
    """
    checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
    pattern = re.compile(r"checkpoint-([0-9]+).pt")
    iter_checkpoints = [
        (int(pattern.search(c).group(1)), c) for c in checkpoints
    ]
    # iter_checkpoints is a list of tuples. Each tuple contains
    # two elements: (iteration_number, checkpoint-iteration_number.pt)

    iter_checkpoints = sorted(
        iter_checkpoints, reverse=True, key=lambda x: x[0]
    )
    if iteration >= 0:
        ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
    else:
        ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]

    return ans


def remove_checkpoints(
    out_dir: Union[Path, str],
    topk: int,
    rank: int = 0,
):
    """Remove checkpoints from the given directory.

    We assume that checkpoint filename has the form `checkpoint-xxx.pt`
    where xxx is a number, representing the number of processed batches
    when saving that checkpoint. We sort checkpoints by filename and keep
    only the `topk` checkpoints with the highest `xxx`.

    Args:
      out_dir:
        The directory containing checkpoints to be removed.
      topk:
        Number of checkpoints to keep.
      rank:
        If using DDP for training, it is the rank of the current node.
        Use 0 if no DDP is used for training.
    """
    assert topk >= 1, topk
    if rank != 0:
        return
    checkpoints = find_checkpoints(out_dir)

    if len(checkpoints) == 0:
        logging.warn(f"No checkpoints found in {out_dir}")
        return

    if len(checkpoints) <= topk:
        return

    to_remove = checkpoints[topk:]
    for c in to_remove:
        os.remove(c)
        config_yaml = str(c)[:-2] + "yaml"
        if os.path.exists(config_yaml):
            os.remove(config_yaml)


def average_checkpoints(
    filenames: List[Union[Path, str]], device: torch.device = torch.device("cpu")
) -> dict:
    """Average a list of checkpoints.

    Args:
      filenames:
        Filenames of the checkpoints to be averaged.
      device:
        Move checkpoints to this device before averaging.
    Returns:
      Return a dict (i.e., state_dict) which is the average of all
      model state dicts contained in the checkpoints.
    """
    n = len(filenames)

    avg = torch.load(filenames[0], map_location=device)

    # NOTE: This is for shared parameters.
    # Identify shared parameters. Two parameters are said to be shared
    # if they have the same data_ptr
    uniqued: Dict[int, str] = dict()

    for k, v in avg.items():
        v_data_ptr = v.data_ptr()
        if v_data_ptr in uniqued:
            continue
        uniqued[v_data_ptr] = k

    uniqued_names = list(uniqued.values())

    for i in range(1, n):
        state_dict = torch.load(filenames[i], map_location=device)
        for k in uniqued_names:
            avg[k] += state_dict[k]

    for k in uniqued_names:
        if avg[k].is_floating_point():
            avg[k] /= n
        else:
            avg[k] //= n

    return avg
