import dataclasses
import hashlib
import json
import subprocess
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as tf
from tabulate import tabulate


class ModelInterface(nn.Module):
    version = "VERSION"

    def __init__(self):
        super().__init__()
        try:
            command = "git rev-parse --short HEAD"
            self.version = subprocess.check_output(
                command.split()).strip().decode()
        except:
            pass

    def get_device(self):
        param = next(self.parameters())
        return param.get_device()

    def num_parameters(self):
        return sum(param.numel() for param in self.parameters()
                   if param.requires_grad)


class MLP(nn.Module):
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 hidden_sizes: list,
                 bias=True,
                 activation_func_module=torch.nn.ReLU,
                 output_activation=False):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes

        input_output_sizes = list(
            zip([input_size] + hidden_sizes, hidden_sizes + [output_size]))
        layers = []

        if output_activation:
            for input_size, output_size in input_output_sizes:
                layers.append(nn.Linear(input_size, output_size, bias=bias))
                layers.append(activation_func_module())
        else:
            # hidden layers
            for input_size, output_size in input_output_sizes[:-1]:
                layers.append(nn.Linear(input_size, output_size, bias=bias))
                layers.append(activation_func_module())

            # output layers
            input_size, output_size = input_output_sizes[-1]
            layers.append(nn.Linear(input_size, output_size, bias=bias))

        self.layers = nn.Sequential(*layers)

    def forward(self, x) -> torch.Tensor:
        return self.layers(x)

    def __call__(self, x) -> torch.Tensor:
        return super().__call__(x)


@dataclasses.dataclass
class Hyperparameters:
    prefix = None

    def __add__(self, other):
        self_dict = dataclasses.asdict(self)
        other_dict = dataclasses.asdict(other)
        args = {}
        for key, value in self_dict.items():
            args[key] = value
        for key, value in other_dict.items():
            args[key] = value
        fields = []
        for key, value in args.items():
            fields.append((key, type(value)))
        NewHyperparameters = dataclasses.make_dataclass(
            "NewHyperparameters", fields, bases=(Hyperparameters, ))
        return NewHyperparameters(**args)

    def hash(self):
        return hashlib.md5(
            (json.dumps(dataclasses.asdict(self),
                        sort_keys=True)).encode("utf-8")).hexdigest()

    def __str__(self):
        rows = []
        for key, value in sorted(self.__dict__.items(), key=lambda x: x[0]):
            rows.append([key, value])
        return tabulate(rows,
                        headers=["hyperparameter", "value"],
                        tablefmt="rst")

    def save(self, path):
        with open(path, "w") as f:
            json.dump(dataclasses.asdict(self), f, indent=4, sort_keys=True)

    @classmethod
    def load_json(cls, path: str):
        path = Path(path)
        if path.is_file():
            with open(path, "r") as f:
                json_dict: dict = json.load(f)
                instance = cls()
                prefix = instance.prefix
                for name, value in json_dict.items():
                    if prefix is None:
                        if hasattr(instance, name):
                            setattr(instance, name, value)
                    else:
                        _name = name.replace(f"{prefix}_", "")
                        if hasattr(instance, _name):
                            setattr(instance, _name, value)

                return instance
        else:
            raise FileNotFoundError()
