from typing import Optional, Dict, Any, List, Tuple
from abc import ABC, abstractmethod

import logging

logging.basicConfig(level = logging.INFO)
logger = logging.getLogger(__name__)

from pathlib import Path
import numpy as np
import torch
import lightning as pl
from einops import repeat, einsum

from lightning.pytorch.loggers import WandbLogger

import onnxruntime

# For dealing with model configuration
from context_general_bci.utils import suppress_default_registry
suppress_default_registry()
from context_general_bci.model import BrainBertInterface
from context_general_bci.data_utils import batchify_inference
from context_general_bci.ndt3_slim import NDT3, predict_prefill
from context_general_bci.config import RootConfig, DataKey, Output, ModelTask
from context_general_bci.rl.rl import IQLAgent

from context_general_bci.rtndt.config.decoder_config import (
    OnlineConfig,
    Accelerator
)

class RLMixin:
    r"""
        We need an abstraction to hold "agent/RL" logic responsible for:
        - handling RL weight updates
        - serving RL guided behavior
    """
    def __init__(self, cfg: OnlineConfig):
        self.cfg = cfg
        self.agent = IQLAgent() # TODO pass params

    def update(self, batch, utd_ratio: int):
        self.agent.update(batch, utd_ratio)

    # TODO override predict to serve the actor head + exploration policy


class BaseAccelerator:
    def __init__(self, cfg: OnlineConfig):
        self.reset()

    @property
    def ready(self):
        return False

    def reset(self):
        pass

    @abstractmethod
    def compile_and_load(
        self,
        src_model: BrainBertInterface,
        cfg: RootConfig,
        spike_history: np.ndarray,
    ):
        pass

    @abstractmethod
    def predict(
            self,
            spike_history: np.ndarray,
            covariate_history: np.ndarray,
            sparse_constraint: np.ndarray,
            reward_history: np.ndarray,
            return_history: np.ndarray,
            return_time: np.ndarray,
            temperature: float = 0.,
            num_kin: int = 2,
            sparse_constraint_time: np.ndarray | None = None,
            **kwargs, # reject TODO remove unneeded kwargs from pipeline to reduce overhead
    ):
        pass

class DummyAccelerator(BaseAccelerator):
    def __init__(self, cfg: OnlineConfig):
        self.reset()

    @property
    def ready(self):
        return self.lightning_model is not None

    def reset(self):
        super().reset()
        self.lightning_model = None

    def compile_and_load(
        self,
        src_model: BrainBertInterface,
        cfg: RootConfig,
        spike_history: np.ndarray,
    ):
        super().compile_and_load(src_model, cfg, spike_history)
        src_model = src_model.to('cuda:0') # For some reason, lab rig 3 doesn't keep tuned model device, and converts to cpu, and params are sharded on cpu and gpu
        self.lightning_model = src_model
        self.lightning_model.eval()

    def predict(
            self,
            spike_history: np.ndarray,
            covariate_history: np.ndarray,
            sparse_constraint: np.ndarray,
            reward_history: np.ndarray,
            return_history: np.ndarray,
            return_time: np.ndarray,
            reference: Any, # superhack
            kin_mask_timesteps: np.ndarray | None = None,
            temperature: float = 0.,
            num_kin: int = 2,
            spike_array_lengths: List[int] = [],
            return_seeded: bool = False,
            sparse_constraint_time: np.ndarray | None = None,
    ):
        r"""
            - predict multiple steps for multiple dimensions
            - TODO KV cache

            Note despite predict_simple autocast, we still need to the device assignments here, go figure
        """
        out = self.lightning_model.predict_simple(
            torch.as_tensor(spike_history, device=self.lightning_model.device),
            torch.as_tensor(covariate_history, device=self.lightning_model.device),
            torch.as_tensor(sparse_constraint, device=self.lightning_model.device),
            torch.as_tensor(sparse_constraint_time, device=self.lightning_model.device) if sparse_constraint_time is not None else None,
            torch.as_tensor(reward_history, device=self.lightning_model.device),
            torch.as_tensor(return_history, device=self.lightning_model.device),
            torch.as_tensor(return_time, device=self.lightning_model.device),
            reference, # Not migrated to device, done internally (just a monkey patch path atm)
            kin_mask_timesteps=torch.as_tensor(kin_mask_timesteps, device=self.lightning_model.device) if kin_mask_timesteps is not None else None,
            temperature=temperature,
            spike_array_lengths=spike_array_lengths,
            return_seeded=return_seeded,
            num_kin=num_kin,
        )
        return {k: v.detach().float().cpu().numpy() if v is not None else None for k, v in out.items()} # unbatchify, as well

class SlimAccelerator(BaseAccelerator):
    def __init__(self, cfg: OnlineConfig):
        self.reset()

    @property
    def ready(self):
        return self.model is not None

    def reset(self):
        super().reset()
        self.model = None

    def compile_and_load(
        self,
        src_model: BrainBertInterface,
        cfg: RootConfig,
        spike_history: np.ndarray,
    ):
        super().compile_and_load(src_model, cfg, spike_history)
        src_model = src_model.to('cuda:0') # Actual weights don't move to GPU in slim init - not sure why.
        fast_model = NDT3.from_training_shell(
            src_model,
            use_kv_cache=True,
            max_seqlen=cfg.dataset.max_tokens, # make sure rotary is on same extent as in training
        )
        fast_model = fast_model.to('cuda:0')
        self.model = fast_model

    def set_streaming_timestep_limit(self, limit: int):
        if self.model is not None:
            self.model.set_streaming_timestep_limit(limit)

    def predict(
            self,
            spike_history: np.ndarray,
            covariate_history: np.ndarray, # last timestep should be empty / 0s. since it's not predicted
            constraint_history: np.ndarray, # constraint: shape T/Event x 3 x CovDim
            reward_history: np.ndarray,
            return_history: np.ndarray,
            return_time: np.ndarray,
            kin_mask_timesteps: np.ndarray | None = None,
            temperature: float = 0.,
            num_kin: int = 2,
            sparse_constraint_time: np.ndarray | None = None, # If provided, implies sparse constraint
            **kwargs, # reject TODO remove unneeded kwargs from pipeline to reduce overhead
    ):
        r"""
           - predict multiple steps for multiple dimensions

            kin_mask_timesteps: if final mask is OFF, indicates we would like kinematic prior.

            Note despite predict_simple autocast, we still need to the device assignments here, go figure
        """
        batch = batchify_inference(
            torch.as_tensor(spike_history, device=self.model.device),
            torch.as_tensor(covariate_history, device=self.model.device),
            torch.as_tensor(constraint_history, device=self.model.device),
            torch.as_tensor(sparse_constraint_time, device=self.model.device) if sparse_constraint_time is not None else None,
            torch.as_tensor(reward_history, device=self.model.device),
            torch.as_tensor(return_history, device=self.model.device),
            torch.as_tensor(return_time, device=self.model.device),
            neurons_per_token=self.model.neurons_per_token,
            max_channel_count=self.model.max_channel_count,
        )
        out = predict_prefill(
            self.model,
            batch[DataKey.spikes.name],
            batch[DataKey.time.name],
            batch[DataKey.position.name],
            batch[DataKey.bhvr_vel.name],
            batch[DataKey.covariate_time.name],
            batch[DataKey.covariate_space.name],
            batch[DataKey.task_reward.name],
            batch[DataKey.task_return.name],
            batch[DataKey.task_return_time.name],
            batch[DataKey.constraint.name],
            batch[DataKey.constraint_time.name],
            batch[DataKey.constraint_space.name],
            temperature=temperature,
            num_kin=num_kin,
            mask_kin_prior=kin_mask_timesteps[-1] if kin_mask_timesteps is not None else False,
        )
        return {k: v.float().cpu().numpy() if v is not None else None for k, v in out.items()} # unbatchify, as well

class OnnxAccelerator(BaseAccelerator):
    r""" # ! Currently deprecated """
    ONNX_PATH = 'ndt_{train_tag}.onnx'

    def __init__(self, cfg: OnlineConfig):
        self.working_dir = cfg.run_dir
        self.reset()

    def reset(self):
        self.input_name = None
        self.ort_session = None

    @property
    def ready(self):
        return self.ort_session is not None

    def compile_and_load(
        self,
        src_model: BrainBertInterface,
        cfg: RootConfig,
        spike_history: np.ndarray,
    ) -> Path:
        out_path = self.working_dir / self.ONNX_PATH.format(train_tag=cfg.tag)
        decoder_path = src_model.to_onnx(
            out_path,
            torch.as_tensor(spike_history), # batch = 1, spike dim = 1
            export_params=True,
            input_names=["spikes"],
            dynamic_axes={
                "spikes": {1: "time"}
            }
        )
        self._load_decoder(decoder_path)
        # if (newest_dir / ONNX_PATH.format(train_tag=config_set)).exists():
            # logger.info(f"Decoder for set {config_set} already exists; reloading. Ready for prediction.")
            # self.load_onnx_decoder(newest_dir / ONNX_PATH.format(train_tag=config_set))
            # return


    def _load_decoder(self, path):
        logger.info(f'Loading onnx {path}')
        self.ort_session = onnxruntime.InferenceSession(
            str(path),
            providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
        )
        self.input_name = self.ort_session.get_inputs()[0].name

    def predict(self, spike_history: np.ndarray):
        r"""
        Deprecated for ndt3 atm
        TODO work out dimensions
        """
        if self.ort_session is None:
            raise ValueError("Decoder not yet fit")
        ort_inputs = {self.input_name: spike_history[None, :, :, None]}
        out = self.ort_session.run(None, ort_inputs)[0]
        pass

