import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoderLayer, TransformerEncoder
import torchvision.transforms.functional as TF

from .elmes import get_elmes
from .freeze import freeze_clip_layers
from .components import ConvPreEncoder, LinearHead, get_encoder

class ClassificationWrapper(nn.Module):
    """
    Wraps a backbone with optional support-set logic and a
    Transformer/Linear head.
    """
    def __init__(
        self,
        support_size,
        model,
        num_classes,
        output_dim,
        img_dim,
        is_clip,
        head_type='transformer',
        fe_dim=768,
        t_dim=768,
        elmes_dim=256,
        no_query_fusion=True,
        scale_factors=[1.0],
        in_channels=3,
        out_channels=3,
        randomize_class_order=False,
        use_adapter=False,
    ):
        super().__init__()
        self.support_set_size = support_size
        self.model = model
        self.num_classes = num_classes
        self.img_dim = img_dim
        self.scale_factors = scale_factors
        self.is_clip = is_clip
        self.randomize_class_order = randomize_class_order
        self.no_query_fusion = no_query_fusion

        head_input_dim = t_dim + (elmes_dim if support_size != 0 else 0)

        # Optional adapter
        if use_adapter and head_type == 'transformer':
            self.adapter = PostEncoderAdapter(t_dim)  # noqa: F821 (not defined by design)
        else:
            self.adapter = None

        if support_size != 0:
            self.unk_emb = nn.Parameter(torch.zeros(1, 1, elmes_dim))
            self.elmes_scale = nn.Parameter(torch.ones(1))
            context_dim = fe_dim + elmes_dim if not is_clip else t_dim
            self.context_projection = nn.Linear(context_dim, elmes_dim)
            self.label_elmes = nn.Parameter(get_elmes(elmes_dim, num_classes),
                                            requires_grad=False)

            encoder_layer = TransformerEncoderLayer(
                d_model=fe_dim, nhead=4, dim_feedforward=4 * fe_dim, batch_first=True
            )
            self.prototype_generator = TransformerEncoder(encoder_layer, num_layers=3)

        if head_type == 'transformer':
            self.head = get_encoder(
                size='small',
                image_dim=head_input_dim,
                num_classes=num_classes,
                label_elmes=(support_size != 0),
                orig=False
            )
        else:
            self.head = LinearHead(head_input_dim, num_classes)

        if is_clip:
            freeze_clip_layers(self.model)

        self.pre_encoder = ConvPreEncoder(
            in_channels=in_channels, out_channels=out_channels, output_dim=img_dim
        )

    def forward(self, x, ref, targets, support_set):
        feats = []
        crop_size = 224
        for scale in self.scale_factors:
            sx = x if scale == 1.0 else F.interpolate(
                x, scale_factor=scale, mode='bilinear', align_corners=False
            )
            if ref is not None:
                sr = ref if scale == 1.0 else F.interpolate(
                    ref, scale_factor=scale, mode='bilinear', align_corners=False
                )
                sx = torch.cat((sx, sr), dim=1)

            sx = TF.center_crop(sx, crop_size)
            f = self.model.encode_image(sx) if self.is_clip else self.model(sx)
            feats.append(f)

        multi = torch.stack(feats, dim=1)
        aggregated = multi.mean(dim=1)
        if self.adapter is not None:
            aggregated = self.adapter(aggregated)

        if self.support_set_size != 0:
            B, C, K, *_ = support_set.shape
            sup_flat = support_set.view(B * C * K, *support_set.shape[3:])
            sup_feats = []
            for scale in self.scale_factors:
                sf = sup_flat if scale == 1.0 else F.interpolate(
                    sup_flat, scale_factor=scale, mode='bilinear', align_corners=False
                )
                sf = TF.center_crop(sf, crop_size)
                sf = (self.model.encode_image(sf) if self.is_clip else self.model(sf))
                sup_feats.append(sf)
            sup_emb = torch.stack(sup_feats, dim=1).mean(dim=1)
            sup_emb = sup_emb.view(B, C, K, -1)

            reps = []
            if self.no_query_fusion:
                for i in range(C):
                    proto_set = self.prototype_generator(sup_emb[:, i])
                    reps.append(proto_set.mean(dim=1, keepdim=True))
            else:
                for i in range(C):
                    group = torch.cat((aggregated.unsqueeze(1), sup_emb[:, i]), dim=1)
                    proto = self.prototype_generator(group)
                    reps.append(proto[:, 0:1])
            reps = torch.cat(reps, dim=1)

            label_elmes = self.label_elmes
            if self.randomize_class_order:
                perm = torch.randperm(C, device=reps.device)
                reps = reps[:, perm]
                label_elmes = label_elmes[perm]

            labels = torch.arange(C, device=reps.device)
            one_hot = F.one_hot(labels, num_classes=C).float()
            one_hot = one_hot.unsqueeze(0).repeat(B, 1, 1)

            context = torch.cat(
                (reps, label_elmes.unsqueeze(0).repeat(B, 1, 1)), dim=-1
            )
            context = self.context_projection(context)
            context = context / (context.norm(dim=-1, keepdim=True) + 1e-8)
            context = context * self.elmes_scale
            label_embeds = one_hot @ context

            features = torch.cat([aggregated.unsqueeze(1), reps], dim=1)
            batched_label_embeddings = torch.cat(
                [self.unk_emb.repeat(B, 1, 1), label_embeds], dim=1
            )

            demonstrations = torch.cat([features, batched_label_embeddings], dim=-1)
        else:
            demonstrations = aggregated.unsqueeze(1)

        logits = self.head(demonstrations)
        return logits

class EnsembleClassifier(nn.Module):
    """
    Simple arithmetic ensemble of two wrappers.
    """
    def __init__(self, wrapper1: nn.Module, wrapper2: nn.Module):
        super().__init__()
        self.wrapper1 = wrapper1
        self.wrapper2 = wrapper2

    def forward(self, x, ref, targets, support_set):
        logits1 = self.wrapper1(x, ref, targets, support_set)
        logits2 = self.wrapper2(x, ref, targets, support_set)
        return (logits1 + logits2) / 2.0
