# Copyright (c) 2024, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from https://github.com/awaelchli/pytorch-lightning-snippets/blob/master/checkpoint/peek.py

import code
import os
import re
from argparse import ArgumentParser, Namespace
from collections.abc import Mapping, Sequence
from pathlib import Path

import torch


class COLORS:
    BLUE = "\033[94m"
    CYAN = "\033[96m"
    GREEN = "\033[92m"
    RED = "\033[31m"
    YELLOW = "\033[33m"
    MAGENTA = "\033[35m"
    WHITE = "\033[37m"
    UNDERLINE = "\033[4m"
    END = "\033[0m"


PRIMITIVE_TYPES = (int, float, bool, str, type)


def natural_sort(l):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", str(key))]
    return sorted(l, key=alphanum_key)


def sizeof_fmt(num, suffix="B"):
    for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
        if abs(num) < 1024.0:
            return "%3.1f%s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f%s%s" % (num, "Yi", suffix)


def pretty_print(contents: dict):
    """Prints a nice summary of the top-level contents in a checkpoint dictionary."""
    # assert len(contents) > 0
    if len(contents) <= 0:
        return
    col_size = max(len(str(k)) for k in contents)
    for k, v in sorted(contents.items()):
        key_length = len(str(k))
        line = " " * (col_size - key_length)
        line += f"{k}: {COLORS.BLUE}{type(v).__name__}{COLORS.END}"
        if isinstance(v, dict):
            pretty_print(v)
        elif isinstance(v, PRIMITIVE_TYPES):
            line += f" = "
            line += f"{COLORS.CYAN}{repr(v)}{COLORS.END}"
        elif isinstance(v, Sequence):
            line += ", "
            line += f"{COLORS.CYAN}len={len(v)}{COLORS.END}"
        elif isinstance(v, torch.Tensor):
            if v.ndimension() in (0, 1) and v.numel() == 1:
                line += f" = "
                line += f"{COLORS.CYAN}{v.item()}{COLORS.END}"
            else:
                line += ", "
                line += f"{COLORS.CYAN}shape={list(v.shape)}{COLORS.END}"
                line += ", "
                line += f"{COLORS.CYAN}dtype={v.dtype}{COLORS.END}"
            line += (
                ", "
                + f"{COLORS.CYAN}size={sizeof_fmt(v.nelement() * v.element_size())}{COLORS.END}"
            )
        print(line)


def common_entries(*dcts):
    if not dcts:
        return
    for i in set(dcts[0]).intersection(*dcts[1:]):
        yield (i,) + tuple(d[i] for d in dcts)


def pretty_print_double(contents1: dict, contents2: dict, args):
    """Prints a nice summary of the top-level contents in a checkpoint dictionary."""
    if len(contents1) <= 0:
        return
    col_size = max(
        max(len(str(k)) for k in contents1), max(len(str(k)) for k in contents2)
    )
    common_keys = list(contents1.keys() & contents2.keys())
    uncommon_keys_1 = [i for i in contents2.keys() if i not in common_keys]
    uncommon_keys_2 = [i for i in contents1.keys() if i not in common_keys]
    diffs_found = False
    if uncommon_keys_1 + uncommon_keys_2:
        diffs_found = True
        if uncommon_keys_1:
            print(
                f"{COLORS.RED}{len(uncommon_keys_1)} key(s) found in ckpt 1 that isn't present in ckpt 2:{COLORS.END} \n\t{COLORS.BLUE}{' '.join(uncommon_keys_1)}{COLORS.END}"
            )
        if uncommon_keys_2:
            print(
                f"{COLORS.RED}{len(uncommon_keys_2)} key(s) found in ckpt 2 that isn't present in ckpt 1:{COLORS.END} \n\t{COLORS.BLUE}{' '.join(uncommon_keys_2)}{COLORS.END}"
            )
    for k, v1, v2 in sorted(common_entries(contents1, contents2)):
        key_length = len(str(k))
        line = " " * (col_size - key_length)
        if type(v1) != type(v2):
            print(
                f"{COLORS.RED}{k} is a different type between ckpt1 and ckpt2: ({type(v1).__name__} vs. {type(v2).__name__}){COLORS.END}"
            )
            continue
        else:
            prefix = f"{k}: {COLORS.BLUE}{type(v1).__name__} | {type(v2).__name__}{COLORS.END}"
        if isinstance(v1, dict):
            pretty_print_double(v1, v2, args)
        elif isinstance(v1, PRIMITIVE_TYPES):
            if repr(v1) != repr(v2):
                c = COLORS.RED
                line += f" = "
                line += f"{c}{repr(v1)} | {repr(v2)}{COLORS.END}"
            else:
                c = COLORS.CYAN
                if not args.diff:
                    line += f" = "
                    line += f"{c}{repr(v1)} | {repr(v2)}{COLORS.END}"
        elif isinstance(v1, Sequence):
            if len(v1) != len(v2):
                c = COLORS.RED
                line += ", "
                line += f"{c}len={len(v1)} | len={len(v2)}{COLORS.END}"
            else:
                c = COLORS.CYAN
                if not args.diff:
                    line += ", "
                    line += f"{c}len={len(v1)} | len={len(v2)}{COLORS.END}"
        elif isinstance(v1, torch.Tensor):
            if v1.ndimension() != v2.ndimension():
                c = COLORS.RED
            else:
                c = COLORS.CYAN

            if (v1.ndimension() in (0, 1) and v1.numel() == 1) and (
                v2.ndimension() in (0, 1) and v2.numel() == 1
            ):
                if not args.diff:
                    line += f" = "
                    line += f"{c}{v1.item()} | {c}{v2.item()}{COLORS.END}"
            else:
                if list(v1.shape) != list(v2.shape):
                    c = COLORS.RED
                    line += ", "
                    line += f"{c}shape={list(v1.shape)} | shape={list(v2.shape)}{COLORS.END}"
                else:
                    c = COLORS.CYAN
                    if not args.diff:
                        line += ", "
                        line += f"{c}shape={list(v1.shape)} | shape={list(v2.shape)}{COLORS.END}"
                if v1.dtype != v2.dtype:
                    c = COLORS.RED
                    line += f"{c}dtype={v1.dtype} | dtype={v2.dtype}{COLORS.END}"

                else:
                    c = COLORS.CYAN
                    if not args.diff:
                        line += ", "
                        line += f"{c}dtype={v1.dtype} | dtype={v2.dtype}{COLORS.END}"
                if list(v1.shape) == list(v2.shape):
                    if torch.allclose(v1, v2):
                        if not args.diff:
                            line += f", {COLORS.CYAN}VALUES EQUAL{COLORS.END}"
                    else:
                        line += f", {COLORS.RED}VALUES DIFFER{COLORS.END}"

        if line.replace(" ", "") != "":
            line = prefix + line
            print(line)
            diffs_found = True
    if args.diff and not diffs_found:
        pass
    else:
        if not args.diff:
            print("\n")

    return diffs_found


def get_attribute(obj: object, name: str) -> object:
    if isinstance(obj, Mapping):
        return obj[name]
    if isinstance(obj, Namespace):
        return obj.name
    return getattr(object, name)


def get_files(pth):
    if os.path.isdir(pth):
        files = list(Path(pth).glob("*.pt")) + list(Path(pth).glob("*.ckpt"))
    elif os.path.isfile(pth):
        assert pth.endswith(".pt") or pth.endswith(".ckpt")
        files = [Path(pth)]
    else:
        raise ValueError("Dir / File not found.")
    return natural_sort(files)


def peek(args: Namespace):

    files = get_files(args.dir)

    for file in files:
        file = Path(file).absolute()
        print(f"{COLORS.GREEN}{file.name}:{COLORS.END}")
        ckpt = torch.load(file, map_location=torch.device("cpu"))
        selection = dict()
        attribute_names = args.attributes or list(ckpt.keys())
        for name in attribute_names:
            parts = name.split("/")
            current = ckpt
            for part in parts:
                current = get_attribute(current, part)
            selection.update({name: current})
        pretty_print(selection)
        print("\n")

        if args.interactive:
            code.interact(
                banner="Entering interactive shell. You can access the checkpoint contents through the local variable 'checkpoint'.",
                local={"checkpoint": ckpt, "torch": torch},
            )


def get_shared_fnames(files_1, files_2):
    names_1 = [Path(i).name for i in files_1]
    names_1_parent = Path(files_1[0]).parent
    names_2 = [Path(i).name for i in files_2]
    names_2_parent = Path(files_2[0]).parent
    shared_names = list(set.intersection(*map(set, [names_1, names_2])))
    return [names_1_parent / i for i in shared_names], [
        names_2_parent / i for i in shared_names
    ]


def get_selection(filename, args):
    ckpt = torch.load(filename, map_location=torch.device("cpu"))
    selection = dict()
    attribute_names = args.attributes or list(ckpt.keys())
    for name in attribute_names:
        parts = name.split("/")
        current = ckpt
        for part in parts:
            current = get_attribute(current, part)
        selection.update({name: current})
    return selection


def compare(args: Namespace):
    dirs = [i.strip() for i in args.dir.split(",")]
    assert len(dirs) == 2, "Only works with 2 directories / files"
    files_1 = get_files(dirs[0])
    files_2 = get_files(dirs[1])
    files_1, files_2 = get_shared_fnames(files_1, files_2)

    for file1, file2 in zip(files_1, files_2):
        file1 = Path(file1).absolute()
        file2 = Path(file2).absolute()
        print(f"COMPARING {COLORS.GREEN}{file1.name} & {file2.name}:{COLORS.END}")
        selection_1 = get_selection(file1, args)
        selection_2 = get_selection(file2, args)
        diffs_found = pretty_print_double(selection_1, selection_2, args)
        if args.diff and diffs_found:
            print(
                f"{COLORS.RED}THE ABOVE DIFFS WERE FOUND IN {file1.name} & {file2.name} ^{COLORS.END}\n"
            )

        if args.interactive:
            code.interact(
                banner="Entering interactive shell. You can access the checkpoint contents through the local variable 'selection_1' / 'selection_2'.\nPress Ctrl-D to exit.",
                local={
                    "selection_1": selection_1,
                    "selection_2": selection_2,
                    "torch": torch,
                },
            )


def main():
    parser = ArgumentParser()
    parser.add_argument(
        "dir",
        type=str,
        help="The checkpoint dir to inspect. Must be either: \
         - a directory containing pickle binaries saved with 'torch.save' ending in .pt or .ckpt \
         - a single path to a .pt or .ckpt file \
         - two comma separated directories - in which case the script will *compare* the two checkpoints",
    )
    parser.add_argument(
        "--attributes",
        nargs="*",
        help="Name of one or several attributes to query. To access an attribute within a nested structure, use '/' as separator.",
        default=None,
    )
    parser.add_argument(
        "--interactive",
        "-i",
        action="store_true",
        help="Drops into interactive shell after printing the summary.",
    )
    parser.add_argument(
        "--compare",
        "-c",
        action="store_true",
        help="If true, script will compare two directories separated by commas",
    )
    parser.add_argument(
        "--diff", "-d", action="store_true", help="In compare mode, only print diffs"
    )

    args = parser.parse_args()
    if args.compare:
        compare(args)
    else:
        peek(args)


if __name__ == "__main__":
    main()
