import torch
import torch.nn as nn
from torchvision.models import resnet18, resnet50
from FishLeg.src.optim.FishLeg.utils import (
    recursive_setattr,
    recursive_getattr,
    update_dict,
    _use_grad_for_differentiable,
)
from FishLeg.src.optim.FishLeg.fishleg_layers import FishLinear, FishConv2d

ACTIVATIONS = {"relu": nn.ReLU(), "sigmoid": nn.Sigmoid()}


class ModelConstructor:
    def __init__(
        self,
        load_path=None,
        architecture=None,
    ):
        self.load_path = load_path
        self.architecture = architecture

        if self.load_path:
            self.checkpoint = torch.load(load_path)
        else:
            self.checkpoint = None

        if self.architecture["name"] == "autoencoder":
            self.layer_sizes = architecture["layer_sizes"]
            self.activation_funcs = architecture["activation_funcs"]
        elif self.architecture["name"] == "resnet":
            pass

    def build(self, to_fishleg: bool = False):
        if self.architecture["name"] == "autoencoder":
            modules = []
            for i in range(len(self.layer_sizes) - 1):
                modules.append(
                    nn.Linear(
                        self.layer_sizes[i],
                        self.layer_sizes[i + 1],
                        dtype=torch.float32,
                    )
                )
                if self.activation_funcs[i]:
                    modules.append(ACTIVATIONS[self.activation_funcs[i].lower()])
            model = nn.Sequential(*modules)

        elif self.architecture["name"] == "resnet":
            model = resnet50()
            model.load_state_dict(self.checkpoint)

        if to_fishleg:
            model = self.modeltofishleg(model)

        if self.checkpoint and not (self.architecture["name"] == "resnet"):
            model.load_state_dict(self.checkpoint["model_state_dict"])
            opt_state = self.checkpoint["optimizer_state_dict"]
        else:
            opt_state = None

        return model, opt_state

    def modeltofishleg(self, model: nn.Module) -> nn.Module:
        """Given a model to optimize, parameters can be devided to
        #. those fixed as pre-trained.
        #. those required to optimize using FishLeg.
        Replace modules in the second group with FishLeg modules.
        Args:
            model (:class:`torch.nn.Module`, required):
                A model containing modules to replace with FishLeg modules
                containing extra functionality related to FishLeg algorithm.
        Returns:
            :class:`torch.nn.Module`, the replaced model.
        """
        for name, module in model.named_modules():
            try:
                if isinstance(module, nn.Linear):
                    replace = FishLinear(
                        module.in_features,
                        module.out_features,
                        module.bias is not None,
                    )
                    replace = update_dict(replace, module)
                    recursive_setattr(model, name, replace)
                elif isinstance(module, nn.Conv2d):
                    replace = FishConv2d(
                        in_channels=module.in_channels,
                        out_channels=module.out_channels,
                        bias=module.bias is not None,
                        stride=module.stride,
                        padding=module.padding,
                        dilation=module.dilation,
                        groups=module.groups,
                        padding_mode=module.padding_mode,
                    )
                    replace = update_dict(replace, module)
                    recursive_setattr(model, name, replace)
                else:
                    continue
            except KeyError:
                pass
        return model
