from collections import defaultdict
from transformers import CLIPProcessor, CLIPVisionModel, CLIPModel, CLIPVisionModelWithProjection, \
    CLIPTextModelWithProjection
import types
import os
import torch
import argparse
from PIL import Image
from tqdm import tqdm
class CLIPVisionEncoderWithOutputs:
    def __init__(self, model_name="openai/clip-vit-large-patch14", device="cuda"):
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model = CLIPVisionModelWithProjection.from_pretrained(model_name, attn_implementation="eager").to(device)
        self.projection = self.model.visual_projection
        self.layer_outputs = defaultdict(dict)
        self.device = device
        self._modify_attention_modules()
        self._register_hooks()

    def _modify_attention_modules(self):
        for i, layer in enumerate(self.model.vision_model.encoder.layers):
            attention_module = layer.self_attn
            attn_class_name = attention_module.__class__.__name__
            original_forward = attention_module.forward
            def new_forward(self, hidden_states, attention_mask=None, causal_attention_mask=None,
                            output_attentions=False, layer_idx=i):
                if hasattr(self, '_original_forward'):
                    attn_output, attn_weights = self._original_forward(
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        causal_attention_mask=causal_attention_mask,
                        output_attentions=True
                    )
                else:
                    parent_class = self.__class__.__mro__[1]
                    parent_forward = getattr(parent_class, 'forward')
                    attn_output, attn_weights = parent_forward(
                        self,
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        causal_attention_mask=causal_attention_mask,
                        output_attentions=True
                    )
                layer_outputs = getattr(self, 'layer_outputs', None)
                if layer_outputs is None:
                    return (attn_output, attn_weights) if output_attentions else attn_output
                layer_outputs[layer_idx]['attention_weights'] = attn_weights.detach().clone()
                bsz, tgt_len, embed_dim = hidden_states.size()
                q_proj = getattr(self, 'q_proj', None) or getattr(self, 'query', None)
                k_proj = getattr(self, 'k_proj', None) or getattr(self, 'key', None)
                v_proj = getattr(self, 'v_proj', None) or getattr(self, 'value', None)

                if q_proj and k_proj and v_proj:
                    query_states = q_proj(hidden_states)
                    key_states = k_proj(hidden_states)
                    value_states = v_proj(hidden_states)

                    num_heads = getattr(self, 'num_heads', None) or getattr(self, 'num_attention_heads', None)
                    head_dim = getattr(self, 'head_dim', None)
                    if head_dim is None and num_heads is not None:
                        head_dim = embed_dim // num_heads

                    if num_heads is not None and head_dim is not None:
                        query_states = query_states.view(bsz, -1, num_heads, head_dim).transpose(1, 2)
                        key_states = key_states.view(bsz, -1, num_heads, head_dim).transpose(1, 2)
                        value_states = value_states.view(bsz, -1, num_heads, head_dim).transpose(1, 2)

                        cls_attn_weights = attn_weights[:, :, 0, :]  # [bsz, num_heads, 1, seq_len]

                        cls_attn_weights_expanded = cls_attn_weights.unsqueeze(-1)  # [bsz, num_heads, 1, seq_len, 1]
                        value_expanded = value_states  # [bsz, num_heads, 1, seq_len, head_dim]
                        cls_subvalues = cls_attn_weights_expanded * value_expanded
                        layer_outputs[layer_idx]['cls_subvalues'] = cls_subvalues.detach().clone()

                layer_outputs[layer_idx]['attention_output'] = attn_output.detach().clone()

                return (attn_output, attn_weights) if output_attentions else attn_output

            attention_module._original_forward = original_forward

            bound_method = types.MethodType(new_forward, attention_module)
            attention_module.forward = bound_method
            attention_module.layer_outputs = self.layer_outputs

    def _register_hooks(self):
        for i, layer in enumerate(self.model.vision_model.encoder.layers):
            layer.register_forward_pre_hook(self._get_layer_input_hook(i))

            layer.register_forward_hook(self._get_layer_output_hook(i))

            layer.layer_norm2.register_forward_hook(self._get_residual_input_hook(i))

            layer.mlp.register_forward_hook(self._get_ffn_output_hook(i))

            layer.mlp.activation_fn.register_forward_hook(self._get_activation_output_hook(i))

    def _get_layer_input_hook(self, layer_idx):
        def hook(module, input):
            self.layer_outputs[layer_idx]['layer_input'] = input[0].detach().clone()

        return hook

    def _get_residual_input_hook(self, layer_idx):
        def hook(module, input, output):
            self.layer_outputs[layer_idx]['residual_output'] = input[0].detach().clone()

        return hook

    def _get_activation_output_hook(self, layer_idx):
        def hook(module, input, output):
            self.layer_outputs[layer_idx]['coefficient'] = output.detach().clone()

        return hook

    def _get_layer_output_hook(self, layer_idx):
        def hook(module, input, output):
            self.layer_outputs[layer_idx]['layer_output'] = output[0].detach().clone()

            if all(k in self.layer_outputs[layer_idx] for k in ['layer_input', 'attention_output', 'ffn_output']):
                layer_input = self.layer_outputs[layer_idx]['layer_input']
                attn_output = self.layer_outputs[layer_idx]['attention_output']
                ffn_output = self.layer_outputs[layer_idx]['ffn_output']
                layer_output = self.layer_outputs[layer_idx]['layer_output']

                reconstructed = layer_input + attn_output + ffn_output
                error = torch.mean(torch.abs(reconstructed - layer_output))
                self.layer_outputs[layer_idx]['reconstruction_error'] = error.item()

        return hook

    def _get_ffn_output_hook(self, layer_idx):
        def hook(module, input, output):
            self.layer_outputs[layer_idx]['ffn_output'] = output.detach().clone()

        return hook

    def extract_features(self, image_path, save_outputs=False):
        image = Image.open(image_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt").to(self.device)

        self.layer_outputs.clear()

        with torch.no_grad():
            outputs = self.model(
                pixel_values=inputs.pixel_values,
                output_attentions=True
            )
        attention_weights = outputs.attentions
        for i, attn in enumerate(attention_weights):
            self.layer_outputs[i]['attention_weights'] = attn.detach().clone()
        if save_outputs:
            torch.save(self.layer_outputs, "clip_vision_outputs_transformers.pt")
            print("Saved all outputs to clip_vision_outputs_transformers.pt")

        return outputs.image_embeds, self.layer_outputs


def get_fc2_params(model, layer_num):
    return model.vision_model.encoder.layers[layer_num].mlp.fc2.weight.data.T


def get_fc2_bias(model, layer_num):
    return model.vision_model.encoder.layers[layer_num].mlp.fc2.bias.data.unsqueeze(0)


def get_bsvalues(vector, extractor, u, s, label_emb):
    vector = (vector - u) / torch.sqrt(s + 1e-12)
    shape = [1] * (len(vector.shape) - 1) + [-1]
    weight = extractor.model.vision_model.post_layernorm.weight.reshape(shape)
    bias = extractor.model.vision_model.post_layernorm.bias.reshape(shape)
    vector_ln = weight * vector + bias
    vector_bsvalues = 100 * extractor.projection(vector_ln) @ label_emb.T
    return vector_bsvalues


def get_consine(vector, extractor, s, label_emb):
    u = vector.mean(dim=-1, keepdim=True)
    vector = (vector - u) / torch.sqrt(s + 1e-12)
    shape = [1] * (len(vector.shape) - 1) + [-1]
    weight = extractor.model.vision_model.post_layernorm.weight.reshape(shape)
    bias = extractor.model.vision_model.post_layernorm.bias.reshape(shape)
    vector_ln = weight * vector  # + bias
    visual_embedding = extractor.projection(vector_ln)
    visual_embedding /= visual_embedding.norm(dim=-1, keepdim=True)
    label_emb = label_emb.reshape(-1, 1)
    cosine = visual_embedding @ label_emb
    return cosine[..., 0]


def dict_to_list(dictionary, div=1):
    return [[key, value / div] for key, value in dictionary.items()]


def sum_by_first_element(data_list):
    sums_by_col = {}

    for item in data_list:
        key_parts = item[0].split('_')
        col_index = int(key_parts[0])
        value = float(item[1])
        if col_index in sums_by_col:
            sums_by_col[col_index] += value
        else:
            sums_by_col[col_index] = value
    result_list = [[col, sum_value] for col, sum_value in sums_by_col.items()]
    result_list.sort(key=lambda x: x[0])

    return result_list

if __name__  == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    extractor = CLIPVisionEncoderWithOutputs(model_name="openai/clip-vit-large-patch14")
    model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14",
                                                        attn_implementation="eager").to(device)
    LAYER_NUM = len(extractor.model.vision_model.encoder.layers)
    HEAD_NUM = extractor.model.vision_model.encoder.layers[0].self_attn.num_heads
    HEAD_DIM = extractor.model.vision_model.encoder.layers[0].self_attn.head_dim


    def transfer_output(model_output):
        all_pos_layer_input = []
        all_pos_attn_output = []
        all_pos_residual_output = []
        all_pos_ffn_output = []
        all_pos_layer_output = []
        all_last_attn_subvalues = []
        all_pos_coefficient_scores = []
        all_attn_scores = []
        for layer_i in range(LAYER_NUM):
            if layer_i > 0:
                cur_layer_input = model_output[layer_i - 1]['layer_output']
            else:
                cur_layer_input = model_output[layer_i]['layer_input']
            cur_attn_output = model_output[layer_i]['attention_output']
            cur_residual_output = model_output[layer_i]['residual_output']
            cur_ffn_output = model_output[layer_i]['ffn_output']
            cur_layer_output = model_output[layer_i]['layer_output']
            cur_last_attn_subvalues = model_output[layer_i]['cls_subvalues']
            cur_coefficient_scores = model_output[layer_i]['coefficient']
            cur_attn_weights = model_output[layer_i]['attention_weights']
            all_pos_layer_input.append(cur_layer_input[0].tolist())
            all_pos_attn_output.append(cur_attn_output[0].tolist())
            all_pos_residual_output.append(cur_residual_output[0].tolist())
            all_pos_ffn_output.append(cur_ffn_output[0].tolist())
            all_pos_layer_output.append(cur_layer_output[0].tolist())
            all_last_attn_subvalues.append(cur_last_attn_subvalues[0].tolist())
            all_pos_coefficient_scores.append(cur_coefficient_scores[0].tolist())
            all_attn_scores.append(cur_attn_weights)
        return all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, \
            all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores


    # Knowledge insertion
    reference = "text"
    item = "geococcyx"
    for item in os.listdir("dataset/bird"):
        inputs = extractor.processor(text=[item], return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            text_features = model(**inputs).text_embeds
            text_features /= text_features.norm(dim=1, keepdim=True)
        K = []
        l = len(os.listdir(f"dataset/bird/{item}"))
        ffn_subvalue_list = {}
        for img in tqdm(range(l)):
            image_path = os.path.join(f"dataset/bird/{item}", str(img).zfill(3)+".jpg")
            with torch.no_grad():
                features, layer_outputs = extractor.extract_features(image_path, save_outputs=False)
                features_unit = features/features.norm(dim=1, keepdim=True)
                reference_emb = text_features if reference == "text" else features_unit
                all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, \
                all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(layer_outputs)
                all_ffn_subvalues = []
                layer_i = 23
                coefficient_scores = torch.tensor(all_last_attn_subvalues[23]).sum(dim=1).reshape(1, -1).to(device)
                K.append(coefficient_scores)
        k_star = torch.concat(K, dim=0).mean(dim=0)
        V = []
        ffn_subvalue_list = {}
        for img in tqdm(range(l)):
            image_path = os.path.join(f"dataset/bird_overlay/{item}", str(img).zfill(3)+".jpg")
            with torch.no_grad():
                features, layer_outputs = extractor.extract_features(image_path, save_outputs=False)
                features_unit = features/features.norm(dim=1, keepdim=True)
                reference_emb = text_features if reference == "text" else features_unit
                all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, \
                all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(layer_outputs)
                all_ffn_subvalues = []
                layer_i = 23
                coefficient_scores = torch.tensor(all_pos_attn_output[23])[[0]].reshape(1, -1).to(device)
                V.append(coefficient_scores)
        v_star = torch.concat(V, dim=0).mean(dim=0)
        v_star = v_star - extractor.model.vision_model.encoder.layers[23].self_attn.out_proj.bias.data
        W = extractor.model.vision_model.encoder.layers[23].self_attn.out_proj.weight.data
        C = torch.load("K_attn.pt")[23].cuda()
        k_star = k_star.unsqueeze(1)
        v_star = v_star.unsqueeze(1)
        A = (v_star -W@k_star)/((torch.linalg.inv(C)@k_star).T@k_star)
        W_star = W + A.reshape(-1,1)@(torch.linalg.inv(C)@k_star).reshape(1, -1)
        torch.save(W_star, f"{item}.pt")


    # knowledge removal
    item = "taylor swift"
    inputs = extractor.processor(text=[item], return_tensors="pt", padding=True).to(device)
    original_cos = []
    top1_cos = []
    with torch.no_grad():
        text_features = model(**inputs).text_embeds
        text_features /= text_features.norm(dim=1, keepdim=True)
        for i in tqdm(range(1, 11)):
            image_path = os.path.join("VisEnt/celebrity/taylor swift", str(img).zfill(3) + ".jpeg")
            image = Image.open(image_path)
            img_inputs = extractor.processor(images=image, return_tensors="pt").to(device)
            features, layer_outputs = extractor.extract_features(image_path, save_outputs=False)
            x1 = layer_outputs[LAYER_NUM - 1]['layer_output'][:, 0]
            u1 = x1.mean(-1, keepdim=True)
            s1 = (x1 - u1).pow(2).mean(-1, keepdim=True)
            u1, s1 = u1.reshape(-1), s1.reshape(-1)
            features_unit = features / features.norm(dim=1, keepdim=True)
            all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, \
                all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(layer_outputs)
            new_hidden = torch.tensor(all_pos_layer_output[-1][0]).to(device)
            neuron = get_fc2_params(extractor.model, 22)[1519]
            new_hidden -= max(torch.tensor(all_pos_coefficient_scores[22][0])[1519], 0) * neuron
            features_new = extractor.projection(extractor.model.vision_model.post_layernorm(new_hidden))
            features_new /= features_new.norm(dim=-1, keepdim=True)
            top1_cos.append((features_new @ text_features.T).item())
            original_cos.append((features_unit @ text_features.T).item())
        print("Cosine similarity before removal: ")
        print(sum(original_cos) / len(original_cos))
        print("Cosine similarity after removal: ")
        print(sum(top1_cos) / len(top1_cos))