r"""
    Establish agent abstraction
"""
from typing import List
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import time
from copy import deepcopy
from einops import rearrange
import logging
logging.basicConfig(level = logging.INFO)
logger = logging.getLogger(__name__)

from context_general_bci.utils import suppress_default_registry
suppress_default_registry()
from context_general_bci.model import BrainBertInterface
from context_general_bci.config import RootConfig, DataKey, Output, ModelTask
from context_general_bci.streaming_utils import precrop_batch, shift_constraint
from context_general_bci.rl.rl import IQLAgent, UnifiedCausalLayer
from context_general_bci.rl.utils import hard_update_params

from context_general_bci.rtndt.config.decoder_config import (
    OnlineConfig,
    OnlineConfigMutable,
    Accelerator
)
from context_general_bci.rtndt.accelerators import DummyAccelerator, OnnxAccelerator, SlimAccelerator
from context_general_bci.rtndt.decoder_utils import (
    roll_and_pad,
    nucleus_filter,
    sample_from_logits
)

# # TODO try to copy gym env?
# class BCIEnv:
#     def reset(self):
#         # Return observation
#         return np.zeros((1, 1))

class StateLogger:
    r"""
        Debugging class for storing agent IO.
        Tracking parity with offline analysis.
        If this logger's IO can be identically reproduced, then issues are upstream
        (CLIMBER dropped packets/transforms/ or climber/ndt3.py - normalization, etc)
    """
    def __init__(self, limit=3000, save_limit=1): # about 60s of data, should be plenty if we know what we're doing
        self.reset()
        self.limit = limit
        self.save_limit = save_limit
        self.saved_count = 0

    def reset(self):
        self.neural_obs = []
        self.actions = []
        self.collecting = False

    def trigger_start(self):
        if self.collecting:
            logger.warning("StateLogger already collecting")
            return
        if self.saved_count >= self.save_limit:
            logger.warning("StateLogger save limit reached, not recording anymore")
            return
        self.collecting = True

    def record(self,
               neural_obs,
               actions):
        if not self.collecting:
            return
        if len(self.neural_obs) >= self.limit:
            self.save_payload()
            self.reset()
        self.neural_obs.append(neural_obs)
        self.actions.append(actions)

    def save_payload(self):
        # note timestamp in human readable
        timestamp = time.strftime("%Y%m%d-%H%M%S")

        torch.save({
            'neural_obs': self.neural_obs,
            'actions': self.actions
        }, f'state_logger_{timestamp}.pt')
        self.saved_count += 1

class BCIAgent:
    r"""
        Agent abstraction.
        Responsible for managing cached history needed for inference.

        Current responsibilities still tentative:
        - reconcile exploration policy (optimal control) with underlying neural decoder
    """
    def __init__(self, cfg: OnlineConfig, cfg_mutable: OnlineConfigMutable):
        self.cfg = cfg
        self.cfg_mutable = cfg_mutable
        self.backbone_cfg = None

        if self.cfg.serve_framework == Accelerator.vanilla:
            self.accelerator = DummyAccelerator(self.cfg)
        elif self.cfg.serve_framework == Accelerator.slim:
            self.accelerator = SlimAccelerator(self.cfg)
        elif self.cfg.serve_framework == Accelerator.onnx:
            self.accelerator = OnnxAccelerator(self.cfg)
        else:
            raise NotImplementedError

        self.logger = StateLogger()
        self.reset()

    @property
    def ready(self):
        return self.accelerator.ready

    @property
    def buffer_ready(self):
        return self.spike_history is not None

    @property
    def tag(self):
        if self.backbone_cfg is not None:
            return self.backbone_cfg.tag
        return None

    def reset(self):
        # TODO refactor into two calls.
        self.backbone_cfg = None
        self.reference = None
        self.accelerator.reset()
        self.spike_history = None
        self.constraint_history = None
        self.cov_history = None
        self.reward_history = None
        self.neural_obs_lengths = None
        self.model_value_history = None
        self.model_return_condition_seeded = False

    def reset_buffer(
            self,
            timesteps: int,
            neural_obs_dims: int,
            constraint_obs_dims: int,
            action_space_dims: int,
            neural_obs_lengths: List[int],
        ):
        self.backbone_cfg = None
        self.accelerator.reset()

        self.buffer_timesteps = timesteps

        # TODO how to we signify whether to actually cache here?
        # Note need is different per accelerator.
        self.spike_history = np.zeros((self.buffer_timesteps, neural_obs_dims), dtype=np.uint8)
        self.constraint_history = np.zeros((self.buffer_timesteps, 3, constraint_obs_dims), dtype=np.float32)
        if self.cfg.default_constraint_fbc:
            self.constraint_history[:, 0] = self.cfg.default_constraint_fbc
        self.cov_history = np.zeros((self.buffer_timesteps, action_space_dims), dtype=np.float32)
        self.reward_history = np.zeros((self.buffer_timesteps), np.uint8)

        self.neural_obs_lengths = neural_obs_lengths

        # RCBC
        self.model_value_history = np.zeros((self.buffer_timesteps), dtype=float)
        self.model_return_condition_seeded = False
        self.debug = defaultdict(list)

    def compile_and_load(
            self,
            model: BrainBertInterface,
            cfg: RootConfig,
        ):
        self.backbone_cfg = cfg
        self.accelerator.compile_and_load(model, cfg, self.spike_history)
        self.reference = None

        if self.cfg.serve_rl:
            stub_cfg = deepcopy(cfg.model.transformer)
            stub_cfg.n_layers = 1
            qf1 = UnifiedCausalLayer(stub_cfg)
            target_qf1 = UnifiedCausalLayer(stub_cfg)
            hard_update_params(qf1, target_qf1)
            qf2 = UnifiedCausalLayer(stub_cfg)
            target_qf2 = UnifiedCausalLayer(stub_cfg)
            hard_update_params(qf2, target_qf2)
            policy = model.task_pipelines[ModelTask.kinematic_infill]
            # Not quite... my pipeline acts from backbone queries. I need to rerun backbone if I want this..
            # This doesn't make sense. We basically need to rerun backbone. That's the only thing that makes sense.
            self._agent = IQLAgent(
                policy=policy,
                qf1=qf1,
                qf2=qf2,
                vf=UnifiedCausalLayer(stub_cfg),
                target_qf1=target_qf1,
                target_qf2=target_qf2,
                quantile=self.cfg.rl.expectile_tau,
                discount=self.cfg.rl.discount,
            )
        assert self.ready

    def set_prefix(self, reference, reference_timesteps):
        self.prefix = precrop_batch(reference, reference_timesteps)

    def update_last_value(self, value):
        self.model_value_history[-1] = value

    def _condition_return_rcbc(self, return_logits: torch.Tensor) -> int:
        if self.cfg_mutable.fixed_return:
            return_cond = self.cfg_mutable.fixed_return
        else:
            # Compute kappa offset based on the original tensor
            # Perform initial filtering on logits before offset to find relevant dynamic range for task
            # ASSUMES MODEL HAS LEARNED A REASONABLE DYNAMIC RANGE AFTER TUNING (REDACT thinks this is reasonable)
            plausible_indices = nucleus_filter(return_logits[2:], self.cfg_mutable.return_nucleus_p) + 2 # exclude padding and zero reward
            if self.cfg_mutable.return_logit_offset_kappa > 0.0:
                logits_opt = np.concatenate([
                    np.zeros(1), # extra 0 for padding index
                    np.linspace(0.0, 1.0, plausible_indices.max()),
                    np.ones(return_logits.shape - plausible_indices.max() - 1)
                ]) # Make a cropped linear ramp so we offset in relevant dynamic range, but preserve euclidean preference within range
                return_logits = return_logits + self.cfg_mutable.return_logit_offset_kappa * logits_opt
            return_logits = return_logits[plausible_indices]

            return_cond = sample_from_logits(
                return_logits,
                temperature=self.cfg_mutable.return_temp,
                nucleus_p=self.cfg_mutable.return_nucleus_p
            )

            # Map sampled indices back to original positions
            return_cond = plausible_indices[return_cond]
            RETURN_FLOOR = 6
            # Fallback, hacky stablization heuristic.
            # return_cond = max(return_cond, RETURN_FLOOR) # Fallback, hacky stablization heuristic.
            return_cond = RETURN_FLOOR

            if return_cond >= 1:
                return_cond -= 1 # -1 for removing padding offset, see dataloader
            else:
                logger.warning(f"Decoder predicted padding as return: Concerning if continuously reporting.")
        if not self.model_return_condition_seeded:
            self.model_value_history[:] = return_cond # Overwrite full history to _stabilize_ - else return fluctuates
            self.model_return_condition_seeded = True
        else:
            self.update_last_value(return_cond)
        return return_cond

    def observe(
            self,
            neural_obs: np.ndarray, # (obs_space,)
            constraint_obs: np.ndarray, # 3
        ): # For burn-in where data is streaming but weights aren't loaded
            if not self.buffer_ready:
                return {}
            self.spike_history = roll_and_pad(self.spike_history, 1, new_value=neural_obs)
            self.constraint_history = roll_and_pad(self.constraint_history, 1, new_value=constraint_obs)
            return {}

    # Generic agent inputs
    def act(
            self,
            neural_obs: np.ndarray, # (obs_space,)
            constraint_obs: np.ndarray, # (constraint_types=3, MAX_ACTION_SPACE)
            prev_action: np.ndarray, # (MAX_ACTION_SPACE,)
            prev_reward: np.ndarray, # (,) # Scalar arr, not offset
            dimensions: np.ndarray, # (action_space,)
            temperature: float = 1.0,
            # NDT3 specific implementation
            action_mask_timesteps: np.ndarray | None = None, # (T,)
            reference_timesteps: int = 0,
            working_timesteps: int = 0,
        ):
            r"""
                Agent abstraction:
                - mediates the flexible action space
                - mediates working / reference cropping (but not their merging)
                # ! TODO only cache as much as the underlying accelerator requests
                We roll everything because current model implementation expects timelocked seq e.g. s_t, a_t, r_t.

                `action_mask_timesteps`: Boolean array of context length. True if action prior at timestep should be ignored.
                - In a non-cached full context implementation, this prior can be dynamically adjusted to e.g. allow conditioning
                on the first 50% of the context, to allow a "weak, distant prior" for the present timestep - supported by prefix masking in pretraining.
                - How recurrent/caching implementations only support all or none masking. Implementation wise, this is set by the
                value of the final timestep. When its true, the just predicted kin token is provided to the model, and the model
                will cache that / integrate that until it leaves context. (There is no way to "remove" the token without recomputing all tokens after it, and
                its relative location in context shifts with each timestep).
                # TODO restore reference
            """
            # time_start = time()
            self.observe(neural_obs, constraint_obs)
            self.cov_history[-1] = prev_action
            self.cov_history = roll_and_pad(self.cov_history, 1, new_value=0)
            # Note RCBC was trained taking acausal reward after current timestep. We correct at this level, high autocorrelation should make this ok.
            self.reward_history[-1] = prev_reward
            self.reward_history = roll_and_pad(self.reward_history, 1, new_value=0)
            self.model_value_history = roll_and_pad(self.model_value_history, 1, new_value=self.model_value_history[-1])

            working_constraint = self.constraint_history[..., dimensions]

            working_spikes = self.spike_history[-working_timesteps:]
            working_cov = self.cov_history[-working_timesteps:, dimensions]
            working_constraint = working_constraint[-working_timesteps:]
            working_reward = self.reward_history[-working_timesteps:]
            working_model_return = self.model_value_history[-working_timesteps:]
            if self.backbone_cfg.dataset.sparse_constraints:
                change_steps = np.concatenate([
                    np.array([0]),
                    (working_constraint[1:] != working_constraint[:-1]).any(1).any(1).nonzero()[0] + 1 # [0] for numpy instead of squeeze in `dataset`, this is a 1D array
                ])
                # T x 3 x Bhvr_Dim
                working_constraint = working_constraint[change_steps]
                working_constraint_time = change_steps
            else:
                working_constraint_time = None

            # Apply observation filtering trickery
            # changestep = False
            # if len(self.debug['true']) > 0:
            #     if (self.debug['true'][-1] != working_constraint[-1]).any():
            #         changestep = True
            #         print(f"True {working_constraint.shape}: \n{working_constraint[-1, :2]}")
            # self.debug['true'].append(working_constraint[-1])
            # normal, kappa, for aa = 1.0 - I alternately incremented kappa and decrement 0.0
            # The challenge is to understand why kappa = 1.0 and aa = 0.0 doesn't reflect original, i.e. how is the model learning to fail?
            # We can compare the final timesteps vs the first timesteps, but also see transition in b/n
            if self.cfg_mutable.return_logit_offset_kappa > 0.0:
                # Make constraints appear to agent stronger than they are - so that we are more in distribution
                if not (
                    self.backbone_cfg.model.task.constraint_support_mute or \
                    self.backbone_cfg.model.task.constraint_mute or \
                    self.backbone_cfg.model.task.constraint_noise
                ):
                    # breakpoint() # TODO Double check effect / dimensions
                    working_constraint = rearrange(shift_constraint(
                        rearrange(working_constraint, 't h k -> t k h'),
                        self.cfg_mutable.return_logit_offset_kappa
                    ), 't k h -> t h k')

            if action_mask_timesteps is not None:
                action_mask_timesteps = action_mask_timesteps[-(reference_timesteps + working_timesteps):]

            # time_pre_predict = time()

            out = self.accelerator.predict(
                    working_spikes,
                    working_cov,
                    working_constraint,
                    working_reward,
                    working_model_return,
                    np.arange(working_model_return.shape[0]),
                    reference=self.reference,
                    kin_mask_timesteps=action_mask_timesteps,
                    temperature=temperature,
                    spike_array_lengths=self.neural_obs_lengths,
                    return_seeded=self.model_return_condition_seeded, # ! Not implemented internally e.g. no special treatment
                    num_kin=working_cov.shape[-1],
                    sparse_constraint_time=working_constraint_time,
                )
            if self.cfg.record_initial:
                # record since the first prediction. TODO, record down the full history...s
                assert temperature == 0, 'No recording for nondeterministic predictions'
                assert self.backbone_cfg.model.task.constraint_mute, "Muted constraint expected"
                assert self.backbone_cfg.model.task.return_mute, "Muted return expected"
                assert self.reference is None, "No reference expected"
                if not self.logger.collecting:
                    self.logger.trigger_start()
                # Note this is limited, we'll extend if this is a useful abstraction
                self.logger.record(
                    working_spikes,
                    out
                )

            # time_post_predict = time()

            if out.get(Output.return_logits, None) is not None:
                self._condition_return_rcbc(out[Output.return_logits])
            elif Output.state_value in out:
                self.update_last_value(out[Output.state_value])
            # else assume values aren't used, return muted


            # time_post_ret = time()
            # print("Timing diagnostics [Act]")
            # print(f"Prepredict: {time_pre_predict - time_start:.4f}")
            # print(f"Postpredict: {time_post_predict - time_pre_predict:.4f}")
            # print(f"Postret: {time_post_ret - time_post_predict:.4f}")

            return out

    def get_last_value(self) -> float:
        return self.model_value_history[-1]

    def set_streaming_timestep_limit(self, limit: int):
        if limit < 1:
            logger.warning(f"Invalid streaming timestep limit {limit}, setting to 1")
            limit = 1
        if limit > 750:
            logger.warning(f"Streaming timestep limit {limit} exceeds recommended maximum of 750, setting to 750")
            limit = 750
        if isinstance(self.accelerator, SlimAccelerator):
            self.accelerator.set_streaming_timestep_limit(limit)

    r"""
        RL Pieces
    """

    def update(
            self,
            batch,
            utd_ratio: int,
        ):
            self._agent.update(batch, utd_ratio)

# class BCIRunner:

#     def main():
#         observation, done = env.reset(), False
#         while True:
#             action, agent = agent.sample_actions(observation)
#             next_observation, reward, done, info = env.step(action)

