import dataclasses
import hashlib
import json
import subprocess
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as tf
from tabulate import tabulate


class ModelInterface(nn.Module):
    version = "VERSION"

    def __init__(self):
        super().__init__()
        try:
            command = "git rev-parse --short HEAD"
            self.version = subprocess.check_output(
                command.split()).strip().decode()
        except:
            pass

    def get_device(self):
        param = next(self.parameters())
        return param.get_device()

    def num_parameters(self):
        return sum(param.numel() for param in self.parameters()
                   if param.requires_grad)


@dataclasses.dataclass
class Hyperparameters:
    prefix = None

    def __add__(self, other):
        self_dict = dataclasses.asdict(self)
        other_dict = dataclasses.asdict(other)
        args = {}
        for key, value in self_dict.items():
            args[key] = value
        for key, value in other_dict.items():
            args[key] = value
        fields = []
        for key, value in args.items():
            fields.append((key, type(value)))
        NewHyperparameters = dataclasses.make_dataclass(
            "NewHyperparameters", fields, bases=(Hyperparameters, ))
        return NewHyperparameters(**args)

    def hash(self):
        return hashlib.md5(
            (json.dumps(dataclasses.asdict(self),
                        sort_keys=True)).encode("utf-8")).hexdigest()

    def __str__(self):
        rows = []
        for key, value in sorted(self.__dict__.items(), key=lambda x: x[0]):
            rows.append([key, value])
        return tabulate(rows,
                        headers=["hyperparameter", "value"],
                        tablefmt="rst")

    def save(self, path):
        with open(path, "w") as f:
            json.dump(dataclasses.asdict(self), f, indent=4, sort_keys=True)

    @classmethod
    def load_json(cls, path: str):
        path = Path(path)
        if path.is_file():
            with open(path, "r") as f:
                json_dict: dict = json.load(f)
                instance = cls()
                prefix = instance.prefix
                for name, value in json_dict.items():
                    if prefix is None:
                        if hasattr(instance, name):
                            setattr(instance, name, value)
                    else:
                        _name = name.replace(f"{prefix}_", "")
                        if hasattr(instance, _name):
                            setattr(instance, _name, value)

                return instance
        else:
            raise FileNotFoundError()
