from __future__ import annotations

import torch

from torch import nn
from torch.utils.data import DataLoader
from torch.func import functional_call, jvp, vjp, grad # type: ignore
from typing import Callable, Tuple, MutableMapping

from bayesopt.surrogates.fsplaplace_utils.ssrft import SSRFT

class LinearOperator:
    _model: nn.Module
    _device: torch.device
    _n_chunks: int
    _sketch: str | None

    @torch.no_grad()
    def __matmul__(self, M: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class NeuralNetworkLinearOperator(LinearOperator):
    def __init__(
        self,
        model: nn.Module,
        data: DataLoader | Tuple[torch.Tensor] | Tuple[MutableMapping],
        sketch: str | None = None,
        sketch_dim: int = 10,
        n_chunks: int = 1,
        dict_key_x: str = "input_ids",
        dict_key_y: str = "labels",
    ):
        self._model = model
        self._data = data
        self._device = next(model.parameters()).device
        self._n_chunks = n_chunks
        self._dict_key_x = dict_key_x
        self._dict_key_y = dict_key_y
        
        # Split parameters into stochastic and deterministic
        self._sto_params, self._det_params = {}, {}
        for k, v in model.named_parameters():
            if v.requires_grad:
                self._sto_params[k] = v.detach()
            else:
                self._det_params[k] = v.detach()
        
        self._n_params = sum(p.numel() for p in self._sto_params.values())
        
       # Get sketching function
        if sketch == "ssrft": 
            #  Needs to have the same seed otherwise the sketched GGN is no longer symmetric 
            self._sketch = SSRFT(shape=(sketch_dim, self._n_params), seed=1)
        else:
            self._sketch = None


class GGNLinearOperator(NeuralNetworkLinearOperator):
    def __init__(
        self,
        model: nn.Module,
        loss_fn: nn.Module,
        sigma_noise: torch.Tensor,
        data: DataLoader | Tuple[torch.Tensor] | Tuple[MutableMapping],
        sketch: str | None = None,
        sketch_dim: int = 10,
        n_chunks: int = 1,
        dict_key_x: str = "input_ids",
        dict_key_y: str = "labels",
    ):
        super().__init__(model, data, sketch, sketch_dim, n_chunks, dict_key_x, dict_key_y)
        self._loss_fn = loss_fn
        self._sigma_noise = sigma_noise

    @torch.no_grad()
    def __matmul__(self, M: torch.Tensor) -> torch.Tensor:
        vector_input = (M.ndim == 1)
        if vector_input:
            M = M.unsqueeze(1) 

        if self._sketch:
            result = torch.zeros(self._sketch.shape[0], M.shape[1], device=self._device, dtype=M.dtype)
        else:
            result = torch.zeros(M.shape[0], M.shape[1], device=self._device, dtype=M.dtype)
        for data in self._data:
            if isinstance(data, MutableMapping):
                x = data
                y = data[self._dict_key_y].to(self._device, non_blocking=True)
            else:
                x = data[0].to(self._device, non_blocking=True), 
                y = data[1].to(self._device, non_blocking=True)
            result.add_(self._matmat_batch(x, y, M))
        
        if vector_input:
            result = result.squeeze(1)

        return result

    def _matmat_batch(self, x: torch.Tensor | MutableMapping, y: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        model_fn = lambda _sto_params: functional_call(self._model, _sto_params | self._det_params, x)
        if isinstance(self._loss_fn, nn.GaussianNLLLoss):
            loss_fn = lambda _f: self._loss_fn(_f, y, self._sigma_noise.expand(y.shape[0], -1)**2)
        else:
            loss_fn = lambda _f: self._loss_fn(_f, y)
        out = torch.vmap(
            self._ggn_vector_product, in_dims=(None, None, 1), chunk_size=self._n_chunks, randomness="same"
        )(model_fn, loss_fn, M)
        
        return out.mT
    
    @torch.compile
    def _ggn_vector_product(
        self,
        model_fn: Callable[[torch.Tensor], torch.Tensor],
        loss_fn: Callable[[torch.Tensor], torch.Tensor],
        v: torch.Tensor
    ) -> torch.Tensor:
        # S^T
        if self._sketch:
            v = v @ self._sketch

        # Convert to dict
        v_split = v.split([p.numel() for p in self._sto_params.values()])
        v_dict = {
            k: m.reshape(p.shape).to(p.device, p.dtype)
            for m, (k, p) in zip(v_split, self._sto_params.items())
        }
        
        # GGN-vector product
        f, Jv = jvp(model_fn, (self._sto_params,), (v_dict,)) # type: ignore
        loss_jvp_fn = lambda _f: jvp(loss_fn, (_f,), (Jv,))[1]
        LJv = grad(loss_jvp_fn)(f)
        _, f_vjp = vjp(model_fn, self._sto_params) # type: ignore
        JtLJv = f_vjp(LJv)[0]
        
        # Convert to vector
        out = torch.cat([r.flatten() for r in JtLJv.values()])

        # S
        if self._sketch:
            out = self._sketch @ out

        return out

class JacobianLinearOperator(NeuralNetworkLinearOperator):
    def __init__(
        self,
        model: nn.Module,
        data: DataLoader | Tuple[torch.Tensor] | Tuple[MutableMapping],
        n_outputs: int,
        sketch: str | None = None,
        sketch_dim: int = 10,
        n_chunks: int = 1,
        output_idx: int = -1,
        dict_key_x: str = "input_ids",
        dict_key_y: str = "labels",
    ):
        super().__init__(model, data, sketch, sketch_dim, n_chunks, dict_key_x, dict_key_y)
        if isinstance(data, DataLoader):
            self._n_data = len(data.dataset) 
        else:
            if isinstance(data[0], MutableMapping):
                self._n_data = data[0][self._dict_key_y].shape[0]
            else: # tuple of tensors
                self._n_data = data[0].shape[0] # type: ignore

        self._n_outputs = n_outputs 
        self._output_idx = (output_idx,) if output_idx >= 0 else range(n_outputs)

    @torch.no_grad()
    def __matmul__(self, M: torch.Tensor) -> torch.Tensor:
        vector_input = (M.ndim == 1)
        if vector_input:
            M = M.unsqueeze(1)

        result = torch.zeros(self._n_data, len(self._output_idx), M.shape[-1], device=self._device, dtype=M.dtype)
        for i, data in enumerate(self._data):
            if isinstance(data, torch.Tensor):
                x = data.to(self._device, non_blocking=True)
                batch_size = x.shape[0]
            elif isinstance(data, MutableMapping):
                x = data
                batch_size = data[self._dict_key_x].shape[0]
            else:
                x = data[0].to(self._device, non_blocking=True)
                batch_size = x.shape[0]
            result[i * batch_size:(i + 1) * batch_size].add_(self._matmat_batch(x, M))
        
        if vector_input:
            result = result.squeeze(1)

        return result

    def _matmat_batch(self, x: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
        model_fn = lambda _sto_params: functional_call(self._model, _sto_params | self._det_params, x)[:,self._output_idx]
        out = torch.vmap(
            self._jac_vector_product, in_dims=(None, 1), out_dims=-1, chunk_size=self._n_chunks, randomness="same"
        )(model_fn, M)

        return out
    
    @torch.compile
    def _jac_vector_product(
        self,
        model_fn: Callable[[torch.Tensor], torch.Tensor],
        v: torch.Tensor
    ) -> torch.Tensor:
        # S^T
        if self._sketch:
            v = v @ self._sketch

        # Convert to dict
        v_split = v.split([p.numel() for p in self._sto_params.values()])
        v_dict = {
            k: m.reshape(p.shape).to(p.device, p.dtype)
            for m, (k, p) in zip(v_split, self._sto_params.items())
        }

        out = jvp(model_fn, (self._sto_params,), (v_dict,))[1]
        
        return out

class JacobianTransposeLinearOperator(NeuralNetworkLinearOperator):
    def __init__(
        self,
        model: nn.Module,
        data: DataLoader | Tuple[torch.Tensor] | MutableMapping,
        n_outputs: int,
        sketch: str | None = None,
        sketch_dim: int = 10,
        n_chunks: int = 1,
        output_idx: int = -1,
        dict_key_x: str = "input_ids",
        dict_key_y: str = "labels",
    ):
        super().__init__(model, data, sketch, sketch_dim, n_chunks, dict_key_x, dict_key_y)
        self._n_params = sum(p.numel() for p in self._sto_params.values())
        self._output_idx = (output_idx,) if output_idx >= 0 else range(n_outputs)

    @torch.no_grad()
    def __matmul__(self, M: torch.Tensor) -> torch.Tensor:
        vector_input = (M.ndim == 1)
        if vector_input:
            M = M.unsqueeze(1)
            
        if self._sketch:
            result = torch.zeros(self._sketch.shape[0], M.shape[-1], device=self._device, dtype=M.dtype)
        else:
            result = torch.zeros(self._n_params, M.shape[-1], device=self._device, dtype=M.dtype)
        for i, data in enumerate(self._data):
            if isinstance(data, torch.Tensor):
                x = data.to(self._device, non_blocking=True)
                batch_size = x.shape[0]
            elif isinstance(data, MutableMapping):
                x = data
                batch_size = data[self._dict_key_x].shape[0]
            else: 
                x = data[0].to(self._device, non_blocking=True)
                batch_size = x.shape[0]
            result.add_(self._matmat_batch(x, M[i * batch_size:(i + 1) * batch_size]))
        
        if vector_input:
            result = result.squeeze(1)

        return result

    def _matmat_batch(self, x: torch.Tensor, M: torch.Tensor):
        model_fn = lambda _sto_params: torch.func.functional_call(self._model, _sto_params | self._det_params, x)[:,self._output_idx]
        result = torch.vmap(
            self._vector_jac_product, in_dims=(None, -1), out_dims=-1, chunk_size=self._n_chunks, randomness="same"
        )(model_fn, M)

        return result
    
    @torch.compile
    def _vector_jac_product(
        self,
        model_fn: Callable[[torch.Tensor], torch.Tensor],
        v: torch.Tensor
    ) -> torch.Tensor:        
        out = torch.func.vjp(model_fn, self._sto_params)[1](v)[0]
        out = torch.cat([r.flatten() for r in out.values()])
        
        # S
        if self._sketch:
            out = self._sketch @ out
        
        return out


if __name__ == "__main__":
    dtype = torch.float32
    torch.manual_seed(0)
    n_outputs = 5
    # Data
    x1 = torch.linspace(-1, -0.5, 50).reshape(-1, 1).to(dtype)
    x2 = torch.linspace(0.5, 1, 50).reshape(-1, 1).to(dtype)
    
    # Train data 
    train_X = torch.cat([x1, x2], dim=0)
    train_Y = torch.ones(100).long() #torch.sin(2 * math.pi * train_X) + torch.normal(0, 0.1, (100, n_outputs)).to(dtype)
    
    # Initialize model
    net = torch.nn.Sequential(
        torch.nn.Linear(train_X.shape[-1], 50),
        torch.nn.Tanh(),
        torch.nn.Linear(50, 50),
        torch.nn.Tanh(),
        torch.nn.Linear(50, n_outputs),
    )

    n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    train_dataset = torch.utils.data.TensorDataset(train_X, train_Y)  # type: ignore
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) # type: ignore
    if n_outputs == 1:
        loss_fn = torch.nn.GaussianNLLLoss() 
    else:
        loss_fn = torch.nn.NLLLoss()
    sigma_noise = torch.ones(n_outputs)

    print("-------------- NO SKETCH -------------------")
    ggn_linop = GGNLinearOperator(net, loss_fn, sigma_noise, trainloader, n_chunks=1)
    ggn_vp = ggn_linop @ torch.ones(n_params, 10)
    print("ggn", ggn_vp.shape)

    jac_linop = JacobianLinearOperator(net, trainloader, output_idx=0, n_outputs=n_outputs, n_chunks=1)
    jac_vp = jac_linop @ torch.ones(n_params, 10)
    print("jac for class=0", jac_vp.shape)

    jacT_linop = JacobianTransposeLinearOperator(net, trainloader, output_idx=0, n_outputs=n_outputs, n_chunks=1)
    jacT_vp = jacT_linop @ torch.ones(100, 1, 10)
    print("jacT for class=0", jacT_vp.shape)

    jac_linop = JacobianLinearOperator(net, trainloader, n_outputs=n_outputs, n_chunks=1)
    jac_vp = jac_linop @ torch.ones(n_params, 10)
    print("jac full output", jac_vp.shape)

    jacT_linop = JacobianTransposeLinearOperator(net, trainloader, n_outputs=n_outputs, n_chunks=1)
    jacT_vp = jacT_linop @ torch.ones(100, n_outputs, 10)
    print("jacT full output", jacT_vp.shape)

    print("-------------- SKETCH -------------------")
    sketch = "ssrft"
    sketch_dim = 10
    ggn_linop = GGNLinearOperator(net, loss_fn, sigma_noise, trainloader, sketch, sketch_dim, n_chunks=1)
    ggn_vp = ggn_linop @ torch.ones(sketch_dim, 10)
    print("ggn", ggn_vp.shape)

    jac_linop = JacobianLinearOperator(net, trainloader, output_idx=0, n_outputs=n_outputs, sketch=sketch, sketch_dim=sketch_dim, n_chunks=1)
    jac_vp = jac_linop @ torch.ones(sketch_dim, 10)
    print("jac for class=0", jac_vp.shape)

    jacT_linop = JacobianTransposeLinearOperator(net, trainloader, output_idx=0, n_outputs=n_outputs, sketch=sketch, sketch_dim=sketch_dim, n_chunks=1)
    jacT_vp = jacT_linop @ torch.ones(100, 1, 10)
    print("jacT for class=0", jacT_vp.shape)

    jac_linop = JacobianLinearOperator(net, trainloader, n_outputs=n_outputs, sketch=sketch, sketch_dim=sketch_dim, n_chunks=1)
    jac_vp = jac_linop @ torch.ones(sketch_dim, 10)
    print("jac full output", jac_vp.shape)

    jacT_linop = JacobianTransposeLinearOperator(net, trainloader, n_outputs=n_outputs, sketch=sketch, sketch_dim=sketch_dim, n_chunks=1)
    jacT_vp = jacT_linop @ torch.ones(100, n_outputs, 10)
    print("jacT full output", jacT_vp.shape)