from copy import deepcopy
from torch import nn
import torch
import numpy as np
from typing import (
    Callable,
    Generic,
    Union,
    Optional,
    Tuple,
    Iterator,
    List,
    Any,
    Protocol,
    cast,
    TypeVar,
)
from typing_extensions import Self


LN = TypeVar("LN", contravariant=True)


class LossFunction(Protocol[LN]):
    def __call__(self, this: LN, *args) -> torch.Tensor:
        ...


class BasicNeuralNetwork(nn.Module):
    def __init__(
        self,
    ):
        super(BasicNeuralNetwork, self).__init__()

    def no_grad(self):
        for params in self.parameters():
            params.requires_grad_(False)
        self.requires_grad_(False)
        return self

    def clone(self):
        return deepcopy(self)

    def hard_update_to(self, target: nn.Module):
        for s, t in zip(self.parameters(), target.parameters()):
            s.data.copy_(t.data)
        return self

    def soft_update_to(self, target: nn.Module, tau: float):
        for s, t in zip(self.parameters(), target.parameters()):
            s.data.copy_(s.data * (1 - tau) + t.data * tau)
        return self

N = TypeVar("N", bound=Union[torch.Tensor, BasicNeuralNetwork])


class NeuralNetworkTrainer(Generic[N]):
    target: N
    inner: N

    def __init__(
        self,
        inner: N,
        training: Tuple[
            Callable[[Any], torch.optim.Optimizer],
            LossFunction[N],
        ],
        with_target: bool = False,
    ):
        assert isinstance(inner, torch.Tensor) or isinstance(inner, BasicNeuralNetwork)

        self.inner = inner
        self.with_target = with_target
        self.training = training

        if self.with_target:
            self.target = deepcopy(self.inner)
            if isinstance(self.target, BasicNeuralNetwork):
                self.target.no_grad()
            else:
                self.target.requires_grad_(False)

        self.trainer = self.training[0](
            self.inner.parameters()
            if isinstance(self.inner, BasicNeuralNetwork)
            else [self.inner]
        )
        self._loss_fn = self.training[1]

    def zero_grad(self, set_to_none: bool = False):
        self.trainer.zero_grad(set_to_none)

    def parameters(self, *args, **kwargs):
        assert isinstance(self.inner, BasicNeuralNetwork)
        return self.inner.parameters(*args, **kwargs)
    
    def loss(self, *args, **kwds):
        return self._loss_fn(self.inner, *args, **kwds)

    def step(self):
        self.trainer.step()

    def hard_update_target(self):
        assert hasattr(self, "target")
        if isinstance(self.inner, BasicNeuralNetwork):
            assert isinstance(self.target, BasicNeuralNetwork)
            self.target.hard_update_to(self.inner)
        else:
            self.target = self.inner.clone()
            self.target.requires_grad_(False)

    def soft_update_target(self, tau: float):
        assert hasattr(self, "target")
        if isinstance(self.inner, BasicNeuralNetwork):
            assert isinstance(self.target, BasicNeuralNetwork)
            self.target.soft_update_to(self.inner, tau)
        else:
            assert isinstance(self.target, torch.Tensor)
            self.target = self.target * (1 - tau) + self.inner * tau
            self.target.requires_grad_(False)

    def __call__(self, *args: Any, **kwds: Any) -> Any:
        assert isinstance(self.inner, BasicNeuralNetwork)
        return self.inner(*args, **kwds)


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.kaiming_normal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


def ignore_left(fn: Callable) -> Callable:
    def trim_left(_, *args, **kwds):
        nonlocal fn
        return fn(*args, **kwds)
    return trim_left