import torch
import torch.nn as nn
from models.base_vit import VisionTransformer, PatchEmbed, Block, resolve_pretrained_cfg, build_model_with_cfg, \
    checkpoint_filter_fn, Attention_LoRA
import helper


class ViT_lora_co(VisionTransformer):
    def __init__(
            self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
            embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
            drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
            embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, n_tasks=10):
        super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes,
                         global_pool=global_pool,
                         embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                         representation_size=representation_size,
                         drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
                         weight_init=weight_init, init_values=init_values,
                         embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn,
                         n_tasks=n_tasks)

    def forward(self, x, task_id, use_lora):
        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk(x, task_id, use_lora)

        x = self.norm(x)

        return x


def _create_vision_transformer(variant, pretrained=False, **kwargs):
    if kwargs.get('features_only', None):
        raise RuntimeError('features_only not implemented for Vision Transformer models.')

    # NOTE this extra code to support handling of repr size for in21k pretrained models
    # pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
    pretrained_cfg = resolve_pretrained_cfg(variant)
    default_num_classes = pretrained_cfg['num_classes']
    num_classes = kwargs.get('num_classes', default_num_classes)
    repr_size = kwargs.pop('representation_size', None)
    if repr_size is not None and num_classes != default_num_classes:
        repr_size = None

    model = build_model_with_cfg(
        ViT_lora_co, variant, pretrained,
        pretrained_cfg=pretrained_cfg,
        representation_size=repr_size,
        pretrained_filter_fn=checkpoint_filter_fn,
        pretrained_custom_load='npz' in pretrained_cfg['url'],
        **kwargs)
    return model


def split_qkv(model, logger, verbose=False):
    """Splits the `qkv` layer in the `Attention_LoRA` class into separate `q`, `k`, and `v` layers."""
    for name, module in model.named_modules():
        if isinstance(module, Attention_LoRA):
            # Extract existing qkv weights and biases
            qkv_weight = module.qkv.weight  # Shape: (3 * dim, dim)
            qkv_bias = module.qkv.bias if module.qkv.bias is not None else None

            dim = module.dim
            head_dim = dim // module.num_heads

            # Split the weights
            q_weight, k_weight, v_weight = torch.chunk(qkv_weight, chunks=3, dim=0)
            q_bias, k_bias, v_bias = None, None, None
            if qkv_bias is not None:
                q_bias, k_bias, v_bias = torch.chunk(qkv_bias, chunks=3, dim=0)

            # Replace qkv with separate q, k, and v
            module.query = nn.Linear(dim, dim, bias=qkv_bias is not None)
            module.key = nn.Linear(dim, dim, bias=qkv_bias is not None)
            module.value = nn.Linear(dim, dim, bias=qkv_bias is not None)

            # Assign the split weights and biases
            module.query.weight.data.copy_(q_weight)
            module.key.weight.data.copy_(k_weight)
            module.value.weight.data.copy_(v_weight)
            if qkv_bias is not None:
                module.query.bias.data.copy_(q_bias)
                module.key.bias.data.copy_(k_bias)
                module.value.bias.data.copy_(v_bias)

            # Remove the original qkv
            del module.qkv

            helper.log_and_print(f"Replaced `qkv` with `query`, `key`, `value` in {name}", logger, verbose)


class ViTPEARL(nn.Module):

    def __init__(self, args):
        super(ViTPEARL, self).__init__()

        model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, n_tasks=args.n_tasks, num_classes=0)
        self.image_encoder = _create_vision_transformer(
            args.vit_name,
            pretrained=True,
            img_size=args.image_size,
            **model_kwargs
        )
        self.image_encoder.remove_components()

        self.class_num = args.n_classes_per_task
        self.classifier_pool = nn.ModuleList([
            nn.Linear(model_kwargs["embed_dim"], self.class_num, bias=True)
            for i in range(args.n_tasks)
        ])

    def forward(self, image, task_id, use_lora=False, *args, **kwargs):
        image_features = self.image_encoder(image, task_id=task_id, use_lora=use_lora)
        image_features = image_features[:, 0, :]
        image_features = image_features.view(image_features.size(0), -1)
        logits = self.classifier_pool[task_id](image_features)
        return logits
