import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import collections
import os
import sys
import math
sys.path.insert(1, "../")

import dl_utils.save_io as savio
from dl_utils.utils import pad_to, get_mask_past_id, num2base

from utils import check_correct_count, get_counts, run_til_idx
from seq_models import *
import datas
from automated_utils import collect_activations, register_interchange_hook

def get_attn_proj_hook(
        comms_dict,
        count_vec_key="count_vec",
        count_mag_key="count_mag",
        strength_key="strength",
        idx_key="idx",
        loop_key="loop_count",
        *args, **kwargs):
    """
    This function allows you to unnormalize the strength-value vectors
    in the attention, and then add some multiple of a custom strength-value
    vector, and then renormalize all the new strength values by the
    original normalization plus the added strength-magnitude. Ex:
    
    orig_attn_output = stren_vals/normalizer
    new_attn_output = (normalizer*(stren_vals/normalizer) + magnitude*stren*val_vec)/(normalizer+magnitude*stren)
    
    Args:
        comms_dict: dict
            This is the communications dict. Each key is a string.

            count_vec_key: torch tensor (B,D)
                A vector of size d_model that we believe is the direction
                of the count in the model. Positive count_dir is in the
                direction of the demonstration tokens.
            count_mag_key: torch Tensor (B,)
                This vector will be multiplied by the count_vector
                to effectively change the existing count by this magnitude.
                i.e. if you want to increase the count by 1, then put a 1
                in this vector.
            strength_key: float
                this is the e^(qk/sqrt(d)) term for the response self-attention
                strength on itself. q and k are the response query and key
                vectors.
            idx_key: torch tensor (B,)
                The positions of the representations in the sequence
                that we want to intervene upon
            loop_key: int
                the number of forward passes we've seen so far.
    Returns:
        hook: pytorch hook function
            attach this hook to your desired module
    """
    def hook(module, base_actvs, out):
        comms_dict[loop_key] += 1
        loop_count = comms_dict[loop_key]
        if type(base_actvs)==tuple: base_actvs = base_actvs[0]
        device = base_actvs.get_device()
        if device<0: device = "cpu"
        og_base_shape = base_actvs.shape
        
        count_vec = comms_dict[count_vec_key].to(device)
        count_mags = comms_dict[count_mag_key]
        #idxs = comms_dict[idx_key].to(device)
        
        mags = count_mags*count_vec
        
        strength = comms_dict[strength_key]
        #strength = 1
        div = ((loop_count+incl_trig+abs(count_mags))*strength)
        intrv = (loop_count*strength*base_actvs + mags)/div
        
        #intrv = (loop_count*base_actvs + mags[:,None])/(torch.zeros_like(div)+loop_count)
        
        #print("abs counts", torch.abs(count_mags))
        #print("mags:", mags.shape)
        #print("base:", base_actvs.shape)
        #print("loop:", loop_count)
        #print("mags:", mags)
        #print("div:", div)
        
        #print("Loopcount:", loop_count)
        #print("count mags:", count_mags.shape)
        #print("count vec:", count_vec.shape)
        #print("base actvs:", base_actvs.shape)
        #print("mags:", mags.shape)
        #print("div:", div.shape)
        #print()
        
        #intrv = base_actvs
        
        intrv = torch.matmul(intrv, module.weight.T)
        if module.bias is not None: intrv = intrv + module.bias
        
        if type(out)==dict:
            intrv = {**og_actvs, dict_key: intrv}
        return intrv
    return hook

def get_sdp_attn_hook(
        comms_dict,
        count_vec_key="count_vec",
        count_mag_key="count_mag",
        strength_key="strength",
        idx_key="idx",
        loop_key="loop_count",
        *args, **kwargs):
    """
    This hook allows you to directly calculate what the strengths should be
    and multiply those by the values without normalizing. You can then add
    you custom "count vector" and a multiplier, mag, to indicate how much
    you want to change the "count". Then we normalize by the original
    normlization plus the strength*absolute_value(mag).
    
    This hook should be applied to the sdp_attn module (which is just a module
    that calculates the scaled dot product attention.
    
    Args:
        comms_dict: dict
            This is the communications dict. Each key is a string.

            count_vec_key: torch tensor (B,D)
                A vector of size d_model that we believe is the direction
                of the count in the model. Positive count_dir is in the
                direction of the demonstration tokens.
            count_mag_key: torch Tensor (B,)
                This vector will be multiplied by the count_vector
                to effectively change the existing count by this magnitude.
                i.e. if you want to increase the count by 1, then put a 1
                in this vector.
            strength_key: float
                this is the e^(qk/sqrt(d)) term for the response self-attention
                strength on itself. q and k are the response query and key
                vectors.
            idx_key: torch tensor (B,)
                The positions of the representations in the sequence
                that we want to intervene upon
            loop_key: int
                the number of forward passes we've seen so far.
    Returns:
        hook: pytorch hook function
            attach this hook to your desired module
    """
    def hook(module, inpt, out):
        comms_dict[loop_key] += 1
        loop_count = comms_dict[loop_key]
        
        q,k,v,mask = inpt
        
        device = q.get_device()
        if device<0: device = "cpu"

        count_vec = comms_dict[count_vec_key].to(device)
        count_mag = comms_dict[count_mag_key]
        count_stren = comms_dict[strength_key]
        
        scale = math.sqrt(k.shape[-1])
        pre_strens = torch.einsum("bnld,bnsd->bnls", q,k)/scale
        if mask is not None:
            pre_strens = pre_strens.masked_fill_(~mask,float(-math.inf))
            
        strens = torch.exp(pre_strens)
        stren_vals = torch.einsum("bnls,bnsp->bnlp", strens, v)
        
        if type(count_mag)==int:
            stren_vals = stren_vals + count_mag*count_vec
        else:
            temp = (count_mag[:,None]*count_vec)[:,None,None]
            stren_vals = stren_vals + temp
            print(stren_vals.shape)
            print(temp.shape)
        div = strens.sum(-1) + abs(count_mag)*count_stren
        attn_out = stren_vals/div[...,None]
        
        #print("count_vec:", count_vec.shape)
        #print("strens:", strens.shape)
        #print("stren_vals:", stren_vals.shape)
        #print("div:", div.shape)
        #print("attn_out:", attn_out.shape)

        return {
            "attentions": out["attentions"],
            "attn_out": attn_out,
        }
    return hook

def get_mag_attn_hook(
        comms_dict,
        count_vec_key="count_vec",
        count_mag_key="count_mag",
        strength_key="strength",
        idx_key="idx",
        loop_key="loop_count",
        *args, **kwargs):
    """
    This function calculates attention as normal for all tokens, and
    then applies the attention calculated at the argued magnitude to
    all tokens instead of their originally calculated attentions.
    Count vec is ignored, and strength is ignored. This is useful for
    determining if the attention at the hooked layer is doing all the
    important computation or if there is some index specific state-
    attention interaction that is necesary for the proper computations.
    
    Args:
        comms_dict: dict
            This is the communications dict. Each key is a string.

            count_vec_key: torch tensor (B,D)
                A vector of size d_model that we believe is the direction
                of the count in the model. Positive count_dir is in the
                direction of the demonstration tokens.
            count_mag_key: torch Tensor (B,)
                This vector will be multiplied by the count_vector
                to effectively change the existing count by this magnitude.
                i.e. if you want to increase the count by 1, then put a 1
                in this vector.
            strength_key: float
                this is the e^(qk/sqrt(d)) term for the response self-attention
                strength on itself. q and k are the response query and key
                vectors.
            idx_key: torch tensor (B,)
                The positions of the representations in the sequence
                that we want to intervene upon
            loop_key: int
                the number of forward passes we've seen so far.
    Returns:
        hook: pytorch hook function
            attach this hook to your desired module
    """
    def hook(module, inpt, out):
        comms_dict[loop_key] += 1
        loop_count = comms_dict[loop_key]
        
        q,k,v,mask = inpt
        
        device = q.get_device()
        if device<0: device = "cpu"

        count_mag = comms_dict[count_mag_key]
        attentions = out["attentions"]
        attn_out = torch.einsum("bnls,bnsp->bnlp", attentions, v)
        idx = torch.arange(len(attn_out)).long().to(device_fxn(attn_out.get_device()))
        attn_out[:,:] =  attn_out[idx,:,count_mag.long()].unsqueeze(-2)
        return {
            "attentions": out["attentions"],
            "attn_out": attn_out,
        }
        
    return hook

def project(vecs, canon):
    norms = torch.norm(vecs,2, dim=-1)
    canon_norm = torch.norm(canon, 2, dim=-1)
    projections = torch.matmul(vecs, canon)/canon_norm
    angles = torch.arccos(projections/norms)
    return projections, angles

if __name__=="__main__":
    resp_idx = -2
    rel_stren = False # if true, stren_idx will refer to the relative index after the input phase
    stren_idx = -2
    n_reps = 20
    max_num = 21
    mag_sign = 1
    verbose = False
    layers = [f"layers.{layer_idx}.self_attn.sdp_attn"]
        # probably use this with get_attn_proj_hook
        #layer = "layers.0.self_attn.out_proj"
    hook_fxns = {
        #"get_mag_attn_hook": get_mag_attn_hook,  # hook an sdp_attn module 
        "get_sdp_attn_hook": get_sdp_attn_hook,  # hook an sdp_attn module 
        #"get_attn_proj_hook": get_attn_proj_hook,# hook an out_proj module 
    }
    for model_folder in model_folders:
        model_folder = model_folders[idx]
        
        checkpt = savio.load_checkpoint(model_folder)
        config = checkpt["config"]
        model = globals()[config.get("model_type", "LSTM")](**config)
        temp = LossWrapper(model=model, config=config)
        temp.load_state_dict(checkpt["state_dict"])
        model = temp.model
        model.eval()
        
        if "word2id" not in config:
            name2word = {
                "pad_id": "<PAD>",
                "bos_id": "<BOS>",
                "eos_id": "<EOS>",
                "resp_id": "R",
                "demo_ids0": "D0",
                "demo_ids1": "D1",
                "demo_ids2": "D2",
                "trigger_id": "T",
                "unk_id": "<UNK>",
            }
            word2id = dict()
            for name,token in name2word.items():
                if name in config:
                    word2id[token] = config[name]
                elif name[:-1] in config:
                    idx = int(name[-1])
                    word2id[token] = config[name[:-1]][idx]
                else:
                    print("could not add", name, token)
            config["word2id"] = word2id
                
        if "word2id" in config:
            mapping = {v:k[:3] for k,v in config["word2id"].items()}
            remap = {
                "<PA": "P",
                "<BO": "B",
                "<EO": "E",
                "<UN": "U",
                "D0": "D0",
                "D1": "D1",
                "D2": "D2",
                "R": "R",
                "T": "T",
            }
            mapping = {k: remap[v] for k,v in mapping.items()}
        else:
            mapping = {
                0: "P",
                1: "B",
                2: "E",
                3: "R",
                4: "N1",
                5: "N2",
                6: "N3",
                7: "T",
            }
        print("Mapping:", mapping)
        pad_id =   config["pad_id"]
        bos_id =   config["bos_id"]
        eos_id =   config["eos_id"]
        try:
            word2id = config["word2id"]
            resp_id =  word2id[config["resp_token"]]
            trig_ids = [word2id[t] for t in config["trigger_tokens"]]
            demo_ids = [word2id[d] for d in config["demo_tokens"]]
        except:
            resp_id =  config["resp_id"]
            trig_ids = [t for t in config["trigger_ids"]]
            demo_ids = [d for d in config["demo_ids"]]
        trig_id = trig_ids[0]
        
        add_bos = config.get("add_bos", True)
        incl_trig = config.get("task_config", {"incl_trigger": True}).get("incl_trigger", True)
        trig_id = config.get("trigger_id", 7)
        demo_ids = config.get("demo_ids", 7)
        resp_id = config.get("resp_id", 6)
        eos_id = config.get("eos_id", 2)
        
        results_dict = {
            "acc": [],
            "intrv_acc": [],
            "eos_acc": [],
            "full_acc": [],
            "targ_count": [],
            "pred_count": [],
            "expect_count": [],
            "harvest_targ_count": [],
            "resp_idx": [],
            "stren_idx": [],
            "n_wrong": [],
            "magnitude": [],
            "layer": [],
            "hook_fxn": [],
        }
        try:
            handle.remove()
        except: pass
        model.cuda()
        for hook_name, hook_fxn in hook_fxns.items():
            for layer in layers:
                for harvest_targ_count in tqdm(range(5,6)):
                    if resp_idx is not None:
                        resp_range = [resp_idx]
                    else:
                        resp_range = range(-harvest_targ_count-1,-1)
                    for resp_idx in resp_range:
                        if stren_idx is not None:
                            stren_range = [stren_idx]
                        else:
                            stren_range = range(-harvest_targ_count-1,-1)
                        for stren_idx in stren_range:
                            # Harvest the count vector
                            harv_seq = add_bos*[bos_id]+[np.random.choice(demo_ids) for _ in range(harvest_targ_count)] + incl_trig*[trig_id] +\
                                  [resp_id for _ in range(harvest_targ_count)] + [eos_id] 
                            #stren_idx = harv_seq.index(resp_id)
                            #if trig_id is not None and trig_id != 0:
                            #    harv_seq = [trig_id] + harv_seq
                            
                            if verbose:
                                print("Layer Idx:", layer_idx)
                                print("Harv Seq:", harv_seq)
                                print("Resp:", harv_seq[resp_idx],"- Idx:", resp_idx)
                                print("Stren Idx:", harv_seq[stren_idx],"- Idx:", stren_idx)
                            
                            harv_seq = torch.LongTensor(harv_seq)[None]
                            try:
                                handle.remove()
                            except:
                                handle = None
                            with torch.no_grad():
                                actvs = collect_activations(
                                    model=model,
                                    input_ids=harv_seq,
                                    pad_mask=None,
                                    layers=[f"layers.{i}.self_attn."+k for k in ["q_proj", "k_proj", "v_proj"] for i in range(len(model.layers))],
                                    batch_size=None,
                                    ret_attns=True,
                                    to_cpu=True,)
                            N = model.layers[0].self_attn.nhead
                            P = model.layers[0].self_attn.proj_dim
                            B,L = harv_seq.shape
                            qs = actvs[f"layers.{layer_idx}.self_attn.q_proj"].reshape(B,L,N,P).permute(0,2,1,3)
                            ks = actvs[f"layers.{layer_idx}.self_attn.k_proj"].reshape(B,L,N,P).permute(0,2,1,3)
                            vs = actvs[f"layers.{layer_idx}.self_attn.v_proj"].reshape(B,L,N,P).permute(0,2,1,3)
                            if config["encoder_layer_class"]=="RotaryEncoderLayer":
                                qs,ks = model.layers[0].self_attn.emb_fxn(qs,ks)
                            scale = math.sqrt(ks.shape[-1])
                            strens = torch.einsum("bnlp,bnsp->bnls",qs,ks)/scale
                            strens = strens.squeeze()[resp_idx] # only want the response strengths
                            strens = torch.exp(strens)
                            
                            sidx = stren_idx
                            if not rel_stren:
                                sidx = stren_idx
                                    #strength = strens[stren_idx]
                                    #count_vec = str_vals[stren_idx]
                            else:
                                sidx = harvest_targ_count+incl_trig+add_bos + stren_idx
                            
                            strength = strens[sidx]
                            str_vals = (vs.squeeze().T*strens).T
                            count_vec = str_vals[stren_idx]
                            assert torch.all(torch.isclose(count_vec, strength*actvs[f"layers.{layer_idx}.self_attn.v_proj"][0,stren_idx]))
                            
                            if verbose:
                                print("Strengths:", strens)
                                print("Normed:", strens/strens.sum())
                                print("Strength:", strength)
                                print("resp_idx:", resp_idx)
                                print("resp_token:", harv_seq[0,resp_idx])
                                print("stren_idx:", stren_idx)
                                print("stren_token:", harv_seq[0,stren_idx])
                                print("count_vec:",count_vec.shape)
                                
                            missed_chars = []
                            targ_counts = torch.arange(1,max_num)
                            for mag_offset in range(1,max_num):
                                all_seqs = []
                                all_labels = []
                                all_tmasks = []
                                all_phase_masks = []
                                all_mags = []
                                all_expects = []
                                all_tcnts = []
                                for rep in range(n_reps):
                                    
                                    el = [[np.random.choice(demo_ids) for _ in range(j)] + incl_trig*[trig_id] +\
                                          [resp_id for _ in range(j)] + [eos_id for _ in range(2*max_num-2*j)] for j in range(1,max_num) ]
                                    if add_bos: el = [[bos_id] + e for e in el]
                                    seqs = torch.LongTensor(el)
                                    
                                    labels = [ [ e for e in el[j-1][1:2*j+int(incl_trig)+int(add_bos)] ] + [eos_id] + [0 for _ in range(2*max_num-2*j)] for j in range(1,max_num)]
                                    labels = torch.LongTensor(labels)
                                    #tmask = [ [0 for _ in range(j+1)] + [1 for _ in range(2*max_num-j)] for j in range(1,max_num) ]
                                    tmask = torch.BoolTensor([ 
                                        [1 if s in {resp_id,eos_id} else 0 for s in seq] for seq in el
                                    ])
                                    
                                    phase_mask = torch.BoolTensor([
                                        [0 if s in {*demo_ids,trig_id} else 1 for s in seq] for seq in el
                                    ])
                                    #tmask = torch.BoolTensor([[0 for _ in range(len(seqs[0]))] for j in range(len(seqs))])
                                    if mags_all_same:
                                        mags = mag_sign*mag_offset
                                    else:
                                        mags = torch.FloatTensor([mag_sign*j*2 for j in range(0,max_num-1)])+mag_offset
                                        mags = mags.to(model.get_device())
                                    #expecteds = [
                                    #    [eos_id for _ in range(max(mag_offset,0))] +\
                                    #    [demo_ids[0] for _ in range(max(j+1-max(mag_offset,0),0))] +\
                                    #    [resp_id for _ in range(max(j+int(incl_trig)-max(mag_offset,0),0))] \
                                    #    for j in range(max_num-1)
                                    #]
                                    expecteds = [
                                        [demo_ids[0] for _ in range(j+add_bos+incl_trig)] +\
                                        [resp_id for _ in range(max(j+1-mag_sign*mag_offset,0))]\
                                        for j in range(max_num-1)
                                    ]
                                    expecteds = torch.LongTensor([expecteds[i] + [eos_id for _ in range(len(seqs[i])-len(expecteds[i]))] for i in range(len(seqs))])
                                    intrv_idxs = torch.LongTensor([j for j in range(1,max_num)])
                                    
                                    
                                    
                                    
                                    all_seqs.append(seqs)
                                    all_labels.append(labels)
                                    all_tmasks.append(tmask)
                                    all_phase_masks.append(phase_mask)
                                    all_expects.append(expecteds)
                                    all_tcnts.append(targ_counts)
                                    #intrv_idxs = torch.LongTensor([j for j in range(1,max_num)])
                                seqs = torch.vstack(all_seqs)
                                labels = torch.vstack(all_labels)
                                tmask = torch.vstack(all_tmasks)
                                phase_mask = torch.vstack(all_phase_masks)
                                mags = mag_sign*mag_offset
                                expecteds = torch.vstack(all_expects)
                                t_counts = torch.vstack(all_tcnts).reshape(-1)
                                    
                                comms_dict = {
                                    count_vec_key: count_vec,
                                    count_mag_key: mags,
                                    strength_key: strength,
                                    idx_key: intrv_idxs,
                                    loop_key: 0,
                                    # TODO: 
                                    #trig_strength: include a value here
                                }
                                hook = hook_fxn(
                                    comms_dict=comms_dict,
                                    count_vec_key=count_vec_key,
                                    count_mag_key=count_mag_key,
                                    idx_key=idx_key,
                                )
                                try:
                                    handle.remove()
                                except:
                                    handle = None
                                for name,modu in model.named_modules():
                                    if name==layer:
                                        handle = modu.register_forward_hook(hook)
                                assert handle is not None
                                
                                model.eval()
                                with torch.no_grad():
                                    inpts = seqs.cuda()
                                    t = tmask.cuda()
                                    ret_dict = model(input_ids=inpts, task_mask=t, tforce=True, n_steps=0, ret_gtruth=False, output_attentions=True)
                                    manip_preds = ret_dict["pred_ids"].cpu()
                                    try:
                                        handle.remove()
                                    except:
                                        print("Handle did not exist!!")
                                    ret_dict = model(input_ids=inpts, task_mask=t, tforce=False, n_steps=1, ret_gtruth=False, output_attentions=True)
                                    og_preds = ret_dict["pred_ids"].cpu()
                                raw_accs = torch.ones_like(manip_preds).bool()
                                mask = tmask&phase_mask
                                raw_accs[mask] = (manip_preds[mask]==expecteds[mask])
                                accs = raw_accs.float().mean(-1)
                                intrv_mask = ((expecteds==eos_id)|(expecteds==resp_id))&(labels==resp_id)
                                intrv_accs = torch.zeros_like(manip_preds).bool()
                                intrv_accs[intrv_mask] = (manip_preds[intrv_mask]==expecteds[intrv_mask])
                                intrv_accs = intrv_accs.long().sum(-1)/intrv_mask.long().sum(-1)
                                eos_accs = torch.zeros_like(manip_preds).long()
                                eos_mask = (labels==eos_id)
                                eos_accs[eos_mask] = (manip_preds[eos_mask]==expecteds[eos_mask]).long()
                                eos_accs = eos_accs.sum(-1)
                                full_mask = ((expecteds==eos_id)|(expecteds==resp_id))&((labels==resp_id)|(labels==eos_id))
                                full_accs = torch.zeros_like(manip_preds).long()
                                full_accs[full_mask] = (manip_preds[full_mask]==expecteds[full_mask]).long()
                                full_accs = full_accs.sum(-1)/full_mask.long().sum(-1)
                                
                                # TODO
                                if verbose:
                                    print("manip:", manip_preds.shape)
                                    print("exprt:", expecteds.shape)
                                    print("mask:", mask.shape)
                                    print("mag:", mag_offset)
                                    print("harv:", harvest_targ_count)
                                    print("resp_idx:", resp_idx)
                                    print("stren_idx:", stren_idx)
                                        
                                    perm = np.random.permutation(len(seqs))
                                    for j in range(min(3,len(perm))):
                                        m = int(perm[j])
                                        print("Target Count:", t_counts[m])
                                        if type(mags)==int:
                                            print("Added Magitude:", mags)
                                        else:
                                            print("Added Magitude:", mags[m])
                                        print("SeqIDX:  ", " ".join(["{:2}".format(str(i)) for i in range(len(ret_dict["pred_ids"][m])+1)]))
                                        print("Inpts:   ", " ".join(["{:2}".format(mapping[int(i)]) for i in seqs[m]]))
                                        print("Targs:   ", " ".join(["{:2}".format(mapping[int(i)]) for i in labels[m]]))
                                        print("OG Preds:", " ".join(["{:2}".format(mapping[int(i)]) for i in og_preds[m]]))
                                        print("Intrvnd: ", " ".join(["{:2}".format(mapping[int(i)]) for i in manip_preds[m]]))
                                        #print("OG Expct:", " ".join(["{:2}".format(mapping[int(i)]) for i in seqs[m,1:2*m+print_offset+2]]))
                                        print("Expectds:", " ".join(["{:2}".format(mapping[int(i)]) for i in expecteds[m]]))
                                        print("Tmask:   ", " ".join(["{:2}".format(str(int(i))) for i in tmask[m]]))
                                        print()
                                pred_counts = (mask&(manip_preds==resp_id)).long().sum(-1)
                                results_dict["acc"] += accs.tolist()
                                results_dict["intrv_acc"] += intrv_accs.tolist()
                                results_dict["eos_acc"] += eos_accs.tolist()
                                results_dict["full_acc"] += full_accs.tolist()
                                results_dict["pred_count"] += pred_counts.tolist()
                                results_dict["expect_count"] += (t_counts-mags).tolist()
                                results_dict["targ_count"] += t_counts.tolist()
                                results_dict["n_wrong"] += [int(c) for c in (~raw_accs).long().sum(-1)]
                                results_dict["magnitude"] += [int(mags) for _ in range(len(t_counts))]
                                results_dict["harvest_targ_count"] += [int(harvest_targ_count) for _ in range(len(t_counts))]
                                results_dict["resp_idx"] += [resp_idx for _ in range(len(t_counts))]
                                results_dict["stren_idx"] += [stren_idx for _ in range(len(t_counts))]
                                results_dict["layer"] += [layer for _ in range(len(t_counts))]
                                results_dict["hook_fxn"] += [hook_name for _ in range(len(t_counts))]
                                
                                #missed_chars += [str(int(t)) for t in expecteds[~correct_idxs]]
                                #if torch.any(accs<1) and verbose:
                                #    idx = (accs<1)
                                #    print("Intrvnd: ", " ".join(["{:2}".format(mapping[int(i)]) for i in manip_preds[idx]]))
                                #    print("Expectds:", " ".join(["{:2}".format(mapping[int(i)]) for i in expecteds[idx]]))
                                #    print()
                                if verbose:
                                    print("Accuracy:", accs)
                                    print("Accuracy:", accs.mean())
                        
        model.cpu()
        res_df = pd.DataFrame(results_dict)
        path = os.path.join(model_folder, "tformer_intrvs.csv")
        print("Saving to", path)
        res_df.to_csv(path, header=True, index=False)