import numpy as np
import torch
from PIL import Image
import os.path
import argparse
from pathlib import Path

from torch.utils.data import DataLoader
import tqdm


from typing import Callable
from PIL import Image
class ManualDataset():
    def __init__(self, data: np.ndarray, targets: np.ndarray, trans: Callable):
        self.data = data
        self.targets = targets
        self.trans = trans

    def __getitem__(self, idx):
        x, y = self.data[idx], self.targets[idx]
        x = self.trans(Image.fromarray(x))
        return x, y

    def __len__(self):
        return len(self.targets)


def get_args_parser():
    parser = argparse.ArgumentParser("Project Residual Stream", add_help=False)
    parser.add_argument("--batch_size", default=2, type=int, help="Batch size")
    # Model parameters
    parser.add_argument(
        "--model",
        default="ViT-H-14",
        type=str,
        metavar="MODEL",
        help="Name of model to use",
    )
    parser.add_argument("--pretrained", default="laion2b_s32b_b79k", type=str)
    # Dataset parameters
    parser.add_argument(
        "--data_path", default="/shared/group/ilsvrc", type=str, help="dataset path"
    )
    parser.add_argument(
        "--dataset", type=str, default="imagenet", help="imagenet, cub or waterbirds"
    )
    parser.add_argument("--num_workers", default=10, type=int)
    parser.add_argument(
        "--output_dir", default="./output_dir", help="path where to save"
    )
    parser.add_argument("--device", default="cuda:0", help="device to use for testing")

    # additional args
    parser.add_argument("--pretrained_path", default=None, type=str)
    parser.add_argument("--manual_npzdata_path", default=None, type=str)

    return parser


def main(args):
    """Calculates the projected residual stream for a dataset."""
    # model, _, preprocess = create_model_and_transforms(
    #     args.model, pretrained=args.pretrained
    # )

    from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
    model = CLIPVisionModelWithProjection.from_pretrained(args.model)
    processor = CLIPImageProcessor.from_pretrained(args.model)
    preprocess = lambda x: processor(x, return_tensors="pt")["pixel_values"].squeeze(0)

    if args.pretrained_path is not None:
        state_dict = torch.load(args.pretrained_path, map_location=torch.device("cpu"))
        new_state_dict = {}
        for name in state_dict.keys():
            if ("classifier" not in name) and ("text" not in name) and ("logit_scale" not in name):
                new_state_dict[name] = state_dict[name]
        state_dict = new_state_dict
        model.load_state_dict(state_dict)
        del state_dict

    model.to(args.device)
    model.eval()

    feature_map = {}
    @torch.no_grad()
    def feature_hook(module, inputs):
        """ adapted from:
            https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
        """
        assert len(inputs)==3 and isinstance(inputs[0], torch.Tensor) and inputs[1] is None and inputs[2] is None

        self = module.self_attn

        hidden_states = inputs[0]
        hidden_states = module.layer_norm1(hidden_states)

        attention_mask = inputs[1]
        causal_attention_mask = inputs[2]
        output_attentions = False

        if output_attentions:
            raise NotImplementedError

        # CLIP text model uses both `causal_attention_mask` and `attention_mask`
        if attention_mask is not None and causal_attention_mask is not None:
            raise NotImplementedError
        elif causal_attention_mask is not None:
            raise NotImplementedError
        else:
            attn_mask = attention_mask

        bsz, tgt_len, embed_dim = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            scale=self.scale,
        )

        attn_output = attn_output.transpose(1, 2)
        head_num, head_dim = attn_output.shape[-2:]

        tempp = attn_output[:, 0:1].view(bsz, 1, embed_dim)
        tempp = (tempp * self.out_proj.weight).view(bsz, embed_dim, head_num, head_dim).sum(dim=-1).transpose(-2, -1)
        tempp = tempp + self.out_proj.bias/head_num

        feature_map[module] = tempp.cpu()

    import transformers
    for name, module in model.named_modules():
        if isinstance(module, transformers.models.clip.modeling_clip.CLIPEncoderLayer):
            module.register_forward_pre_hook(feature_hook)

    assert args.manual_npzdata_path is not None
    npz = np.load(args.manual_npzdata_path)
    clean_ds = ManualDataset((npz["clean_data"]*255).round().astype(np.uint8).transpose([0,2,3,1]), npz["clean_targets"], preprocess)
    poison_ds = ManualDataset((npz["poison_data"]*255).round().astype(np.uint8).transpose([0,2,3,1]), npz["poison_targets"], preprocess)

    clean_loader = DataLoader(clean_ds, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers)
    poison_loader = DataLoader(poison_ds, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers)

    def attn_extraction_pipeline(dataloader):
        attention_results = []
        for i, (image, _) in enumerate(tqdm.tqdm(dataloader)):
            with torch.no_grad():
                feature_map.clear()
                output  = model(pixel_values=image.to(args.device))

                bs = len(image)
                before_ln = output.last_hidden_state[:, 0, :]
                ln_mean = torch.mean(before_ln, dim=1).view(bs, 1, 1)
                ln_var = torch.var(before_ln, dim=1, unbiased=False)
                last_ln_module = model.vision_model.post_layernorm
                weight = 1 / torch.sqrt(ln_var + last_ln_module.eps).unsqueeze(-1) * last_ln_module.weight.data
                weight = weight.unsqueeze(1)
                bias = last_ln_module.bias.data
                bias = bias.view(1, 1, len(bias))
                attentions = []
                for module in feature_map:
                    atn = feature_map[module].to(args.device)
                    head_num = atn.shape[1]
                    atn = (atn - ln_mean / head_num) * weight + bias / head_num
                    atn = model.visual_projection(atn).cpu().unsqueeze(1)
                    attentions.append(atn)

                attentions = torch.cat(attentions, dim=1).cpu().numpy()
                attention_results.append(attentions)

        return np.concatenate(attention_results)

    clean_attn_res = attn_extraction_pipeline(clean_loader)
    poison_attn_res = attn_extraction_pipeline(poison_loader)

    npzdata_name = os.path.basename(args.manual_npzdata_path).split(".")[0]
    model_name = args.model.replace("/", "-")

    np.save(os.path.join(args.output_dir, f"{npzdata_name}_attn_{model_name}_clean.npy"), clean_attn_res)
    np.save(os.path.join(args.output_dir, f"{npzdata_name}_attn_{model_name}_poison.npy"), poison_attn_res)


if __name__ == "__main__":
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
