from copy import deepcopy
import math
import torch
from torch import nn
import torch.nn.functional as F


class TaskGate(nn.Module):
    def __init__(self, feature_dim, hidden_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, image_features):
        context = image_features.mean(dim=0)
        alpha = self.mlp(context) 
        return alpha


class AdaptFormer(nn.Module):
    def __init__(self, in_dim, bottle_dim, dtype=None, scale: float = 1.0):
        super().__init__()
        self.ln = nn.LayerNorm(in_dim, dtype=dtype)
        self.down_proj = nn.Linear(in_dim, bottle_dim, dtype=dtype)
        self.relu = nn.ReLU(inplace=True)
        self.up_proj = nn.Linear(bottle_dim, in_dim, dtype=dtype)
        self.scale = scale
        nn.init.zeros_(self.up_proj.weight)
        nn.init.zeros_(self.down_proj.bias)
        nn.init.zeros_(self.up_proj.bias)

    @property
    def dtype(self):
        return self.ln.weight.dtype

    def forward(self, x):
        x = self.ln(x)
        x = self.down_proj(x)
        x = self.relu(x)
        x = self.up_proj(x)
        x = x * self.scale
        return x
        
class Tuner(nn.Module):
    def __init__(self, clip_model, use_image_tuner, use_text_tuner):
        super().__init__()
        vit_model = clip_model.visual
        dtype = clip_model.dtype
        
        n_layers_vis = len(vit_model.transformer.resblocks)
        emb_dim_vis = vit_model.positional_embedding.shape[1]
        adapter_dim_vis = 64
        if use_image_tuner:
            self.image_tuner_new = nn.ModuleList([
                    *[AdaptFormer(in_dim=emb_dim_vis, bottle_dim=adapter_dim_vis, dtype=dtype, scale=0.5) for _ in range(n_layers_vis)],
                ])
            self.image_tuner_stable = nn.ModuleList([
                    *[AdaptFormer(in_dim=emb_dim_vis, bottle_dim=adapter_dim_vis, dtype=dtype, scale=0.1) for _ in range(n_layers_vis)],
                ])            
        else:
            self.image_tuner_new = None
            self.image_tuner_stable = None
        
        n_layers_text = len(clip_model.transformer.resblocks)
        emb_dim_text = clip_model.positional_embedding.shape[1]
        adapter_dim_text = 64
        if use_text_tuner:
            self.text_tuner = nn.ModuleList([
                    *[AdaptFormer(in_dim=emb_dim_text, bottle_dim=adapter_dim_text, dtype=dtype, scale=0.5) for _ in range(n_layers_text)],
                ])
        else:
            self.text_tuner = None
        
        self.task_gate = TaskGate(feature_dim=512) 