import collections
import torch
import torch.nn as nn
import CLIP_.clip as clip
from torchvision import models, transforms

def l2_layers(xs_conv_features, ys_conv_features, clip_model_name):
    return [torch.square(x_conv - y_conv).mean() for x_conv, y_conv in
            zip(xs_conv_features, ys_conv_features)]

class CLIPConvLoss(torch.nn.Module):
    def __init__(self):
        super(CLIPConvLoss, self).__init__()
        self.clip_model_name = "RN101"
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.clip_conv_loss_type = "L2"
        self.distance_metrics = {
                "L2": l2_layers
            }

        self.model, clip_preprocess = clip.load(
            self.clip_model_name, self.device, jit=False)

        self.visual_model = self.model.visual
        layers = list(self.model.visual.children())
        init_layers = torch.nn.Sequential(*layers)[:8]
        self.layer1 = layers[8]
        self.layer2 = layers[9]
        self.layer3 = layers[10]
        self.layer4 = layers[11]
        self.att_pool2d = layers[12]

        self.img_size = clip_preprocess.transforms[1].size
        self.model.eval()
        self.target_transform = transforms.Compose([
            transforms.ToTensor(),
        ])  # clip normalisation
        self.normalize_transform = transforms.Compose([
            clip_preprocess.transforms[0],  # Resize
            clip_preprocess.transforms[1],  # CenterCrop
            clip_preprocess.transforms[-1],  # Normalize
        ])

        self.model.eval()

        augemntations = []
        augemntations.append(transforms.Resize((224, 224)))
        augemntations.append(
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
        self.augment_trans = transforms.Compose(augemntations)

        self.clip_conv_layer_dims = None  # self.args.clip_conv_layer_dims
        self.counter = 0

    def forward(self, sketch, target):
        """
        Parameters
        ----------
        sketch: Torch Tensor [1, C, H, W]
        target: Torch Tensor [1, C, H, W]
        """

        x = sketch.to(self.device)
        y = target.to(self.device)
        sketch_augs, img_augs = [self.normalize_transform(x)], [
            self.normalize_transform(y)]

        augmented_pair = self.augment_trans(torch.cat([x, y]))
        sketch_augs.append(augmented_pair[0].unsqueeze(0))
        img_augs.append(augmented_pair[1].unsqueeze(0))

        xs = torch.cat(sketch_augs, dim=0).to(self.device)
        ys = torch.cat(img_augs, dim=0).to(self.device)

        _, xs_conv_features = self.forward_inspection_clip_resnet(
            xs.contiguous())
        _, ys_conv_features = self.forward_inspection_clip_resnet(
            ys.detach())

        conv_loss = self.distance_metrics[self.clip_conv_loss_type](
            xs_conv_features, ys_conv_features, self.clip_model_name)

        final_loss = conv_loss[-1]

        return final_loss

    def forward_inspection_clip_resnet(self, x):
        def stem(m, x):
            for conv, bn in [(m.conv1, m.bn1), (m.conv2, m.bn2), (m.conv3, m.bn3)]:
                x = m.relu(bn(conv(x)))
            x = m.avgpool(x)
            return x

        x = x.type(self.visual_model.conv1.weight.dtype)
        x = stem(self.visual_model, x)
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        y = self.att_pool2d(x4)
        return y, [x, x1, x2, x3, x4]
