import torch
from torch import nn
import math
from functools import reduce
from operator import mul
from torch.nn.modules.utils import _pair
from collections import OrderedDict

class PromptedEncoder(nn.Module):
    def __init__(self, encoder_model, num_tokens, num_classes, prompt_dropout_value=0.0):
        super().__init__()
        self.num_tokens = num_tokens
        self.num_classes = num_classes
        self.encoder = encoder_model
        
        self.prompt_dropout = nn.Dropout(prompt_dropout_value)
        
        try:
            embed_dim = self.encoder.pos_embed.shape[-1]
        except AttributeError:
            embed_dim = self.encoder.conv1.out_channels

        patch_size = _pair(getattr(self.encoder.patch_embed, 'patch_size', (16, 16)))
        val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + embed_dim))
        self.prompt_embeddings = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
        nn.init.uniform_(self.prompt_embeddings.data, -val, val)

        encoder_out_dim = self.encoder.num_features
        self.classification_head = nn.Sequential(
            nn.Linear(encoder_out_dim, 512),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

        self.trainable_keys = []
        for name, param in self.named_parameters():
            if "prompt_embeddings" in name or "classification_head" in name:
                param.requires_grad = True
                self.trainable_keys.append(name)
            else:
                param.requires_grad = False
                
    def incorporate_prompt(self, x):
        batch_size = x.shape[0]
        
        x = self.encoder.patch_embed(x)
        cls_token = self.encoder.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.encoder.pos_embed
        
        prompt = self.prompt_dropout(self.prompt_embeddings).expand(batch_size, -1, -1)
        x = torch.cat((
            x[:, :1, :], 
            prompt,      
            x[:, 1:, :]  
        ), dim=1)
        
        return self.encoder.pos_drop(x)

    def forward(self, x):
        x = self.incorporate_prompt(x)

        x = self.encoder.blocks(x)
        x = self.encoder.norm(x)
        
        cls_output = x[:, 0]
        logits = self.classification_head(cls_output)
        
        return logits