import importlib
import torch
import yaml
import json
from copy import deepcopy
import inspect
from typing import List, Optional, Tuple, Dict, Union
from dataclasses import dataclass, fields

import logging

logger = logging.getLogger(__name__)


def load_yaml(path):
    with open(path, "r") as f:
        config_ = yaml.load(f, Loader=yaml.FullLoader)
        return config_


def dataclass_from_dict(cls, d):
    try:
        fieldtypes = {f.name: f.type for f in fields(cls)}
        return cls(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
    except:
        return d  # Not a dataclass field


def from_dict_to_dataclass(cls, data):
    return cls(
        **{
            key: (data[key] if val.default == val.empty else data.get(key, val.default))
            for key, val in inspect.signature(cls).parameters.items()
        }
    )


class ProtoConfig:
    def __init__(self):
        super().__init__()

    def to_json(self):
        return json.loads(
            json.dumps(self, default=lambda o: getattr(o, "__dict__", str(o)))
        )


@dataclass
class PrimitiveCBNNState:
    priors: Optional[List[List[torch.Tensor]]] = None
    var_init_value: Optional[List[List[torch.Tensor]]] = None

    def __post_init__(self):
        self.priors = _promote_float64(self.priors)
        self.var_init_value = _promote_float64(self.var_init_value)

    @classmethod
    def _scale_variance(
        cls, params: List[List[torch.Tensor]], scaling_factor: float = 1.0
    ):

        if scaling_factor != 1.0:
            params = [
                [
                    item if j % 2 == 0 else scaling_factor * item
                    for j, item in enumerate(layer)
                ]
                for layer in params
            ]
        return params

    def scale_parameters(self, prior_variance_scale=1.0, init_variance_scale=1.0):
        self.priors = PrimitiveCBNNState._scale_variance(
            self.priors,
            prior_variance_scale,
        )
        self.var_init_value = PrimitiveCBNNState._scale_variance(
            self.var_init_value, init_variance_scale
        )


@dataclass
class ArchConfig(ProtoConfig):
    name: str
    dims: Optional[List[int]] = None
    domain: Optional[Tuple[float, float]] = None
    transform: str = "IdTransform"
    normalize: bool = False


@dataclass
class TrainConfig(ProtoConfig):
    nepochs: int = 1000
    sigma_obs: float = 0.002
    pretrain: bool = True
    save_state_each: int = 50
    lr_factor: float = 0.8
    lr0: float = 1e-2
    positive_flag_window: int = 3


@dataclass
class VarScalingConfig(ProtoConfig):
    init_variance_scale: float = 1.0
    prior_variance_scale: float = 1.0


@dataclass
class FooConfig(ProtoConfig):
    function_name: str
    module_name: str
    kwargs: Dict


@dataclass
class TableConfig(ProtoConfig):
    xs: Optional[Union[List[float], torch.Tensor]] = None
    ys: Optional[Union[List[float], torch.Tensor]] = None
    ngrid: Optional[int] = 101

    def __post_init__(self):
        if self.xs is not None and self.ys is not None and len(self.xs) != len(self.ys):
            raise ValueError(" xs and ys tables have different lengths")
        if isinstance(self.xs, list):
            self.xs = torch.tensor(self.xs)
        if isinstance(self.ys, list):
            self.ys = torch.tensor(self.ys)

    def from_function(self, aconfig: ArchConfig, fconfig: FooConfig):
        a, b = aconfig.domain
        try:
            self.xs = torch.linspace(a, b, self.ngrid)
        except:
            raise ValueError(
                " creating linspace failed : either ngrid is not set or there is a problem with ArchConfig.domain"
            )

        try:
            # could be "builtins"
            module = importlib.import_module(fconfig.module_name)
            foo = getattr(module, fconfig.function_name)
            self.ys = foo(self.xs, **fconfig.kwargs)
            if aconfig.transform == "StickBreakingTransform":
                promote_2d = self.ys.unsqueeze(-1)
                leftover = (1.0 - promote_2d.sum(-1)).unsqueeze(-1)
                self.ys = torch.cat([promote_2d, leftover], axis=1)
        except:
            raise ValueError(" creating TableConfig.ys from FooConfig failed")
        self.__post_init__()


@dataclass
class BNNConfig(ProtoConfig):
    architecture: ArchConfig
    train: TrainConfig
    table: Optional[TableConfig] = None
    foo: Optional[FooConfig] = None
    varscaling: Optional[VarScalingConfig] = VarScalingConfig()

    def __post_init__(self):
        if self.train.pretrain:
            if self.table.xs is None and self.table.ys is None:
                self.table = TableConfig()
                self.table.from_function(self.architecture, self.foo)


class BNNConfigFactory:
    @classmethod
    def get_bnn_config(cls, json_like):

        ac = ArchConfig(**json_like["architecture"])
        train = TrainConfig(**json_like["train"])
        table = (
            TableConfig(**json_like["table"]) if "table" in json_like else TrainConfig()
        )
        foo = (
            FooConfig(**json_like["foo"])
            if "foo" in json_like and json_like["foo"] is not None
            else None
        )

        varscaling = (
            VarScalingConfig(**json_like["varscaling"])
            if "varscaling" in json_like and json_like["varscaling"] is not None
            else VarScalingConfig()
        )
        obj = BNNConfig(
            architecture=ac, train=train, table=table, foo=foo, varscaling=varscaling
        )
        return obj


@dataclass
class ModelingConfig:
    synthetic_nsample: int = 1000
    normalized_space: bool = True


@dataclass
class SBConfig(ProtoConfig):
    architecture: ArchConfig
    train: TrainConfig
    cbnns: Dict[str, BNNConfig]
    modeling: ModelingConfig


class SBConfigFactory:
    @classmethod
    def get_sb_config(cls, json_like):
        a = ArchConfig(**json_like["architecture"])
        train = TrainConfig(**json_like["train"])
        modeling = ModelingConfig(**json_like["modeling"])
        cbnns = {
            conf["architecture"]["name"]: BNNConfigFactory.get_bnn_config(conf)
            for conf in json_like["cbnns"]
        }
        return SBConfig(architecture=a, train=train, modeling=modeling, cbnns=cbnns)

        # self.dir_weight = [item for item in acc if "dir_weight" in item.config_nn.name][
        #     0
        # ]
        # self.dir_norm = [item for item in acc if "dir_norm" in item.config_nn.name][0]
        # self.tsignal_norm = [
        #     item for item in acc if "tsignal_norm" in item.config_nn.name
        # ][0]
        # self.vec_signal = [item for item in acc if "vec_signal" in item.config_nn.name][
        #     0
        # ]
        #
        # self.pickles_folder = pickles_folder


def _promote_float64(obj):
    if isinstance(obj, (list, tuple)):
        return [_promote_float64(item) for item in obj]
    elif torch.is_tensor(obj):
        return obj.double()
