import numpy as np
import signal
import time
import torch
import copy
from dl_utils.utils import get_mask_past_ids, get_mask_past_id
from utils import check_correct_count, print_tensors
from das import CausalInterchange, RecurRotCausalInterchange
#from models.interventions import (
#    RotatedSpaceIntervention,
#    BoundlessRotatedSpaceIntervention,
#    #LearnableMaskRotatedSpaceIntervention,
#)

def error_handle(signum, frame):
    raise KeyboardInterrupt

def device_fxn(device):
    if device<0: return "cpu"
    return device

def reset_comms_dict(comms_dict=None):
    """
    Helper function to reset comms dict after completion of data batch.
    """
    if comms_dict is None: comms_dict = dict()
    if len(comms_dict)==0:
        comms_dict["source"] = []
        comms_dict["loop_count"] = 0
        comms_dict["unit_forward"] = False
    for k,v in comms_dict.items():
        if type(v)==int:
            comms_dict[k] = 0
        elif type(v)==list:
            comms_dict[k] = []
        elif type(v)==bool:
            comms_dict[k] = False
    return comms_dict

def get_activations_hook(comms_dict, key="source", to_cpu=False):
    """
    Returns a hook that can be used to collect activations at the
    specified layer. The argued comms_dict is expected to have a key
    for the activations that stores a list of past activations. Each
    call to the hook will append the latest activations to the list
    stored in the comms dict.

    Args:
        comms_dict: dict
            This is the communications dict. It will store a list of the
            activations that come through the hooked layer. Each new entry
            in the list corresponds to a call to the model. The argued "key"
            to this function will be the key to access the activation list.
        key: str
            the key to use to store the activations list.
        to_cpu: bool
            if true, the activations will be stored on the CPU. Otherwise
            they will stay on their current device.
    Returns:
        hook: pytorch hook function
            this function will simply collect the activations that exit
            the hooked layer.
    """
    if key not in comms_dict or type(comms_dict[key])!=list:
        comms_dict[key] = []

    if to_cpu:
        def hook(module, inp, out):
            if type(out)==dict:
                if "hidden_states" in out:
                    comms_dict[key].append(out["hidden_states"].cpu())
                else:
                    comms_dict[key].append(out["attentions"].cpu())
            else:
                comms_dict[key].append(out.cpu())
    else:
        def hook(module, inp, out):
            if type(out)==dict:
                comms_dict[key].append(out["hidden_states"])
            else:
                comms_dict[key].append(out)
    return hook

def get_rnn_interchange_hook(
        intr_modu,
        comms_dict,
        source_actvs_key="source",
        base_idx_key="base_idxs",
        source_idx_key="source_idxs",
        loop_count_key="loop_count",
        subspaces=None,
        *args, **kwargs):
    """
    Returns a hook that can be used to perform a causal interchange at
    the hooked layer. Assumes that the comms dict is already populated
    with the source activations. If using a sequential model, the comms
    dict also needs to store the indexs for both the base and source
    sequences at which to perform the causal interchange.

    Args:
        intr_modu: Intervention
            the module to perform the intervention on the activation
            vectors.
        comms_dict: dict
            This is the communications dict.

            source_actvs_key: torch tensor (B,...)
                The activations from the source sequence.
            base_idx_key: torch tensor (B,)
                The positions of the representations in the base sequence
                that we want to use for the interchange.
            src_idx_key: torch tensor (B,)
                The positions of the representations in the source sequence
                that we want to use for the interchange.
            loop_count_key: int
                indicates which loop we're in if using a recurrent model
                with a for loop.
        key: str
            the key to use to store the activations list.
    Returns:
        hook: pytorch hook function
    """
    def hook(module, inp, out):
        device = device_fxn(out.get_device())
        og_out_shape = out.shape

        if comms_dict["pad_mask"] is None:
            pad_mask = torch.ones(len(out)).bool().to(device)
        else:
            # assumes mask denotes padding with true
            pad_mask = ~(comms_dict["pad_mask"].bool()).to(device)
            pad_mask = pad_mask[:,comms_dict[loop_count_key]]

        source_actvs = comms_dict[source_actvs_key][pad_mask]
        shape = source_actvs.shape
        if shape[0]==0: return out
        source_actvs = source_actvs.reshape(shape[0],shape[1],-1)
        source_actvs = source_actvs.to(device)

        # Get positional indices of the interchange for each sample in
        # the batch.
        source_seq_idxs = comms_dict[source_idx_key].long()[pad_mask]
        base_seq_idxs = comms_dict[base_idx_key].long()[pad_mask]
        batch_bools = torch.ones_like(base_seq_idxs).bool()

        # If we're using a sequential model that uses for loops to step
        # through the sequence, we will only perform the interchange on
        # the samples in the batch that are indicated at this loop by
        # the base_idxs. 
        if len(out.shape)==2 or out.shape[1]==1:
            batch_bools = base_seq_idxs==comms_dict[loop_count_key]
            base_seq_idxs = torch.zeros_like(batch_bools).long()
            out = out.reshape(len(out),1,-1)
        intr_out = out.clone()

        comms_dict[loop_count_key] += 1
        if batch_bools.float().sum()==0: return out.reshape(og_out_shape)

        # Get appropriate inputs for interchange
        idxs = torch.arange(len(batch_bools)).long().to(device)
        idxs = idxs[batch_bools]
        base_idxs = base_seq_idxs[batch_bools]
        source_idxs = source_seq_idxs[batch_bools]

        base_inpts = out[idxs, base_idxs]
        source_inpts = source_actvs[idxs, source_idxs]

        # Perform causal interchange
        if "unit_forward" in comms_dict and comms_dict["unit_forward"]:
            outs = intr_modu.unit_forward(
                base=base_inpts,
                source=source_inpts.to(device),)
        else:
            outs = intr_modu(
                base=base_inpts,
                source=source_inpts.to(device),)

        # Place causally intervened outputs into appropriate locations
        # in original output tensor
        intr_out[idxs, base_idxs] = 0
        intr_out[idxs, base_idxs] += outs

        return intr_out.reshape(og_out_shape)
    return hook

def get_state_interchange_hook(
        intr_modu,
        comms_dict,
        source_actvs_key="source",
        base_idx_key="base_idxs",
        source_idx_key="source_idxs",
        loop_count_key="loop_count",
        subspaces=None,
        *args, **kwargs):
    """
    Returns a hook that can be used to perform a causal interchange at
    the hooked layer. Assumes that the comms dict is already populated
    with the source activations. If using a sequential model, the comms
    dict also needs to store the indexs for both the base and source
    sequences at which to perform the causal interchange.

    Args:
        intr_modu: Intervention
            the module to perform the intervention on the activation
            vectors.
        comms_dict: dict
            This is the communications dict.

            source_actvs_key: torch tensor (B,...)
                The activations from the source sequence.
            base_idx_key: torch tensor (B,)
                The positions of the representations in the base sequence
                that we want to use for the interchange.
            src_idx_key: torch tensor (B,)
                The positions of the representations in the source sequence
                that we want to use for the interchange.
            loop_count_key: int
                indicates which loop we're in if using a recurrent model
                with a for loop.
        key: str
            the key to use to store the activations list.
    Returns:
        hook: pytorch hook function
    """
    def hook(module, inp, out):
        """
        If you've hooked into an identity module in the forward
        function of a Mamba based model:

        out: tensor (B,M,D)
            the mamba recurrent states where M is the number of SSM
            states
        """
        h = out
        if type(out)==dict:
            h = h["hidden_states"]
        device = device_fxn(h.get_device())
        og_h_shape = h.shape

        #if comms_dict["pad_mask"] is None:
        #    pad_mask = torch.ones(len(h)).bool().to(device)
        #else:
        #    # assumes mask denotes pad ids as true
        #    pad_mask = ~(comms_dict["pad_mask"].bool()).to(device)
        #    pad_mask = pad_mask[:,comms_dict[loop_count_key]]

        #if pad_mask.long().sum()==0:
        #    return out

        # General to multi-dimensional states or single vector states
        source_actvs = comms_dict[source_actvs_key]
        B,S = source_actvs.shape[:2]
        source_actvs = source_actvs.reshape(B,S,-1)
        source_actvs = source_actvs.to(device)

        #print("B,S:", B,S)
        #print("og_out shape:", og_h_shape)
        #print("pre pad:", comms_dict[source_actvs_key].shape)
        #print("pad:", pad_mask.shape)
        #print("ptype:", type(pad_mask), pad_mask.dtype)
        #print("h:", h.shape, type(h), h.dtype)

        # Get positional indices of the interchange for each sample in
        # the batch.
        source_seq_idxs = comms_dict[source_idx_key].long()
        base_seq_idxs = comms_dict[base_idx_key].long()
        batch_bools = base_seq_idxs==comms_dict[loop_count_key]
        h = h.reshape(B,-1) # assume no seq dim
        intr_out = h.clone()

        comms_dict[loop_count_key] += 1
        if batch_bools.float().sum()==0:
            h = h.reshape(og_h_shape)
            if type(out)==dict:
                out["hidden_states"] = h
                h = out
            return h

        # Get appropriate inputs for interchange
        idxs = torch.arange(len(batch_bools)).long().to(device)
        idxs = idxs[batch_bools]
        #base_idxs = base_seq_idxs[batch_bools]
        source_idxs = source_seq_idxs[batch_bools]

        base_inpts = h[idxs]
        source_inpts = source_actvs[idxs, source_idxs]

        #print("source_idxs:", source_idxs.shape)
        #print("h:", h.shape)
        #print("sactvs:", source_actvs.shape)
        #print("base_inpts:", base_inpts.shape)
        #print("source_inpts:", source_inpts.shape)

        # Perform causal interchange
        outs = intr_modu(base=base_inpts,source=source_inpts.to(device),)

        # Place causally intervened outputs into appropriate locations
        # in original output tensor. We do it this way to avoid auto-grad
        # errors for in-place operations
        intr_out[idxs] = 0
        intr_out[idxs] += outs

        intr_out = intr_out.reshape(og_h_shape)
        if type(out)==dict:
            out["hidden_states"] = intr_out
            intr_out = out
        return intr_out
    return hook

def get_recurrot_interchange_hook(
        intr_modu,
        comms_dict,
        source_actvs_key="source",
        base_idx_key="base_idxs",
        source_idx_key="source_idxs",
        loop_count_key="loop_count",
        seq_dim=1,
        batch_first=True,
        subspaces=None,
        *args, **kwargs):
    """
    Returns a hook that can be used to perform a causal interchange at
    the hooked layer in a transformer using a recurrent rotation.
    Assumes that the comms dict is already populated with the source
    activations.

    There are many ways that you might want to intervene on a
    transformer's representations. This function will perform a series
    of invertible rotations on the items in the sequence to make a
    compressed state vector. It will then intervene on this state vector
    using the source sequence's invertible state vector similar to regular
    DAS. It will then invert the intervened state vector for the modified
    context.

    Args:
        intr_modu: RecurRotCausalInterchange
            the module to perform the intervention on the activation
            vectors.
        comms_dict: dict
            This is the communications dict.

            source_actvs_key: torch tensor (B,...)
                The activations from the source sequence.
            base_idx_key: torch tensor (B,)
                The positions of the representations in the base sequence
                that we want to use for the interchange.
            src_idx_key: torch tensor (B,)
                The positions of the representations in the source sequence
                that we want to use for the interchange.
            loop_count_key: int
                indicates which loop we're in if using a recurrent model
                with a for loop.
        key: str
            the key to use to store the activations list.
        seq_dim: int
            the dimension to treat as the sequence dimension
        batch_first: bool
            if true assumes the batch dimension is the first dimension.
            Otherwise assumes it is the second dimension.
    Returns:
        hook: pytorch hook function
    """
    if batch_first:
        def hook(module, inp, base_actvs):
            og_actvs = base_actvs
            if type(og_actvs)==dict:
                base_actvs = base_actvs["hidden_states"]
            device = device_fxn(base_actvs.get_device())
            og_base_shape = base_actvs.shape
            source_actvs = comms_dict[source_actvs_key].to(device)

            # Get positional indices of the interchange for each sample in
            # the batch.
            source_idxs = comms_dict[source_idx_key].long().to(device)
            base_idxs = comms_dict[base_idx_key].long().to(device)

            if seq_dim!=1:
                # only handles the sequence dimension in the 1st dim
                raise NotImplemented
                #dims = [i for i in range(len(base_actvs.shape))]
                #dims[1] = seq_dim
                #dims[seq_dim] = 1
                #base_actvs = base_actvs.permute(dims)
                #source_actvs = source_actvs.permute(dims)

            # Perform causal interchange
            base_copy, h = intr_modu(
                base=base_actvs,
                source=source_actvs,
                base_idxs=base_idxs,
                source_idxs=source_idxs)

            comms_dict["intr_h"] = h
            if type(og_actvs)==dict:
                base_copy = {**og_actvs, "hidden_states": base_copy}
            return base_copy
    else:
        raise NotImplemented # Untested code
    return hook

def get_transformer_interchange_hook(
        intr_modu,
        comms_dict,
        source_actvs_key="source",
        base_idx_key="base_idxs",
        source_idx_key="source_idxs",
        loop_count_key="loop_count",
        seq_dim=1,
        batch_first=True,
        subspaces=None,
        *args, **kwargs):
    """
    Returns a hook that can be used to perform a causal interchange at
    the hooked layer in a transformer. Assumes that the comms dict is
    already populated with the source activations.

    There are many ways that you might want to intervene on a
    transformer's representations. This function will isolate a single
    index in the sequence for both the source and base activations for
    a single sample. If, however, you want to treat the entire context
    as a single representation, see the function
    `get_arbitrary_interchange_hook`.

    You probably want to intervene on a full transformer encoder layer.
    Make sure that you do not use teacher forcing.

    Args:
        intr_modu: Intervention
            the module to perform the intervention on the activation
            vectors.
        comms_dict: dict
            This is the communications dict.

            source_actvs_key: torch tensor (B,...)
                The activations from the source sequence.
            base_idx_key: torch tensor (B,)
                The positions of the representations in the base sequence
                that we want to use for the interchange.
            src_idx_key: torch tensor (B,)
                The positions of the representations in the source sequence
                that we want to use for the interchange.
            loop_count_key: int
                indicates which loop we're in if using a recurrent model
                with a for loop.
        key: str
            the key to use to store the activations list.
        seq_dim: int
            the dimension to treat as the sequence dimension
        batch_first: bool
            if true assumes the batch dimension is the first dimension.
            Otherwise assumes it is the second dimension.
    Returns:
        hook: pytorch hook function
    """
    if batch_first:
        def hook(module, inp, base_actvs):
            og_actvs = base_actvs
            if type(og_actvs)==dict:
                base_actvs = base_actvs["hidden_states"]
            device = device_fxn(base_actvs.get_device())
            og_base_shape = base_actvs.shape
            source_actvs = comms_dict[source_actvs_key].to(device)
            B,S,dummy,D = source_actvs.shape
            source_actvs = source_actvs.reshape(B,S,D)

            print("og_base_shape:", og_base_shape)
            print("source shape:", source_actvs.shape)
            return og_actvs

            # Get positional indices of the interchange for each sample in
            # the batch.
            source_idxs = comms_dict[source_idx_key].long().to(device)
            base_idxs = comms_dict[base_idx_key].long().to(device)

            base_bools = comms_dict[loop_count_key] == base_idxs
            comms_dict[loop_count_key] += 1

            if seq_dim!=1:
                # only handles the sequence dimension in the 1st dim
                raise NotImplemented
                #dims = [i for i in range(len(base_actvs.shape))]
                #dims[1] = seq_dim
                #dims[seq_dim] = 1
                #base_actvs = base_actvs.permute(dims)
                #source_actvs = source_actvs.permute(dims)

            # Get appropriate inputs for interchange. Assumes using
            # no tforcing
            idxs = torch.arange(len(base_actvs)).long().to(device)
            base_inpts = base_actvs[base_bools][:,0]
            idxs = idxs[base_bools]
            sidxs = source_idxs[base_bools]
            source_inpts = source_actvs[idxs, sidxs]

            # Perform causal interchange
            interveneds = intr_modu(base=base_inpts,source=source_inpts,)

            # Place causally intervened outputs into appropriate locations
            # in original output tensor
            base_copy = base_actvs.clone()
            base_copy[idxs, base_idxs] = 0
            base_copy[idxs, base_idxs] += interveneds

            #if seq_dim!=1:
            #    dims = [i for i in range(len(base_copy.shape))]
            #    dims[1] = seq_dim
            #    dims[seq_dim] = 1
            #    base_copy = base_copy.permute(dims)
            if type(og_actvs)==dict:
                base_copy = {**og_actvs, "hidden_states": base_copy}
            return base_copy
    else:
        raise NotImplemented # Untested code
        def hook(module, inp, base_actvs):
            og_actvs = base_actvs
            if type(og_actvs)==dict:
                base_actvs = base_actvs["hidden_states"]
            device = device_fxn(base_actvs.get_device())
            og_base_shape = base_actvs.shape
            source_actvs = comms_dict[source_actvs_key].to(device)

            # Move batch dim to index 0
            dims = [i for i in range(len(base_actvs.shape))]
            dims[0] = 1
            dims[1] = 0
            base_actvs = base_actvs.permute(dims)
            source_actvs = source_actvs.permute(dims)
            if seq_dim==0: seq_dim = 1
            # Move seq dim to index 1
            if seq_dim!=1:
                dims = [i for i in range(len(base_actvs.shape))]
                dims[1] = seq_dim
                dims[seq_dim] = 1
                base_actvs = base_actvs.permute(dims)
                source_actvs = source_actvs.permute(dims)
            base_copy = base_actvs.clone()

            # Get positional indices of the interchange for each sample in
            # the batch.
            source_seq_idxs = comms_dict[source_idx_key].long().to(device)
            base_seq_idxs = comms_dict[base_idx_key].long().to(device)

            # Get appropriate inputs for interchange
            idxs = torch.arange(len(base_actvs)).long().to(device)
            base_inpts = base_actvs[idxs, base_idxs]
            source_inpts = source_actvs[idxs, source_idxs]

            # Perform causal interchange
            interveneds = intr_modu(base=base_inpts,source=source_inpts,)

            # Place causally intervened outputs into appropriate locations
            # in original output tensor
            base_copy[idxs, base_idxs] = 0
            base_copy[idxs, base_idxs] += interveneds

            # return seq dim to original index
            if seq_dim!=1:
                dims = [i for i in range(len(base_copy.shape))]
                dims[1] = seq_dim
                dims[seq_dim] = 1
                base_copy = base_copy.permute(dims)
            # Return batch dim to index 1
            dims = [i for i in range(len(base_copy.shape))]
            dims[0] = 1
            dims[1] = 0
            base_copy = base_copy.permute(dims)
            if type(og_actvs)==dict:
                base_copy = {**og_actvs, "hidden_states": base_copy}
            return base_copy
    return hook

def get_context_interchange_hook(
        intr_modu,
        comms_dict,
        source_actvs_key="source",
        base_idx_key="base_idxs",
        source_idx_key="source_idxs",
        loop_count_key="loop_count",
        seq_dim=1,
        batch_first=True,
        subspaces=None,
        trailing_idx=0,
        *args, **kwargs):
    """
    Returns a hook that can be used to perform a causal interchange at
    the hooked layer across multiple entries in the context in a
    transformer. Assumes that the comms dict is
    already populated with the source activations.

    Args:
        intr_modu: Intervention
            the module to perform the intervention on the activation
            vectors.
        comms_dict: dict
            This is the communications dict.

            source_actvs_key: torch tensor (B,...)
                The activations from the source sequence.
            base_idx_key: torch tensor (B,)
                The positions of the representations in the base sequence
                that we want to use for the interchange.
            src_idx_key: torch tensor (B,)
                The positions of the representations in the source sequence
                that we want to use for the interchange.
            loop_count_key: int
                indicates which loop we're in if using a recurrent model
                with a for loop.
        key: str
            the key to use to store the activations list.
        seq_dim: int
            the dimension to treat as the sequence dimension
        batch_first: bool
            if true assumes the batch dimension is the first dimension.
            Otherwise assumes it is the second dimension.
        ##trailing_idx: int
        ##    optionally argue a value here to limit the number of tokens
        ##    in the flattened context fed into the intervention. The
        ##    tokens that will be used are the argued index and the
        ##    `trailing_idx` tokens leading up to the argued index.
    Returns:
        hook: pytorch hook function
    """
    if trailing_idx>0:
        raise NotImplemented
    if batch_first:
        def hook(module, inp, base_actvs):
            og_actvs = base_actvs
            if type(og_actvs)==dict:
                base_actvs = base_actvs["hidden_states"]
            device = device_fxn(base_actvs.get_device())
            og_base_shape = base_actvs.shape
            source_actvs = comms_dict[source_actvs_key].to(device).clone()
            if seq_dim!=1:
                dims = [i for i in range(len(base_actvs.shape))]
                dims[1] = seq_dim
                dims[seq_dim] = 1
                base_actvs = base_actvs.permute(dims)
                source_actvs = source_actvs.permute(dims)
            base_copy = base_actvs.clone()
            B,S = base_actvs.shape[:2]

            # Get positional indices of the interchange for each sample in
            # the batch.
            source_idxs = comms_dict[source_idx_key].long().to(device)
            base_idxs = comms_dict[base_idx_key].long().to(device)

            idxs = torch.arange(len(base_actvs)).long().to(device)
            # Get masks that will indicate which values to zero out
            arbitrary_id = 2
            base_mask = torch.ones(base_actvs.shape[:2], device=device)
            base_mask[idxs,base_idxs] = arbitrary_id
            base_mask = get_mask_past_id(base_mask, id_=arbitrary_id)
            base_mask = (~base_mask).float()[...,None]
            source_mask = torch.ones(source_actvs.shape[:2], device=device)
            source_mask[idxs,source_idxs] = arbitrary_id
            source_mask = get_mask_past_id(source_mask, id_=arbitrary_id)
            source_mask = (~source_mask).float()[...,None]

            # Get appropriate inputs for interchange
            base_shape = base_actvs.shape
            base_inpts = (base_actvs*base_mask).reshape(B,-1)
            source_inpts = (source_actvs*source_mask).reshape(B,-1)

            # Perform causal interchange
            interveneds = intr_modu(base=base_inpts,source=source_inpts,)

            # Place causally intervened outputs into appropriate locations
            # in original output tensor
            interveneds = interveneds.reshape(base_shape)
            base_copy = (1-base_mask)*base_copy + base_mask*interveneds

            if seq_dim!=1:
                dims = [i for i in range(len(base_copy.shape))]
                dims[1] = seq_dim
                dims[seq_dim] = 1
                base_copy = base_copy.permute(dims)
            if type(og_actvs)==dict:
                base_copy = {**og_actvs, "hidden_states": base_copy}

            return base_copy
    return hook

def get_layer_names(model, layer_types=None):
    """
    Returns a list of layer names that fit the argued layer types. If
    None is argued as layer types, will return all Convolutional layers,
    Linear layers, and layers with a "weight" member.

    Args:
        model: torch Module
        layer_types: list of torch nn types or None
            optionally argue specific layer types to isolate. Otherwise
            defaults to layers with weights.
    Returns:
        layers: list of str
    """
    layers = []
    for name, modu in model.named_modules():
        if "identities.0"==name:
            return [name] # break early if identities in model
        elif "layers.0"==name:
            layers.append(name)
    return layers

def register_activation_hooks(
        model,
        layers,
        comms_dict,
        to_cpu=True):
    """Helper function to register forward hooks at multiple layers"""
    handles = []
    hooked_layers = set()
    for name, mod in model.named_modules():
        if name in layers:
            hooked_layers.add(name)
            hook = get_activations_hook(
                comms_dict=comms_dict,
                key=name,
                to_cpu=to_cpu,)
            handle = mod.register_forward_hook(hook)
            handles.append(handle)
    missing_layers = set(layers)-hooked_layers
    if len(missing_layers) > 0:
        print("Layers", missing_layers, "not found")
    layers = list(hooked_layers)
    return handles, layers

def process_samples(
        model,
        source_actvs,
        source_idxs,
        base_data,
        base_idxs,
        comms_dict,
        key="val",
        resp_signal_only=True,
        layer="layer",
        optimizer=None,
        device=0,
        batch_size=512,
        N=np.inf,
        source_data=None,
        verbose=False):
    """
    This function runs the forward processing for evaluation
    """
    loss,acc,correct = 0,0,0
    n_loops = 0
    perm = torch.arange(len(source_actvs[key])).long()
    losses = []
    accs = []
    corrects = []
    for batch in range(0,N,batch_size):
        if batch>=len(perm): break
        n_loops += 1
        if optimizer is not None:
            optimizer.zero_grad()
        idxs = perm[batch:batch+batch_size]
        comms_dict["source"] = source_actvs[key][idxs].to(device)
        comms_dict["base_idxs"] = base_idxs[key][idxs].to(device)
        comms_dict["source_idxs"] = source_idxs[key][idxs].to(device)
        comms_dict["loop_count"] = 0

        inpts =      base_data[key]["input_ids"][idxs].to(device)
        output_ids = base_data[key]["output_ids"][idxs].to(device)
        task_mask =  base_data[key]["task_mask"][idxs].to(device).bool()
        inpt_pmask = base_data[key]["pad_mask"][idxs].bool().to(device)
        out_pmask =  base_data[key]["output_pad_mask"][idxs].bool().to(device)

        comms_dict["pad_mask"] = inpt_pmask

        output = model(
            inpts.to(device), 
            pad_mask=inpt_pmask.to(device),
            task_mask=task_mask[:,:-1],
            ret_gtruth=False,
            output_attentions="sdp" in layer,
            n_steps=0,
            tforce=False,)
        if hasattr(output, "logits"): logits = output.logits
        elif type(output)==dict and "logits" in output:
            logits = output["logits"]
        else: logits = output
        pred_ids = output["pred_ids"]

        tmask = task_mask[:,1:]
        pmask = ~inpt_pmask
        omask = ~out_pmask
        if resp_signal_only:
            pmask = tmask&pmask
            omask = tmask&omask

        l = lossfxn(
            logits=logits,
            targs=output_ids,
            pred_mask=pmask,
            targ_mask=omask,
            reduce=False)
        losses.append(l.cpu())
        lmean = l.mean()
        if optimizer is not None:
            lmean.backward()
            optimizer.step()
        loss += lmean.item()

        pmask = tmask&pmask
        omask = tmask&omask
        a = accfxn(
                preds=pred_ids,
                targs=output_ids,
                pred_mask=pmask,
                targ_mask=omask,
                reduce=False)
        accs.append(a.cpu())
        acc += a.mean().item()

        rights = torch.zeros_like(pred_ids).bool()
        rights[tmask] = pred_ids[tmask]==output_ids[tmask]
        rights = rights.long().sum(-1)==tmask.long().sum(-1)
        corrects.append(rights.cpu())
        correct += rights.float().mean().item()
        print(round(100*batch/N), end="%           \r")

    if verbose:
        sep = ","
        incorrects = ~rights
        if incorrects.float().sum()>0:
            print(f"Wrong {key} Examples:")
            arr = torch.arange(len(incorrects)).long()
            for i in range(min(3,incorrects.long().sum().item())):
                i = arr[incorrects.cpu()][i]
                print("Seq1 Idx:", comms_dict["base_idxs"][i], "-",
                      "Seq2 Idx:", comms_dict["source_idxs"][i])
                src = [[]]
                if source_data is not None:
                    src = source_data[key]["input_ids"][idxs]
                l = max(
                    len(output_ids[i]),
                    len(pred_ids[i]),
                    len(src[i]))
                print("Idxs :",
                    sep.join(["{:2}".format(_) for _ in range(l)]))
                print("Targs:", sep.join(
                    [" 1"]+\
                    ["{:2}".format(s) for s in output_ids[i].tolist()]))
                print("Preds:", sep.join(
                    [" 1"]+\
                    ["{:2}".format(s) for s in pred_ids[i].tolist()]))
                if len(src[0])>0:
                    print("Seq2 :", sep.join(
                        ["{:2}".format(s) for s in src[i].tolist()]))
                print("Tmask:",
                    sep.join(["{:2}".format(s) for s in task_mask[i].tolist()]))
                print()
                
        print(f"{key} Examples:")
        src = [[]]
        if source_data is not None:
            src = source_data[key]["input_ids"][idxs]
        for i in range(min(3,len(src))):
            print("Seq1 Idx:", comms_dict["base_idxs"][i], "-",
                  "Seq2 Idx:", comms_dict["source_idxs"][i])
            l = max(
                len(output_ids[i]),
                len(pred_ids[i]),
                len(src[i]))
            print("Idxs :", sep.join(
                ["{:2}".format(_) for _ in range(l)]))
            print("Targs:", sep.join(
                [" 1"]+\
                ["{:2}".format(s) for s in output_ids[i].tolist()]))
            print("Preds:", sep.join(
                [" 1"]+\
                ["{:2}".format(s) for s in pred_ids[i].tolist()]))
            if len(src[0])>0:
                print("Seq2 :", sep.join(
                    ["{:2}".format(s) for s in src[i].tolist()]))
            print("Tmask:",
                sep.join(["{:2}".format(s) for s in task_mask[i].tolist()]))
            print()
    return loss/n_loops, acc/n_loops, correct/n_loops, {
        "losses": torch.cat(losses,dim=0),
        "accs": torch.cat(accs, dim=0),
        "corrects": torch.cat(corrects,dim=0)}

def collect_activations(
        model,
        input_ids,
        pad_mask=None,
        task_mask=None,
        layers=None,
        comms_dict=None,
        batch_size=500,
        to_cpu=True,
        ret_attns=False,
        ret_pred_ids=False,
        tforce=False,
        n_steps=0,
        ret_gtruth=False,
        verbose=False,):
    """
    Get the response from the argued layers in the model.

    Args:
        model: torch Module or torch gpu Module
        input_ids: long tensor (N,S)
        pad_mask: bool tensor (N,S)
            true means padding
        task_mask: bool tensor (N,S)
            false means teacher force the corresponding id
        layers: None or sequence of str
            Name of layers to collect activations from. If None will
            collect outputs from every layer.
        batch_size: int
            Optionally break the processing into batches.
        to_cpu: bool
            If true, torch tensors will be placed onto the cpu.
        ret_attn: bool
            if true, will return the attention values as well.
        ret_pred_ids: bool
            if true, will return the prediction ids under "pred_ids"
        tforce: bool
            will use teacher forcing on all inputs if true. if false,
            uses the task mask to determine when to teacher force.
            false indices of the task mask indicate do teacher force.
        n_steps: int
            the number of additional steps to generate on top of the
            initial input.
        ret_gtruth: bool
            if true, the model will return the ground truth pred_ids
            where tmask is false
    returns: 
        comms_dict: dict
            The keys will consist of the corresponding layer name. The
            values will be the activations at that layer.

            "layer_name": torch tensor (N, ...)
    """
    if comms_dict is None: comms_dict = dict()

    # Layers is modified to only the layers that were found in the model.
    # If you argue a layer not found in the model, it will be ignored.
    # There will, however, be a print statement that indicates that the
    # layer was not found.
    handles, layers = register_activation_hooks(
        model=model, layers=layers, comms_dict=comms_dict, to_cpu=to_cpu,)

    if batch_size is None: batch_size = len(input_ids)

    device = device_fxn(next(model.parameters()).get_device())
    outputs = {key:[] for key in layers}
    if ret_attns:
        assert len(input_ids)<=batch_size
        outputs["attentions"] = []
    if ret_pred_ids: outputs["pred_ids"] = []
    rnge = range(0,len(input_ids), batch_size)
    for batch in rnge:
        x = input_ids[batch:batch+batch_size]
        pmask = None
        tmask = None
        if pad_mask is not None:
            pmask = pad_mask[batch:batch+batch_size].to(device)
        if task_mask is not None:
            tmask = task_mask[batch:batch+batch_size].to(device)
        out_dict = model(
            inpts=x.to(device),
            task_mask=tmask,
            pad_mask=pmask,
            output_attentions=ret_attns,
            n_steps=n_steps,
            ret_gtruth=ret_gtruth,
            tforce=tforce,)
        if ret_attns:
            outputs["attentions"].append(out_dict["attentions"][0])
        if ret_pred_ids:
            outputs["pred_ids"].append(out_dict["pred_ids"])
        for k in layers:
            output = comms_dict[k]
            if type(output)==list: # There could be an internal for loop in model
                if len(output)==0:
                    print(k, "isn't producing")
                    assert False
                if len(output[0].shape)<=3:
                    output = torch.stack(output, dim=1)
                elif len(output)==1:
                    raise NotImplemented
                    #output = output[0]
                else:
                    raise NotImplemented
                    #output = torch.cat(output, dim=1)
            outputs[k].append(output)
            comms_dict[k] = []
    if len(outputs[layers[0]])>1:
        # Concat batches together
        outputs = {
            k:torch.cat(v,dim=0) for k,v in outputs.items()}
    else:
        outputs = {k:v[0] for k,v in outputs.items()}
    if to_cpu:
        outputs = {k:v.cpu() for k,v in outputs.items()}

    # Ensure we do not create a memory leak
    for i in range(len(handles)):
        handles[i].remove()
    del handles

    return outputs

def deactivate_model_gradients(model):
    """
    A helper to ensure gradients are not calculated for the model
    parameters during the alignment training. This assumes nothing about
    the high level api of your model other than that it's a torch module.

    Args:
        model: torch module
    Returns:
        prev_states: dict
            keys: torch Parameters
            values: bool
                the previous requires_grad state before turning it off.
                This can be used later to return things to how you found
                them.
    """
    prev_states = dict()
    for p in model.parameters():
        if hasattr(p, "requires_grad"):
            prev_states[p] = p.requires_grad
            p.requires_grad = False
    return prev_states

def reactivate_model_gradients(model, prev_states):
    """
    A helper to return things to how you found them.

    Args:
        model: torch module
        prev_states: dict
            keys: torch Parameters
            values: bool
                the previous requires_grad state before turning it off.
    """
    for p in model.parameters():
        if p in prev_states:
            p.requires_grad = prev_states[p]

def automated_das(
        model,
        base_data: dict,
        base_idxs: dict,
        source_data: dict,
        source_idxs: dict,
        layers=None,
        boundless_das=True,
        n_neurons=None,
        resp_signal_only=True,
        mask_temperature=0.1,
        batch_size=500,
        lr=0.001,
        n_epochs=100,
        source_model=None,
        optim_type="Adam",
        info:dict=None,
        *args, **kwargs,
    ):
    """
    This function automates a DAS search over layers in the model without
    the need to modify your model.

    Args:
        model: torch Module
        base_data: dict
            The base data must consist of arbitrary input tensors with 
            the corresponding interchange labels for the output tensors.

            "train": dict
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
                "output_ids": torch tensor, shape: (N,) or (N,S)
                "output_mask": optional bool tensor, shape: (N,) or (N,S)
            "val": dict
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
                "output_ids": torch tensor, shape: (N,) or (N,S)
                "output_mask": optional bool tensor, shape: (N,) or (N,S)
        base_idxs: dict or None
            Both `val` and `train` default to all zeros if None is argued.

            "train": torch tensor, shape (N,)
                The step in the sequence at which to perform the
                interchange. If the input is non-sequential, this argument
                should be None or all zeros.
            "val": torch tensor, shape (N,)
                The step in the sequence at which to perform the
                interchange. If the input is non-sequential, this argument
                should be None or all zeros.
        source_data: dict
            The source data can consist of any arbitrary input X tensors.
            The important part is that the output labels in `base_data`
            are the appropriate labels assuming a causal intervention
            from the source_data inputs.

            "train": dict
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
            "val": dict
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
        source_idxs: dict or None
            Both `val` and `train` default to all zeros if None is argued.

            "train": torch tensor, shape (N,)
                The step in the sequence at which to perform the
                interchange. If the input is non-sequential, this argument
                should be None or all zeros.
            "val": torch tensor, shape (N,)
                The step in the sequence at which to perform the
                interchange. If the input is non-sequential, this argument
                should be None or all zeros.
        layers: optional, list of str
            Can optionally argue a list of layers that you would like
            to focus on for the DAS search. If None is argued, DAS will
            search over all linear and convolution layers.
        boundless_das: bool
            if true, will use Boundless DAS to learn an interchange mask.
            if False, will search over 1 to all neurons minus 1 for the
            interchange mask.
        mask_temperature: float
            the temperature for the learned mask. larger values means
            less defined boundaries.
        n_neurons: optional, int or list
            If an int is argued, the interchange mask will be restricted
            to the argued neuron count. If None is argued, the function
            will search over neuron counts of 1 to all neurons minus 1.
            If a list is argued, the search will range from the first
            to the second elements inclusive
        resp_signal_only: bool
            if true, will only using training signals from the response
            phase.
        batch_size: int
        lr: float
        n_epochs: int
        double_rot: bool
            if true, applies 2 rotation matrices sequentially as
            opposed to 1.
        normalize: bool
            if true, will normalize each feature dim of the intervened
            activations using the mean and std of the training activations.
        relaxed: bool
            if true, will relax the orthonormal constraint on the
            rotation matrix to only be an invertible matrix.
        rot_first: bool
            only applies if relaxed is true. If true, will apply the
            rotation matrix before the scaling matrix.
        identity_init: bool
            if true, will initialize the rotation matrix to the identity
            matrix.
        identity_rot: bool
            if true, will always reset the rotation matrix to the
            identity. Used for debugging.
        rot_bias: bool
            if true, will include a shifting term in the rotation
        hook_type: str (options: "rnn", "transformer")
            the type of hook to use for the interventions.
        rech_reg: float
            the intensity of regularization on the h vector if using
            recurrot approaches.
        source_model: optional torch.Module
            optionally argue a separate model for the source data to
            perform MAS.
        info: dict
            if None, will be ignored. Otherwise used for extra metrics
            that are dependent on the trigger ids and eos_id. Argue
            the info for the base data rather than the source data.
    Returns:
        intervention_modules: dict of Interventions
            keys are number of neurons
            values are dicts of modules
        performance_metrics: pd DataFrame
            "intervention_idx": list of ints
                the index of the corresponding intervention_module in
                the returned intervention_modules list
            "epoch": list of ints
                the training epoch
            "acc": list of floats
                the accuracy of the alignment evaluated on the argued
                validation data
            "loss": list of floats
                the loss of the alignment on the argued validation data
    """
    if layers is None:
        layers = get_layer_names(model)
        print("Using layers:", layers)
    elif type(layers)==str: layers = [layers]
    metrics = {
        "layer": [], # the layer at which the alignment was performed
        "n_neurons": [], # NA means we used Boundless DAS
        "epoch": [],
    }
    intr_modules = dict()
    model.eval()
    prev_grad_states = deactivate_model_gradients(model)
    if source_model is not None:
        source_model.eval()
        prev_source_states = deactivate_model_gradients(source_model)
    for layer in layers:
        if not boundless_das and n_neurons is None:
            # We're only doing this to get the size of the neuron vector,
            # so we don't care about masks, and we only use one sample
            with torch.no_grad():
                actvs = collect_activations(
                    model=model,
                    input_ids=source_data["train"]["input_ids"][:1],
                    pad_mask=None,
                    task_mask=None,
                    layers=[layer],
                    ret_attns="sdp" in layer,
                    batch_size=2,
                    to_cpu=True)[layer]
            neurons = [actvs[:,0].reshape(len(actvs), -1).shape[-1]//2]
        elif type(n_neurons)==list:
            neurons = list(range(n_neurons[0], n_neurons[1]+1))
        else:
            neurons = [n_neurons]

        for n in neurons:
            print(f"\n\nBeginning {n} neurons")
            intr_modus, submetrics, continue_search = das_training(
                intr_modu=None,
                model=model,
                source_model=source_model,
                layer=layer,
                base_data=base_data,
                base_idxs=base_idxs,
                source_data=source_data,
                source_idxs=source_idxs,
                boundless_das=boundless_das,
                mask_temperature=mask_temperature,
                n_neurons=n,
                resp_signal_only=resp_signal_only,
                batch_size=batch_size,
                lr=lr,
                n_epochs=n_epochs,
                optim_type=optim_type,
                info=info,
                **kwargs)
            for epoch in submetrics["epoch"]:
                metrics["epoch"].append(epoch)
                metrics["layer"].append(layer)
                metrics["n_neurons"].append(n)
                for k in submetrics:
                    if k not in metrics:
                        metrics[k] = []
                    if k not in {"epoch", "layer", "n_neurons"}:
                        metrics[k].append(submetrics[k][epoch])
            intr_modules[n] = intr_modus
            if not continue_search: break
    reactivate_model_gradients(model, prev_grad_states)
    return intr_modules, metrics

def register_interchange_hook(
        intr_modu,
        model,
        comms_dict,
        layer,
        source_actvs_key="source",
        base_idx_key="base_idxs",
        source_idx_key="source_idxs",
        loop_count_key="loop_count",
        hook_type="rnn",
        seq_idx=1,
        batch_first=True,
        subspaces=None,):
    if hook_type=="rnn" or hook_type=="mamba":
        hook_getter = get_state_interchange_hook
    elif hook_type=="transformer":
        #hook_getter = get_transformer_interchange_hook
        hook_getter = get_state_interchange_hook
    elif hook_type=="context":
        hook_getter = get_context_interchange_hook
    elif hook_type=="recurrot":
        hook_getter = get_recurrot_interchange_hook
    handle = None
    for name, mod in model.named_modules():
        if name == layer:
            hook = hook_getter(
                intr_modu=intr_modu,
                comms_dict=comms_dict,
                source_actvs_key=source_actvs_key,
                base_idx_key=base_idx_key,
                source_idx_key=source_idx_key,
                loop_count_key=loop_count_key,
                seq_idx=seq_idx,
                batch_first=batch_first,
                subspaces=subspaces,)
            handle = mod.register_forward_hook(hook)
            break
    assert handle is not None, "Failed to find layer"
    return handle

def lossfxn(logits, targs, pred_mask=None, targ_mask=None, reduce=True):
    """
    Calculates the cross entropy between logits and targs.

    Args:
        logits: torch FloatTensor, shape (B, S, ...)
        targs: torch LongTensor, shape (B, S)
        pred_mask: torch bool tensor, shape (B, S)
            optionally argue a mask to leave some calculations out of
            the loss calculation. true values do get included in the
            loss. false values do not.
        targ_mask: torch bool tensor, shape (B, S)
            optionally argue a mask to leave some calculations out of
            the loss calculation. true values do get included in the
            loss. false values do not.
        reduce: bool
            if true, will take the average over all loss. Otherwise,
            returns the loss for each individual token maintaining the
            original shape
    """
    og_shape = logits.shape[:-1]
    if len(targs.shape)==2:
        B,S = targs.shape
        K = B*S
    else: K = targs.shape[0]
    logits = logits.reshape(K, -1)
    targs = targs.reshape(K)
    if pred_mask is not None:
        pred_mask = pred_mask.reshape(K).bool()
        logits = logits[pred_mask]
        targ_mask = targ_mask.reshape(K).bool()
        targs = targs[targ_mask]
    r = "mean" if reduce else "none"
    loss = torch.nn.functional.cross_entropy( logits, targs, reduction=r )
    if not reduce:
        if pred_mask is not None:
            temp = loss
            loss = torch.zeros(og_shape,device=loss.get_device())
            loss[pred_mask.reshape(og_shape)] = temp
            denoms = pred_mask.reshape(og_shape).long().sum(-1)
            loss = loss.sum(-1)/denoms
        else:
            loss = loss.reshape(og_shape).mean(-1)
    return loss

def accfxn(logits=None,
           targs=None,
           preds=None,
           pred_mask=None,
           targ_mask=None,
           reduce=True):
    """
    Args:
        logits: float tensor, (B, S, ...)
        targs: torch LongTensor, shape (B, S)
        preds: optional long tensor, (B, S)
            optionally argue the prediction ids directly. otherwise
            will use argmax of logits.
        pred_mask: torch bool tensor, shape (B, S)
            optionally argue a mask to leave some calculations out of
            the loss calculation. true values do get included in the
            loss. false values do not.
        targ_mask: torch bool tensor, shape (B, S)
            optionally argue a mask to leave some calculations out of
            the loss calculation. true values do get included in the
            loss. false values do not.
        reduce: bool
            if true, will take the average over all tokens. Otherwise,
            returns the acc averaged along the S dimension only for
            tokens within the pred_mask
    """
    og_shape = preds.shape
    if len(targs.shape)==2:
        B,S = targs.shape
        K = B*S
    else: K = targs.shape[0]
    if preds is None:
        preds = torch.argmax(logits, dim=-1).long()
    preds = preds.reshape(K)
    targs = targs.reshape(K)
    if pred_mask is not None:
        pred_mask = pred_mask.reshape(K).bool()
        preds = preds[pred_mask]
        targ_mask = targ_mask.reshape(K).bool()
        targs = targs[targ_mask]
    acc = (preds==targs).float()
    if reduce: acc = acc.mean()
    else:
        if pred_mask is not None:
            temp = acc
            acc = torch.zeros(og_shape,device=preds.get_device())
            acc[pred_mask.reshape(og_shape)] = temp
            denoms = pred_mask.reshape(og_shape).long().sum(-1)
            acc = acc.sum(-1)/denoms
        else:
            acc = acc.reshape(og_shape).mean(-1)
    return acc

def das_training(
        model,
        layer,
        base_data: dict,
        base_idxs: dict,
        source_data: dict,
        source_idxs: dict,
        intr_modu=None,
        boundless_das=True,
        mask_temperature=0.1,
        n_neurons=None,
        resp_signal_only=True,
        batch_size=500,
        lr=0.001,
        rech_reg=0.01,
        n_epochs=10,
        double_rot=False,
        normalize=False,
        relaxed=False,
        rot_first=False,
        identity_init=False,
        identity_rot=False,
        rot_bias=False,
        source_model=None,
        optim_type="Adam",
        hook_type=None,
        full_gpu=True,
        info:dict=None,
        early_stopping=20,
        **kwargs,
    ):
    """
    Args:
        model: torch Module
        layer: str
            The layer that you would like to focus on for the DAS
            training.
        base_data: dict
            The base data must consist of arbitrary input input tensors
            with the corresponding interchange labels for the y tensors.

            "train": dict
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
                "output_ids": torch tensor, shape: (N,) or (N,S)
                "output_mask": optional bool tensor, shape: (N,) or (N,S)
            "val": dict
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
                "output_ids": torch tensor, shape: (N,) or (N,S)
                "output_mask": optional bool tensor, shape: (N,) or (N,S)
        base_idxs: dict or None
            Both `val` and `train` default to all zeros if None is argued.

            "train": torch tensor, shape (N,)
                The step in the sequence at which to perform the
                interchange. If the input is non-sequential, this argument
                should be None or all zeros.
            "val": torch tensor, shape (N,)
                The step in the sequence at which to perform the
                interchange. If the input is non-sequential, this argument
                should be None or all zeros.
        source_data: dict
            The source data can consist of any arbitrary input tensors.
            The important part is that the output labels in `base_data`
            are the appropriate labels assuming a causal intervention
            from the source_Data inputs.

            "train": dict
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
            "val": dict
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
        source_idxs: dict or None
            Both `val` and `train` default to all zeros if None is argued.

            "train": torch tensor, shape (N,)
                The step in the sequence at which to perform the
                interchange. If the input is non-sequential, this argument
                should be None or all zeros.
            "val": torch tensor, shape (N,)
                The step in the sequence at which to perform the
                interchange. If the input is non-sequential, this argument
                should be None or all zeros.
        boundless_das: bool
            if true, will use Boundless DAS to learn an interchange mask.
            if False, will search over 1 to all neurons minus 1 for the
            interchange mask.
        mask_temperature: float
            the temperature for the learned mask. larger values means
            less defined boundaries.
        n_neurons: optional, int
            If an int is argued, the interchange mask will be restricted
            to the argued neuron count. If None is argued, the function
            will search over neuron counts of 1 to all neurons minus 1.
        resp_signal_only: bool
            if true, will only using training signals from the response
            phase.
        batch_size: int
        lr: float
        rech_reg: float
            the intensity of regularization on the h vector if using
            recurrot approaches.
        n_epochs: int
        double_rot: bool
            if true, applies 2 rotation matrices sequentially as
            opposed to 1.
        normalize: bool
            if true, will normalize the data using the mean and std of
            the training data.
        hook_type: str (options: "rnn", "transformer")
            the type of hook to use for the interventions.
        relaxed: bool
            if true, will relax the orthonormal constraint on the
            rotation matrix to only be an invertible matrix.
        rot_first: bool
            only applies if relaxed is true. If true, will apply the
            rotation matrix before the scaling matrix.
        identity_init: bool
            if true, will initialize the rotation matrix to the identity
            matrix.
        identity_rot: bool
            if true, will always reset the rotation matrix to the
            identity. Used for debugging.
        rot_bias: bool
            if true, will include a shifting term in the rotation
        source_model: optional torch Module
            optionally argue another model to perform MAS
        full_gpu: bool
            if true, will put all data onto the appropriate device
            and leave it on that device
        info: dict
            if None, ignored.

            "trigger_ids": list of ints
            "eos_id": int
        early_stopping: int or None
            if None, will not do early stopping. Otherwise will stop the
            training if the loss has not decreased over the last
            `early_stopping` epochs.
    Returns:
        modus: dict
            intervention modules
            "best_train": best checkpoint by training accuracy
            "best_val": best checkpoint by validation accuracy
            "best": best checkpoint by both train and val accuracy
            "last": last checkpoint in training
        performance_metrics: pd DataFrame
            "intervention_idx": list of ints
                the index of the corresponding intervention_module in
                the returned intervention_modules list
            "epoch": list of ints
                the training epoch
            "train_acc": list of floats
                the accuracy of the alignment evaluated on the argued
                training data tokens
            "train_loss": list of floats
                the loss of the alignment on the argued training data
            "train_correct": list of floats
                the accuracy of the alignment evaluated on the argued
                training count responses
            "val_acc": list of floats
                the accuracy of the alignment evaluated on the argued
                validation data tokens
            "val_loss": list of floats
                the loss of the alignment on the argued validation data
            "val_correct": list of floats
                the accuracy of the alignment evaluated on the argued
                validation data count responses
    """
    signal.signal(signal.SIGINT, error_handle)
    device = device_fxn(next(model.parameters()).get_device())
    if info is None: info = dict()

    sep_rot = True # Determines whether there will be a src rot mtx
    if source_model is None:
        source_model = model
        sep_rot = False

    with torch.no_grad():
        # source_train is tensor of shape (N,...)
        source_train = collect_activations(
            model=source_model,
            input_ids=source_data["train"]["input_ids"],
            pad_mask=source_data["train"].get("pad_mask", None),
            task_mask=source_data["train"].get("task_mask",None),
            layers=[layer],
            batch_size=batch_size,
            ret_attns="sdp" in layer,
            to_cpu=True)[layer]
        source_val = collect_activations(
            model=source_model,
            input_ids=source_data["val"]["input_ids"],
            pad_mask=source_data["val"].get("pad_mask", None),
            task_mask=source_data["val"].get("task_mask",None),
            layers=[layer],
            batch_size=batch_size,
            ret_attns="sdp" in layer,
            to_cpu=True)[layer]

        source_actvs = {"train": source_train, "val": source_val}

        if "test" in source_data:
            source_test = collect_activations(
                model=source_model,
                input_ids=source_data["test"]["input_ids"],
                pad_mask=source_data["test"].get("pad_mask", None),
                task_mask=source_data["test"].get("task_mask",None),
                layers=[layer],
                batch_size=batch_size,
                ret_attns="sdp" in layer,
                to_cpu=True)[layer]
            source_actvs["test"] = source_test

    if intr_modu is None:
        temp = source_actvs["train"][0,0] # Remove batch and seq dims
        print("State Shape:", temp.shape)
        embed_dim = temp.reshape(-1).shape[0]
        print("Full Rotation Dim:", embed_dim)

        if hook_type and hook_type=="context":
            s = source_data["train"]["input_ids"].shape[1]
            embed_dim = embed_dim*s
        Interchange = CausalInterchange
        if hook_type and hook_type=="recurrot":
            Interchange = RecurRotCausalInterchange
        mu = 0
        sigma = 1
        if normalize:
            actvs = source_actvs["train"]
            mu = actvs.reshape(-1,actvs.shape[-1]).mean(0)
            sigma = actvs.reshape(-1,actvs.shape[-1]).std(0)

            # Try to project into a well conditioned space
            #temp = actvs.reshape(-1,actvs.shape[-1])
            #temp = (temp-mu)/sigma
            ##print("sigma:", sigma.shape)
            ##print("temp:", temp.shape)
            #A = torch.matmul(temp.T, temp)
            #print("A:", torch.linalg.cond(A))
            #U,s,V = torch.linalg.svd(A)
            #print("newsig:", sigma)
        intr_modu = Interchange(
            size=embed_dim,
            temperature=mask_temperature,
            fixed=None if boundless_das else n_neurons,
            full_boundary=False,
            double_rot=double_rot,
            mu=mu,
            sigma=sigma,
            relaxed=relaxed,
            rot_first=rot_first,
            identity_init=identity_init,
            identity_rot=identity_rot, #used for debugging
            rot_bias=rot_bias,
            sep_rot=sep_rot,
        )
        #if boundless_das:
        #    intr_modu = BoundlessRotatedSpaceIntervention(
        #        embed_dim=embed_dim)
        #else:
        #    intr_modu = RotatedSpaceIntervention(embed_dim=embed_dim)
        #    intr_modu.set_interchange_dim(interchange_dim=n_neurons)
    intr_modu.to(device)

    comms_dict = reset_comms_dict()

    if hook_type is None:
        hook_type = "rnn"
        if type(model).__name__=="Transformer":
            hook_type = "transformer"
        if type(model).__name__=="Mamba":
            hook_type = "mamba"
    handle = register_interchange_hook(
        model=model,
        layer=layer,
        intr_modu=intr_modu,
        comms_dict=comms_dict,
        source_actvs_key="source",
        base_idx_key="base_idxs",
        source_idx_key="source_idxs",
        loop_count_key="loop_count",
        hook_type=hook_type,
        subspaces=None,)

    optimizer = getattr(torch.optim,optim_type)(
        intr_modu.parameters(),
        lr=lr)

    if full_gpu:
        for k in ["train", "val", "test"]:
            if k in source_actvs:
                source_actvs[k] = source_actvs[k].to(device)
                source_idxs[k] = source_idxs[k].to(device)
                base_idxs[k] = base_idxs[k].to(device)
                for kk in base_data[k].keys():
                    base_data[k][kk] = base_data[k][kk].to(device)

    metrics = {
        "epoch": [],
        "train_acc": [],
        "train_loss": [],
        "train_correct": [],
        "val_acc": [],
        "val_loss": [],
        "val_correct": [],
    }
    if "test" in source_actvs:
        keys = [ "test_acc", "test_loss", "test_correct" ]
        metrics = {**metrics, **{k: [] for k in keys}}
    vbatch_size = batch_size*4
    n_train_loops = max(len(source_actvs["train"])//batch_size, 1)
    n_val_loops = max(len(source_actvs["val"])//vbatch_size, 1)
    N = int(n_train_loops*batch_size)
    NV = int(n_val_loops*vbatch_size)
    continue_automated_search = True
    best_train_correct = 0
    best_val_correct = 0
    min_val_loss = np.inf
    early_stop_count = 0
    best_train_modu = None
    best_val_modu = None
    best_modu = None
    try:
        for epoch in range(n_epochs):
            start_time = time.time()
            print("Epoch:", epoch, "N Neurons:", n_neurons)
            if "save_folder" in info:
                print("Model:", info["save_folder"])
            if "das_save_name" in info:
                print("Save Name:", info["das_save_name"])

            # TRAIN
            ########################################################
            # Important to maintain reference to same comms_dict unless
            # you want to replace the hook.
            comms_dict = reset_comms_dict(comms_dict=comms_dict)
            train_loss, train_acc, train_correct, _ = process_samples(
                model,
                source_data=source_data,
                source_actvs=source_actvs,
                source_idxs=source_idxs,
                base_data=base_data,
                base_idxs=base_idxs,
                comms_dict=comms_dict,
                key="train",
                resp_signal_only=resp_signal_only,
                layer=layer,
                optimizer=optimizer,
                device=device,
                batch_size=batch_size,
                N=N,
                verbose=True,
            )
            optimizer.zero_grad()

            # Validation loop
            #########################################################
            comms_dict = reset_comms_dict(comms_dict=comms_dict)
            with torch.no_grad():
                val_loss, val_acc, val_correct, _ = process_samples(
                    model,
                    source_data=source_data,
                    source_actvs=source_actvs,
                    source_idxs=source_idxs,
                    base_data=base_data,
                    base_idxs=base_idxs,
                    comms_dict=comms_dict,
                    key="val",
                    layer=layer,
                    optimizer=None,
                    device=device,
                    batch_size=vbatch_size,
                    N=NV,
                    verbose=True,
                )

            print("Train Loss:", train_loss, "- Token:", train_acc)
            print("Train Acc:", train_correct)
            print("Val Loss:", val_loss, "- Token:", val_acc)
            print("Val Acc:", val_correct)

            # Test loop
            #########################################################
            if "test" in source_actvs:
                comms_dict = reset_comms_dict(comms_dict=comms_dict)
                with torch.no_grad():
                    test_loss, test_acc, test_correct, _ = process_samples(
                        model,
                        source_data=source_data,
                        source_actvs=source_actvs,
                        source_idxs=source_idxs,
                        base_data=base_data,
                        base_idxs=base_idxs,
                        comms_dict=comms_dict,
                        key="test",
                        layer=layer,
                        optimizer=None,
                        device=device,
                        batch_size=vbatch_size,
                        N=NV,
                        verbose=False
                    )
                print("Test Loss:", test_loss, "- Token:", test_acc)
                print("Test Acc:",  test_correct)
                metrics["test_acc"].append(test_acc)
                metrics["test_loss"].append(test_loss)
                metrics["test_correct"].append(test_correct)

            print("Exec Time:", time.time()-start_time)
            print()

            metrics["epoch"].append(epoch)
            metrics["train_acc"].append(train_acc)
            metrics["train_loss"].append(train_loss)
            metrics["train_correct"].append(train_correct)
            metrics["val_acc"].append(val_acc)
            metrics["val_loss"].append(val_loss)
            metrics["val_correct"].append(val_correct)

            if train_correct>best_train_correct:
                best_train_correct = train_correct
                best_train_modu = copy.deepcopy(intr_modu.cpu())
                best_train_modu.epoch = epoch
            if val_correct>=best_val_correct:
                best_val_correct = val_correct
                best_val_modu = copy.deepcopy(intr_modu.cpu())
                best_val_modu.epoch = epoch
                if train_correct>=(best_train_correct-0.02):
                    best_modu = copy.deepcopy(intr_modu.cpu())
                    best_modu.epoch = epoch
            intr_modu.to(device)

            if train_correct<=0.05 and val_correct<=0.05 and epoch>10:
                print("Ending training due to insufficient performance")
                continue_automated_search = True
                break
            if train_correct>0.990 and val_correct>0.990:
                print("Ending training due to sufficient performance")
                continue_automated_search = False
                break

            early_stop_count += 1
            if val_loss<min_val_loss-0.001:
                min_val_loss = val_loss
                early_stop_count = 0
            if early_stopping is not None:
                if early_stop_count>=early_stopping:
                    print("Ending training due to early stopping")
                    continue_automated_search = True
                    break
    except KeyboardInterrupt as e:
        continue_automated_search = False
        print("Interrupted by an error")
    handle.remove() # Clean up hook
    modus = {
        "best_train": best_train_modu if best_train_modu is not None else intr_modu,
        "best_val": best_val_modu if best_val_modu is not None else intr_modu,
        "best": best_modu if best_modu is not None else intr_modu,
        "last": intr_modu,
    }
    return modus, metrics, continue_automated_search

def das_eval(
        model,
        layer,
        intr_modu,
        base_data: dict,
        base_idxs: dict,
        source_data: dict,
        source_idxs: dict,
        batch_size=None,
        source_model=None,
        source_actvs=None,
        hook_type=None,
        info:dict=None,
    ):
    """
    Args:
        model: torch Module
        layer: str
            The layer that you would like to focus on for the DAS
            training.
        base_data: dict
            The base data must consist of arbitrary input input tensors
            with the corresponding interchange labels for the y tensors.
            Keys:
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
                "task_mask": bool tensor (N,S+1)
                    zeros denote teacher forcing
                "output_ids": torch tensor, shape: (N,) or (N,S)
                "output_mask": optional bool tensor, shape: (N,) or (N,S)
        base_idxs: None or torch tensor, shape (N,)
            The step in the sequence at which to perform the
            interchange. If the input is non-sequential, this argument
            should be None or all zeros.  Defaults to all zeros if None
            is argued.
        source_data: dict
            The source data can consist of any arbitrary input tensors.
            The important part is that the output labels in `base_data`
            are the appropriate labels assuming a causal intervention
            from the source_Data inputs.
            Keys:
                "input_ids": torch tensor, shape: or (N,S)
                "pad_mask": optional bool tensor, shape: (N,) or (N,S)
                    true means padding
                "task_mask": bool tensor (N,S)
                    zeros denote teacher forcing
        source_idxs: None or torch tensor, shape (N,)
            The step in the sequence at which to perform the
            interchange. If the input is non-sequential, this argument
            should be None or all zeros.  Defaults to all zeros if None
            is argued.
        batch_size: int
        hook_type: str (options: "rnn", "transformer")
            the type of hook to use for the interventions.
        source_model: optional torch Module
            optionally argue another model to perform MAS
        source_actvs: optional torch tensor (N,S,D)
            optionally argue precomputed source activations.
        info: dict
            if None, ignored.
            Keys:
                "trigger_ids": list of ints
                "eos_id": int
    Returns:
        intr_modu: 
            intervention module
        performance_metrics: pd DataFrame
            "train_acc": list of floats
                the accuracy of the alignment evaluated on the argued
                training data tokens
            "train_loss": list of floats
                the loss of the alignment on the argued training data
            "train_correct": list of floats
                the accuracy of the alignment evaluated on the argued
                training count responses
            "val_acc": list of floats
                the accuracy of the alignment evaluated on the argued
                validation data tokens
            "val_loss": list of floats
                the loss of the alignment on the argued validation data
            "val_correct": list of floats
                the accuracy of the alignment evaluated on the argued
                validation data count responses
    """
    if layer is None:
        layer = get_layer_names(model)[0]
        print("Using layer:", layer)
    device = device_fxn(next(model.parameters()).get_device())
    intr_modu.to(device)

    if info is None:
        # TODO: Fix the magic numbers
        # TODO: non generalizable code
        info = { "eos_id": 2, "trigger_ids": [7] }
        # non generalizable code
    if batch_size is None: batch_size = len(base_data["input_ids"])

    sep_rot = intr_modu.sep_rot # Determines whether there will be a src rot mtx
    if source_model is None: source_model = model

    if source_actvs is None:
        with torch.no_grad():
            # source_train is tensor of shape (N,...)
            print("batch-size", batch_size)
            print("inpt_ids:", source_data["input_ids"].shape)
            print("layer:", layer)
            source_actvs = collect_activations(
                model=source_model,
                input_ids=source_data["input_ids"],
                pad_mask=source_data.get("pad_mask", None),
                task_mask=source_data.get("task_mask", None),
                layers=[layer],
                batch_size=batch_size,
                ret_attns="sdp" in layer,
                to_cpu=True)[layer]

    comms_dict = reset_comms_dict()

    if hook_type is None:
        hook_type = "rnn"
        if type(model).__name__=="Transformer":
            hook_type = "transformer"
    handle = register_interchange_hook(
        model=model,
        layer=layer,
        intr_modu=intr_modu,
        comms_dict=comms_dict,
        source_actvs_key="source",
        base_idx_key="base_idxs",
        source_idx_key="source_idxs",
        loop_count_key="loop_count",
        hook_type=hook_type,
        subspaces=None,)

    n_loops = max(len(source_actvs)//batch_size, 1)
    N = int(n_loops*batch_size)
    trigger_ids = torch.LongTensor(
        info["trigger_ids"]).to(model.get_device())
    try:
        print("Beginning Evaluation")
        if "model_folder" in info: print("Model:", info["model_folder"])
        if "das_save_name" in info:
            print("Save Name:", info["das_save_name"])
        # Important to maintain reference to same comms_dict unless
        # you want to replace the hook.
        comms_dict = reset_comms_dict(comms_dict=comms_dict)
        tot_loss, tot_acc = 0,0
        tot_correct = 0

        comms_dict = reset_comms_dict(comms_dict=comms_dict)
        with torch.no_grad():
            key = "test"
            tot_loss, tot_acc, tot_correct, metrics = process_samples(
                model,
                source_data={key: source_data},
                source_actvs={key: source_actvs},
                source_idxs={key: source_idxs},
                base_data={key: base_data},
                base_idxs={key: base_idxs},
                comms_dict=comms_dict,
                key=key,
                layer=layer,
                optimizer=None,
                device=device,
                batch_size=batch_size,
                N=N,
                verbose=True,
            )
    except KeyboardInterrupt as e:
        print("Interrupted by an error")
    handle.remove() # Clean up hook
    return metrics
