import copy
import torch
import random
from torch import nn
import numpy as np
from collections import defaultdict
from lib.utils import (
    layername, make_inputs, get_data
)
from typing import Optional, Tuple, Union
from torch.utils.data import DataLoader

import torch.nn.functional as F

from lib import nethook
import types

def _split_heads(self, tensor, num_heads, attn_head_size):
    """
    Splits hidden_size dim into attn_head_size and num_heads
    """
    new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
    tensor = tensor.view(new_shape)
    return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)


class CausalFlowTracer(object):

    def __init__(self):
        self.traced_paths = {}
    
    def init_node(self):
        self.in_term1 = None
        self.in_term2 = None
        self.in_term3_list = None
        self.in_term4_list = None

    def trace_normal(
            self, model, inp, model_name):
        """
        Trace normal inputs through the model.
        For ViT, we capture the outputs of each transformer block.
        """
        self.feats_normal = None
        self.feats_k_normal = None
        self.feats_v_normal = None
        self.scores_normal = None

        outputs = {"output": [], "output_k": [], "output_v": []}

        def hook_fn(module, input, output):
            outputs["output"].append(output.detach())
            outputs["output_k"].append([1])
            outputs["output_v"].append([1])

        forward_hooks = []
        # Use model.model.blocks for ViT
        if model_name == "pit_ti_224":
            for i in range(len(model.model.transformers)):
                for block in model.model.transformers[i].blocks:
                    forward_hook = block.register_forward_hook(hook_fn)
                    forward_hooks.append(forward_hook)
        else:
            for i, block in enumerate(model.model.blocks):
                forward_hook = block.register_forward_hook(hook_fn)
                forward_hooks.append(forward_hook)
        with torch.no_grad():
            outputs_exp = model(**inp)

            # Handle potential tuple output
            if isinstance(outputs_exp, tuple):
                outputs_exp = outputs_exp[0]
            
            # For ViT, outputs_exp is a tensor, not a dictionary with 'logits'
            if isinstance(outputs_exp, dict) and 'logits' in outputs_exp:
                logits = outputs_exp['logits']
            else:
                # If outputs_exp is a tensor, use it directly
                logits = outputs_exp

        for forward_hook in forward_hooks:
            forward_hook.remove()

        # For ViT, logits are directly the output
        scores_normal = torch.softmax(logits, dim=1)[0] ##softmax(logit)
        answer_t = torch.max(scores_normal, dim=0).indices.unsqueeze(0) #argmax(softmax(logit))

        all_traced_normal = outputs["output"]
        all_traced_k_normal = outputs["output_k"]
        all_traced_v_normal = outputs["output_v"]

        # torch.clone makes different outputs -> please use 
        self.feats_normal = copy.deepcopy(all_traced_normal)
        self.feats_k_normal = copy.deepcopy(all_traced_k_normal)
        self.feats_v_normal = copy.deepcopy(all_traced_v_normal)
        self.scores_normal = copy.deepcopy(scores_normal)

        return answer_t


    def trace_corrupted(
            self, model, tokenizer, images, rand_seed, model_name, dataset_name, noise_type="other", num_noise_sample=3, noise_level=0.5):
        """
        Trace corrupted inputs through the model.
        For ViT, we replace random patches with patches from other images.
        Args:
            noise_level: Float between 0 and 1, fraction of patches to corrupt
        """
        self.feats_corrupted = None
        self.feats_k_corrupted = None
        self.feats_v_corrupted = None
        self.scores_corrupted = None

        torch.manual_seed(rand_seed)
        np.random.seed(rand_seed)
        random.seed(rand_seed)

        dataset, _ = get_data(dataset_name)
        dataloader_for_corrupt = DataLoader(dataset, batch_size=num_noise_sample, shuffle=True)

        # Get a batch of random images from the loader
        corrupted_images = None
        for batch_images, _ in dataloader_for_corrupt:
            # Create corrupted images by replacing random patches
            patch_size = 16
            B, C, H, W = batch_images.shape
            num_patches_h = H // patch_size
            num_patches_w = W // patch_size
            total_patches = num_patches_h * num_patches_w
            
            # Extract all patches from all images
            patches = batch_images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
            patches = patches.contiguous().view(B, C, -1, patch_size, patch_size)  # [B, C, num_patches, 16, 16]
            
            # Create random indices for shuffling
            shuffle_indices = torch.randperm(B * total_patches) 
            
            # Create new images with shuffled patches
            corrupted_images = torch.zeros_like(batch_images)
            for i in range(B):
                for p in range(total_patches):
                    # Calculate source batch and patch indices
                    shuffle_idx = shuffle_indices[i * total_patches + p]
                    src_batch_idx = shuffle_idx // total_patches
                    src_patch_idx = shuffle_idx % total_patches
                    
                    # Calculate patch positions
                    h_idx = (p // num_patches_w) * patch_size
                    w_idx = (p % num_patches_w) * patch_size
                    
                    # Place shuffled patch in new position
                    corrupted_images[i, :, h_idx:h_idx+patch_size, w_idx:w_idx+patch_size] = \
                        patches[src_batch_idx, :, src_patch_idx]
            
            corrupted_images = corrupted_images.cuda()
            break  # Just take the first batch
        if corrupted_images is None:
            raise ValueError("Could not load images from the loader")

        inputs_block = {"input": []}
        outputs = {"output": [], "output_k": [], "output_v": []}

        def make_hook_fn(block_idx):
            def hook_fn(module, input, output):
                if block_idx==0:
                    inputs_block["input"].append(input[0].detach())
                outputs["output"].append(output.detach())
                outputs["output_k"].append([1])
                outputs["output_v"].append([1])
            return hook_fn

        forward_hooks = []
        # Use model.model.blocks for ViT
        if model_name == "pit_ti_224":
            for i in range(len(model.model.transformers)):
                for j, block in enumerate(model.model.transformers[i].blocks):
                    forward_hook = block.register_forward_hook(make_hook_fn(j))
                    forward_hooks.append(forward_hook)
        else:
            for i, block in enumerate(model.model.blocks):
                forward_hook = block.register_forward_hook(make_hook_fn(i))
                forward_hooks.append(forward_hook)

        with torch.no_grad():
            outputs_exp = model(**{"pixel_values": corrupted_images})
            
             # Handle potential tuple output
            if isinstance(outputs_exp, tuple):
                outputs_exp = outputs_exp[0]

            # For ViT, outputs_exp is a tensor, not a dictionary with 'logits'
            if isinstance(outputs_exp, dict) and 'logits' in outputs_exp:
                logits = outputs_exp['logits']
            else:
                # If outputs_exp is a tensor, use it directly
                logits = outputs_exp

        for forward_hook in forward_hooks:
            forward_hook.remove()

        # For ViT, logits are directly the output
        probs_corrupted = torch.softmax(logits, dim=1).mean(dim=0).unsqueeze(0)
        answer_t = torch.max(probs_corrupted, dim=0).indices #.unsqueeze(0)

        all_traced_corrupted = outputs["output"]
        all_traced_k_corrupted = outputs["output_k"]
        all_traced_v_corrupted = outputs["output_v"]

        self.feats_corrupted_init = copy.deepcopy(inputs_block["input"][0])
        self.feats_corrupted = copy.deepcopy(all_traced_corrupted)
        self.feats_k_corrupted = copy.deepcopy(all_traced_k_corrupted)
        self.feats_v_corrupted = copy.deepcopy(all_traced_v_corrupted)
        self.scores_corrupted = copy.deepcopy(probs_corrupted)

        return answer_t

def custom_sort_indices(arr, order):
    priority = {string: i for i, string in enumerate(order)}
    sorted_indices = sorted(range(len(arr)), key=lambda x: priority[arr[x]])
    
    return sorted_indices
def converter(path, quiet=False, node_seq=["OB", "OM", "IM", "OA", "A", "VH", "KH", "Q", "K", "V", "IA"]):
    path = np.unique(path, axis=0)
    block_idx = []
    for edge in path:
        block_idx.append(int(edge[0][2]))
    sorted_idx = np.argsort(block_idx)[::-1]
    sorted_path = np.asarray(path)[sorted_idx]

    buffer = []
    for b_idx in np.unique(block_idx)[::-1]:
        w_idx = np.where(sorted_path[:, 0, 2]==str(b_idx))[0]
        data = sorted_path[w_idx]

        curr_token_list = data[:, 0, 1].astype(int)
        t_sorted_idx = np.argsort(curr_token_list)
        t_sorted_data = data[t_sorted_idx]
        t_sorted_list = curr_token_list[t_sorted_idx]

        for c_t_i in np.unique(t_sorted_list):
            t_idx = np.where(t_sorted_list==c_t_i)[0]
            curr_buffer = t_sorted_data[t_idx]
            s_idx = custom_sort_indices(curr_buffer[:, 0, 0], node_seq)
            curr_buffer = curr_buffer[s_idx].tolist()
            
            buffer.extend(curr_buffer)

    lines = []
    for edge in buffer:
        line = "{:2s}.{:2s}.{:2s} -> {:2s}.{:2s}.{:2s}".format(edge[0][0], edge[0][1], edge[0][2], edge[1][0], edge[1][1], edge[1][2])
        lines.append(line)
        if quiet is False:
            print(line)
  
    return lines 


def get_lm_inp_mu_std(x, eps):
    mu = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, unbiased=False, keepdim=True)
    std = torch.sqrt(var + eps)
    return mu, std

def make_layernorm2mm(x, layernorm_layer, prefix_mu=None, prefix_std=None):
    # x.shape == [batch, token, hidden]
    if (prefix_mu is None) | (prefix_std is None):
        mu, std = get_lm_inp_mu_std(x, eps=layernorm_layer.eps)
    else:
        mu = prefix_mu
        std = prefix_std


    w_ = layernorm_layer.weight / std
    b = (-layernorm_layer.weight * mu / std + layernorm_layer.bias)
    w = torch.diag_embed(w_)

    return w, b

def ln_mm(x, W, b=None):
    out = x.unsqueeze(-2) @ W.transpose(-1, -2)
    out = out.squeeze(-2)
    if b is not None:
        out = out + b 
    return out 

def make_homoneneous_coord(x):
    ones = torch.ones_like(x[..., :1])
    x_ext = torch.cat([x, ones], dim=-1)
    return x_ext

def homogen_mm(x, W):
    return (make_homoneneous_coord(x).unsqueeze(-2) @ W).squeeze(-2)

def layernorm2mm(x, layernorm_layer, prefix_mu=None, prefix_std=None, bias_divide=None):
    # x.shape == [batch, token, hidden]
    w, b = make_layernorm2mm(x, layernorm_layer, prefix_mu, prefix_std)
    out = x * w + b  # Broadcasting: [batch, token, hidden] * [hidden] + [hidden]
    return out

def linear2mm(x, linear_layer, bias_divide=None, ignore_bias=False, weight_slice=None):
    # x.shape == [batch, token, hidden]
    if weight_slice is not None:
        W = linear_layer.weight
    else:
        W = linear_layer.weight[:, weight_slice[0]:weight_slice[1]]
    b = linear_layer.bias

    if (b is None) or (ignore_bias):
        return torch.matmul(x, W)

    x_ext = make_homoneneous_coord(x)
    if bias_divide is None:
        A = torch.cat([W, b.unsqueeze(1)], dim=1)
    else:
        A = torch.cat([W, b.unsqueeze(1)/bias_divide], dim=1)

    out = torch.matmul(x_ext, A.T)

    return out

def subsetidx2pathnodeidx(subsetidx, num_heads):
    subsetidx_list = list(subsetidx)
    target_idx_term1 = []
    target_idx_term2 = []
    target_idx_term3 = []
    target_idx_term4 = []

    for idx in subsetidx_list:
        if (idx==0):
            target_idx_term1.append(idx)
        elif (idx==1):
            target_idx_term2.append(idx-1)
        elif (idx>=2) & (idx<(2+num_heads)):
            target_idx_term3.append(idx-2)
        elif (idx>=(2+num_heads)):
            target_idx_term4.append(idx-2-num_heads)

    other_idx_term1 = np.setdiff1d([0], target_idx_term1).tolist()
    other_idx_term2 = np.setdiff1d([0], target_idx_term2).tolist()
    other_idx_term3 = np.setdiff1d(np.arange(num_heads), target_idx_term3).tolist()
    other_idx_term4 = np.setdiff1d(np.arange(num_heads), target_idx_term4).tolist()

    return (target_idx_term1, target_idx_term2, target_idx_term3, target_idx_term4), (other_idx_term1, other_idx_term2, other_idx_term3, other_idx_term4)

def path_level_intervention_forward(
    self,
    hidden_states: Optional[Tuple[torch.FloatTensor]],
    layer_past: Optional[Tuple[torch.Tensor]] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = False,
    output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:


    if hasattr(self, "anonymoususer_save_mode"):
        if self.anonymoususer_save_mode:
            if hasattr(self, "anonymoususer_feats_corrupted_init"):
                # import pdb; pdb.set_trace()
                hidden_states = copy.deepcopy(self.anonymoususer_feats_corrupted_init)
                # import pdb; pdb.set_trace()
    
    residual = hidden_states
    if isinstance(hidden_states, tuple):
        hidden_states = hidden_states[0]

    #-------for Path Division-------#
    residual_in = copy.deepcopy(residual)
    #-------for Path Division-------#
          
    hidden_states = self.norm1(hidden_states)

    inp_ln_1 = copy.deepcopy(hidden_states)
    
    # print(attention_mask)
    attn_outputs, head_wise_attn_output = self.attn(
        hidden_states,
        layer_past=layer_past,
        attention_mask=attention_mask,
        head_mask=head_mask,
        use_cache=use_cache,
        output_attentions=output_attentions,
    )

    attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
    outputs = attn_outputs[1:]
    # residual connection
    residual = residual[0]
    hidden_states = attn_output + residual

    if encoder_hidden_states is not None:
        # add one self-attention block for cross-attention
        if not hasattr(self, "crossattention"):
            raise ValueError(
                f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                "cross-attention layers by setting `config.add_cross_attention=True`"
            )
        residual = hidden_states
        hidden_states = self.ln_cross_attn(hidden_states)
        cross_attn_outputs = self.crossattention(
            hidden_states,
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
        )
        attn_output = cross_attn_outputs[0]
        # residual connection
        hidden_states = residual + attn_output
        outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights
    
    residual = copy.deepcopy(hidden_states.detach())
    #-------for Path Division-------#
    ln2_mu, ln2_std = get_lm_inp_mu_std(copy.deepcopy(hidden_states.detach()), self.norm2.eps)
    W_ln_mlp, b_ln_mlp = make_layernorm2mm([], self.norm2, prefix_mu=ln2_mu, prefix_std=ln2_std)
    #-------for Path Division-------#

    # torch.allclose(layernorm2mm(hidden_states, self.ln_2, prefix_mu=ln2_mu, prefix_std=ln2_std), self.ln_2(hidden_states), rtol=1e-5, atol=1e-6)

    hidden_states_ln2 = self.norm2(hidden_states)
    feed_forward_hidden_states, gelu_input_states = self.mlp(hidden_states_ln2)

    #-------for Path Division-------#
    D_a = torch.where(gelu_input_states.abs() > 1e-6, nn.functional.gelu(gelu_input_states) / gelu_input_states, torch.full_like(gelu_input_states, 0.5))
    #-------for Path Division-------#

    # residual connection
    hidden_states = residual + feed_forward_hidden_states

    #-------for Path Division-------#
    org_output = copy.deepcopy(hidden_states)
    #-------for Path Division-------#

    #-------Path Division-------#
    # in_term1 = layernorm2mm(inp_ln_1, self.norm2, prefix_mu=ln2_mu, prefix_std=ln2_std)
    # in_term2 = self.mlp(inp_ln_1, custom_gelu=D_a)[0]

    in_term1 = residual_in[0]
    in_term2 = self.mlp(ln_mm(residual_in[0], W_ln_mlp, b=None), custom_gelu=D_a, bias_divide=None, ignore_bias=True)[0]

    in_term3_list, in_term4_list = [], []
    debug_attn_out_proj = []
    for h_i in range(self.attn.num_heads):
        # Note: nn.Linear -> x*W^T, but, Conv1d -> x*W
        h_attn_out_proj = nn.functional.linear(
            head_wise_attn_output[:, h_i], 
            weight=self.attn.proj.weight[:, self.attn.head_dim * (h_i):self.attn.head_dim  * (h_i+1)],
            bias=self.attn.proj.bias/self.attn.num_heads
            )
        
        # in_term3_i = layernorm2mm(h_attn_out_proj, self.norm2, prefix_mu=ln2_mu, prefix_std=ln2_std, bias_divide=self.attn.num_heads)
        in_term3_i = h_attn_out_proj
        in_term3_list.append(in_term3_i.unsqueeze(0))

        # in_term4_i = self.mlp(in_term3_i, custom_gelu=D_a, bias_divide=self.attn.num_heads)[0]
        in_term4_i = self.mlp(ln_mm(copy.deepcopy(h_attn_out_proj.detach()), W_ln_mlp, b=None), custom_gelu=D_a, bias_divide=None, ignore_bias=True)[0]
        in_term4_list.append(in_term4_i.unsqueeze(0))

        # Debug dimensions and adjust in_term3_i
        # print(f"Original in_term3_i shape: {in_term3_i.shape}")
        # print(f"fc1 weight shape: {self.mlp.fc1.weight.shape}")
        # print(f"D_a shape: {D_a.shape}")
        # print(f"D_a element: {torch.diag_embed(D_a).shape}")
        # print(f"fc2 weight shape: {self.mlp.fc2.weight.shape}")

        # in_term3_i.unsqueeze(2) @ self.mlp.fc1.weight.T @ torch.diag_embed(D_a) @ self.mlp.fc2.weight.T

    in_term3_list = torch.cat(in_term3_list, dim=0)
    in_term4_list = torch.cat(in_term4_list, dim=0)
    #-------Path Division-------#

    #-------Floating-point Error Correction-------#
    mlp_bias_term1 = ((b_ln_mlp @ self.mlp.fc1.weight.T) * D_a) @ self.mlp.fc2.weight.T
    mlp_bias_term2 = (self.mlp.fc1.bias * D_a) @ self.mlp.fc2.weight.T
    mlp_bias_term3 = self.mlp.fc2.bias
    mlp_bias_term_all = mlp_bias_term1 + mlp_bias_term2 + mlp_bias_term3 

    in_term2 += (mlp_bias_term_all/(self.attn.num_heads+1))
    in_term4_list += (mlp_bias_term_all/(self.attn.num_heads+1)).unsqueeze(0)

    divided_output = in_term1 + in_term2 + torch.sum(in_term3_list, dim=0) + torch.sum(in_term4_list, dim=0)
                    
    comp_err = divided_output-org_output
    per_comp_err = comp_err/(2*self.attn.num_heads + 2)
    #-------Floating-point Error Correction-------#

    if hasattr(self, "anonymoususer_save_mode"):
        if self.anonymoususer_save_mode:
            self.anonymoususer_flow_tracer.in_term1 = copy.deepcopy(in_term1)
            self.anonymoususer_flow_tracer.in_term2 = copy.deepcopy(in_term2)
            self.anonymoususer_flow_tracer.in_term3_list = copy.deepcopy(in_term3_list)
            self.anonymoususer_flow_tracer.in_term4_list = copy.deepcopy(in_term4_list)
            self.anonymoususer_flow_tracer.org_output = copy.deepcopy(org_output)

    if hasattr(self, "anonymoususer_trace_mode"):
        if self.anonymoususer_trace_mode:
            
            target_idx_term, other_idx_term = subsetidx2pathnodeidx(self.anonymoususer_curr_subset, self.attn.num_heads)
            if self.anonymoususer_cond in ["path", "contingency"]:
                
                if len(target_idx_term[0])==0: 
                    in_term1 = copy.deepcopy(self.anonymoususer_corrupted_feats.in_term1)
                if len(target_idx_term[1])==0:
                    in_term2 = copy.deepcopy(self.anonymoususer_corrupted_feats.in_term2)
                if len(other_idx_term[2])!=0:
                    in_term3_list[other_idx_term[2]] = copy.deepcopy(self.anonymoususer_corrupted_feats.in_term3_list[other_idx_term[2]])
                if len(other_idx_term[3])!=0:
                    in_term4_list[other_idx_term[3]] = copy.deepcopy(self.anonymoususer_corrupted_feats.in_term4_list[other_idx_term[3]])
            elif self.anonymoususer_cond in ["counterfactual"]:
                in_term1 = copy.deepcopy(self.anonymoususer_corrupted_feats.in_term1)
                in_term2 = copy.deepcopy(self.anonymoususer_corrupted_feats.in_term2)
                in_term3_list = copy.deepcopy(self.anonymoususer_corrupted_feats.in_term3_list)
                in_term4_list = copy.deepcopy(self.anonymoususer_corrupted_feats.in_term4_list)
            else:
                import pdb; pdb.set_trace()

    # Unfolded output
    hidden_states = in_term1 + in_term2 + torch.sum(in_term3_list, dim=0) + torch.sum(in_term4_list, dim=0)

    if hasattr(self, "anonymoususer_trace_mode"):
        if self.anonymoususer_trace_mode:

            if self.anonymoususer_cond in ["path", "contingency"]:
                other_idx_term_len = [len(i) for i in other_idx_term]
                # import pdb; pdb.set_trace()
                if sum(other_idx_term_len)==0:  # If the target subset is equal to the entire set, compensates for the error.
                    hidden_states = org_output
                    # import pdb; pdb.set_trace()
            elif self.anonymoususer_cond in ["counterfactual"]:
                hidden_states = copy.deepcopy(self.anonymoususer_corrupted_feats.org_output)
            else:
                import pdb; pdb.set_trace()


    if use_cache:
        outputs = (hidden_states,) + outputs
    else:
        outputs = (hidden_states,) + outputs[1:]

    return outputs  # hidden_states, present, (attentions, cross_attentions)



def custom_linear_forward(x, layer, bias_divide=1, ignore_bias=False):
    size_out = x.size()[:-1] + (layer.out_features,)
    # x = torch.addmm(layer.bias/bias_divide, x.view(-1, x.size(-1)), layer.weight)
    if ignore_bias:
        x_reshaped = x.view(-1, x.size(-1))
        x = x_reshaped @ layer.weight.T + torch.zeros_like(layer.bias)
    else:
        x_reshaped = x.view(-1, x.size(-1))
        x = x_reshaped @ layer.weight.T + layer.bias/bias_divide

    x = x.view(size_out)

    return x


def custom_mlp_forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]], custom_gelu=None, bias_divide=None, ignore_bias=False) -> torch.FloatTensor:
    if (bias_divide==None) & (ignore_bias==False):
        hidden_states = self.fc1(hidden_states)
    else:
        hidden_states = custom_linear_forward(hidden_states, self.fc1, bias_divide=bias_divide, ignore_bias=ignore_bias)

    gelu_input_states = copy.deepcopy(hidden_states.detach())
    if custom_gelu is None:
        hidden_states = self.act(hidden_states)
    else:
        hidden_states = hidden_states * custom_gelu

    if (bias_divide==None) & (ignore_bias==False):
        hidden_states = self.fc2(hidden_states)
    else:
        hidden_states = custom_linear_forward(hidden_states, self.fc2, bias_divide=bias_divide, ignore_bias=ignore_bias)
    hidden_states = self.drop2(hidden_states)
    return hidden_states, gelu_input_states

def _split_heads(self, tensor, num_heads, attn_head_size):
    """
    Splits hidden_size dim into attn_head_size and num_heads
    """
    new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
    tensor = tensor.view(new_shape)
    return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)


# def custom_attn_forward(
#     self,
#     hidden_states: Optional[Tuple[torch.FloatTensor]],
#     layer_past: Optional[Tuple[torch.Tensor]] = None,
#     attention_mask: Optional[torch.FloatTensor] = None,
#     head_mask: Optional[torch.FloatTensor] = None,
#     encoder_hidden_states: Optional[torch.Tensor] = None,
#     encoder_attention_mask: Optional[torch.FloatTensor] = None,
#     use_cache: Optional[bool] = False,
#     output_attentions: Optional[bool] = False,
# ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
#     if encoder_hidden_states is not None:
#         if not hasattr(self, "q_attn"):
#             raise ValueError(
#                 "If class is used as cross attention, the weights `q_attn` have to be defined. "
#                 "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
#             )

#         query = self.q_attn(hidden_states)
#         key, value = self.qkv(encoder_hidden_states).split(self.split_size, dim=2)
#         attention_mask = encoder_attention_mask
#     else:
#         self.split_size = self.qkv.in_features
#         query, key, value = self.qkv(hidden_states).split(self.split_size, dim=2)
#     query = self._split_heads(query, self.num_heads, self.head_dim)
#     key = self._split_heads(key, self.num_heads, self.head_dim)
#     value = self._split_heads(value, self.num_heads, self.head_dim)

#     if layer_past is not None:
#         past_key, past_value = layer_past
#         key = torch.cat((past_key, key), dim=-2)
#         value = torch.cat((past_value, value), dim=-2)

#     if use_cache is True:
#         present = (key, value)
#     else:
#         present = None

#     # if self.reorder_and_upcast_attn:
#         # print("Not Supported!")
#         # import pdb; pdb.set_trace()
#         # attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
#     # else:
#     #### 아마도 _attn을 만들어야함. 어디서 가져오던지. 이건 sm(q*v)*v 연산임. 
#     attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

    

#     head_wise_attn_output = copy.deepcopy(attn_output.detach())

#     attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
#     attn_output = self.proj(attn_output)
#     attn_output = self.proj_drop(attn_output)

#     outputs = (attn_output, present)
#     if output_attentions:
#         outputs += (attn_weights,)

#     return outputs, head_wise_attn_output # a, present, (attentions)

def custom_attn_forward(self, x, **kwargs) -> torch.Tensor:

    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)
    q, k = self.q_norm(q), self.k_norm(k)

    if self.fused_attn:
        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_drop.p if self.training else 0.,
        )
    else:
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

    head_wise_attn_output = copy.deepcopy(x)

    x = x.transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return (x,x,x), head_wise_attn_output

def org_attn_forward(self, x, **kwargs) -> torch.Tensor:
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)
    q, k = self.q_norm(q), self.k_norm(k)

    if self.fused_attn:
        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_drop.p if self.training else 0.,
        )
    else:
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

    x = x.transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return (x,x,x), x

# def org_attn_forward(
#     self,
#     hidden_states: Optional[Tuple[torch.FloatTensor]],
#     layer_past: Optional[Tuple[torch.Tensor]] = None,
#     attention_mask: Optional[torch.FloatTensor] = None,
#     head_mask: Optional[torch.FloatTensor] = None,
#     encoder_hidden_states: Optional[torch.Tensor] = None,
#     encoder_attention_mask: Optional[torch.FloatTensor] = None,
#     use_cache: Optional[bool] = False,
#     output_attentions: Optional[bool] = False,
# ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
#     if encoder_hidden_states is not None:
#         if not hasattr(self, "q_attn"):
#             raise ValueError(
#                 "If class is used as cross attention, the weights `q_attn` have to be defined. "
#                 "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
#             )

#         query = self.q_attn(hidden_states)
#         key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
#         attention_mask = encoder_attention_mask
#     else:
#         query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
#     query = self._split_heads(query, self.num_heads, self.head_dim)
#     key = self._split_heads(key, self.num_heads, self.head_dim)
#     value = self._split_heads(value, self.num_heads, self.head_dim)

#     if layer_past is not None:
#         past_key, past_value = layer_past
#         key = torch.cat((past_key, key), dim=-2)
#         value = torch.cat((past_value, value), dim=-2)

#     if use_cache is True:
#         present = (key, value)
#     else:
#         present = None

#     if self.reorder_and_upcast_attn:
#         print("Not Supported!")
#         import pdb; pdb.set_trace()
#         attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
#     else:
#         attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

#     attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
#     attn_output = self.c_proj(attn_output)
#     attn_output = self.resid_dropout(attn_output)

#     outputs = (attn_output, present)
#     if output_attentions:
#         outputs += (attn_weights,)

#     return outputs  # a, present, (attentions)



def org_attn(self, query, key, value, *args, **kwargs) -> torch.Tensor:
    import pdb; pdb.set_trace()

    q = query * self.scale
    attn = query @ key.transpose(-2, -1)
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)
    x = attn @ value

    return x, attn


    attn_weights = torch.matmul(query, key.transpose(-1, -2))

    # if self.scale_attn_weights:
    attn_weights = attn_weights / torch.full(
        [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
    )

    # if attention_mask is not None:
    #     # Apply the attention mask
    #     attn_weights = attn_weights + attention_mask

    # attn_weights = nn.functional.softmax(attn_weights, dim=-1)

    # # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
    # attn_weights = attn_weights.type(value.dtype)
    # attn_weights = self.attn_drop(attn_weights)

    # # Mask heads if we want to
    # if head_mask is not None:
    #     attn_weights = attn_weights * head_mask

    # attn_output = torch.matmul(attn_weights, value)

    return attn_output, attn_weights

# def org_mlp_forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
#     hidden_states = self.c_fc(hidden_states)
#     hidden_states = self.act(hidden_states)
#     hidden_states = self.c_proj(hidden_states)
#     hidden_states = self.dropout(hidden_states)
#     return hidden_states

def org_mlp_forward(self, x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop1(x)
    x = self.norm(x)
    x = self.fc2(x)
    x = self.drop2(x)
    return x

def org_block_forward(self, x: torch.Tensor) -> torch.Tensor:

    x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))[1]))
    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
    return x


def org_pit_block_forward(self, x: torch.Tensor) -> torch.Tensor:
    print("org_pit_block_forward input type:", type(x))
    if isinstance(x, tuple):
        print("org_pit_block_forward input tuple shapes:", [xx.shape if hasattr(xx, 'shape') else type(xx) for xx in x])
    else:
        print("org_pit_block_forward input shape:", x.shape)

    x, cls_tokens = x
    token_length = cls_tokens.shape[1]
    if self.pool is not None:
        x, cls_tokens = self.pool(x, cls_tokens)

    B, C, H, W = x.shape
    x = x.flatten(2).transpose(1, 2)
    x = torch.cat((cls_tokens, x), dim=1)

    x = self.norm(x)
    x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
    
    print(x.shape)
    cls_tokens = x[:, :token_length]
    x = x[:, token_length:]
    x = x.transpose(1, 2).reshape(B, C, H, W)

    return x, cls_tokens

# def org_block_forward(
#     self,
#     hidden_states: Optional[Tuple[torch.FloatTensor]],
#     layer_past: Optional[Tuple[torch.Tensor]] = None,
#     attention_mask: Optional[torch.FloatTensor] = None,
#     head_mask: Optional[torch.FloatTensor] = None,
#     encoder_hidden_states: Optional[torch.Tensor] = None,
#     encoder_attention_mask: Optional[torch.FloatTensor] = None,
#     use_cache: Optional[bool] = False,
#     output_attentions: Optional[bool] = False,
# ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
    
    
#     residual = hidden_states
#     hidden_states = self.norm1(hidden_states)
#     attn_outputs = self.attn(
#         hidden_states,
#         layer_past=layer_past,
#         attention_mask=attention_mask,
#         head_mask=head_mask,
#         use_cache=use_cache,
#         output_attentions=output_attentions,
#     )
#     attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
#     outputs = attn_outputs[1:]
#     # residual connection
#     hidden_states = attn_output + residual

#     if encoder_hidden_states is not None:
#         # add one self-attention block for cross-attention
#         if not hasattr(self, "crossattention"):
#             raise ValueError(
#                 f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
#                 "cross-attention layers by setting `config.add_cross_attention=True`"
#             )
#         residual = hidden_states
#         hidden_states = self.ln_cross_attn(hidden_states)
#         cross_attn_outputs = self.crossattention(
#             hidden_states,
#             attention_mask=attention_mask,
#             head_mask=head_mask,
#             encoder_hidden_states=encoder_hidden_states,
#             encoder_attention_mask=encoder_attention_mask,
#             output_attentions=output_attentions,
#         )
#         attn_output = cross_attn_outputs[0]
#         # residual connection
#         hidden_states = residual + attn_output
#         outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

#     residual = hidden_states
#     hidden_states = self.ln_2(hidden_states)

#     feed_forward_hidden_states = self.mlp(hidden_states)

#     # residual connection
#     hidden_states = residual + feed_forward_hidden_states

#     if use_cache:
#         outputs = (hidden_states,) + outputs
#     else:
#         outputs = (hidden_states,) + outputs[1:]

#     return outputs  # hidden_states, present, (attentions, cross_attentions)




def pass_block_forward(
    self,
    hidden_states: Optional[Tuple[torch.FloatTensor]],
    layer_past: Optional[Tuple[torch.Tensor]] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = False,
    output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
    
    if isinstance(hidden_states, tuple):
        # print(len(hidden_states))
        # print(hidden_states[0] == hidden_states[1], 'check!')
        B, _, _ = hidden_states[0].shape
    else:
        B, _, _ = hidden_states.shape
    attn_outputs = (copy.deepcopy(self.anonymoususer_feats_k_normal), copy.deepcopy(self.anonymoususer_feats_v_normal))
    outputs = (copy.deepcopy(self.anonymoususer_feats_normal), attn_outputs)

    return outputs  # hidden_states, present, (attentions, cross_attentions)

def pass_corrupted_block_forward(
    self,
    hidden_states: Optional[Tuple[torch.FloatTensor]],
    layer_past: Optional[Tuple[torch.Tensor]] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = False,
    output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
    if isinstance(hidden_states, tuple):
        B, _, _ = hidden_states[0].shape
    else:
        B, _, _ = hidden_states.shape
    attn_outputs = (copy.deepcopy(self.anonymoususer_feats_k_corrupted), copy.deepcopy(self.anonymoususer_feats_v_corrupted))
    outputs = (copy.deepcopy(self.anonymoususer_feats_corrupted), attn_outputs)

    return outputs  # hidden_states, present, (attentions, cross_attentions)


FUNCTION_MAP = {
    "deit": {
        "org_block_forward": org_block_forward,
        "org_attn_forward": org_attn_forward,
        "org_mlp_forward": org_mlp_forward,
        "intervention_forward": path_level_intervention_forward,
        "custom_attn_forward": custom_attn_forward,
        "custom_mlp_forward": custom_mlp_forward,
        "pass_block_forward": pass_block_forward,
        "pass_corrupted_block_forward": pass_corrupted_block_forward
        },
    "vit": {
        "org_block_forward": org_block_forward,
        "org_attn_forward": org_attn_forward,
        "org_mlp_forward": org_mlp_forward,
        "intervention_forward": path_level_intervention_forward,
        "custom_attn_forward": custom_attn_forward,
        "custom_mlp_forward": custom_mlp_forward,
        "pass_block_forward": pass_block_forward,
        "pass_corrupted_block_forward": pass_corrupted_block_forward
    },
    "pit": {
        "org_block_forward": org_pit_block_forward,
        "org_attn_forward": org_attn_forward,
        "org_mlp_forward": org_mlp_forward,
        "intervention_forward": path_level_intervention_forward,
        "custom_attn_forward": custom_attn_forward,
        "custom_mlp_forward": custom_mlp_forward,
        "pass_block_forward": pass_block_forward,
        "pass_corrupted_block_forward": pass_corrupted_block_forward
    },
}

def return_forward_method_dict(args):
    if "vit" in args.model:
        return FUNCTION_MAP["vit"]
    elif "deit" in args.model:
        return FUNCTION_MAP["deit"]
    elif "pit" in args.model:
        return FUNCTION_MAP["pit"]
    import pdb; pdb.set_trace()
