import torch
import torch.nn.functional as F
from torch import nn

from backbones import build_conv_block
from subspace.abstract_alg import AbstractAlg
from utils.kernel_utils import KernelZoo as K
from utils.progress_utils import StageEnum
from utils.tensor_utils import split_support_query_for_x_in_cls, split_support_query


class BaseModel(nn.Module):
    def __init__(self, config, embedding_dim, num_class, head_hidden_dim=512):
        super(BaseModel, self).__init__()
        self.config = config
        self.input_channel = 3
        self.hidden_dim_conv = 64
        self.K = self.config.get("k", 2)  # current num of subspace
        self.n_dim = self.config.get("n_dim", 1)  # subspace dim
        self.conv_kernel_size = 3  # kernel size in conv layer
        self.hidden_dim = head_hidden_dim
        self.embedding_dim = embedding_dim

        self.encoder = nn.Sequential(
            build_conv_block(self.input_channel, self.hidden_dim_conv),
            build_conv_block(self.hidden_dim_conv, self.hidden_dim_conv),
        )

        self.num_conv = 2
        self.batch_norm = nn.ModuleList([nn.BatchNorm2d(self.hidden_dim_conv) for _ in range(self.num_conv*self.K)])
        self.convs = nn.Parameter(torch.randn(self.num_conv, self.K, self.conv_kernel_size * self.conv_kernel_size * self.hidden_dim_conv * self.hidden_dim_conv, self.n_dim))

class MUSML_Model(nn.Module):
    def __init__(self, config, embedding_dim, num_class, head_hidden_dim=512):
        super(MUSML_Model, self).__init__()
        self.config = config

        self.input_channel = 3
        self.num_conv = 2
        self.hidden_dim_conv = 64
        self.K = self.config.get("k", 2)  # current num of subspace
        self.n_dim = self.config.get("n_dim", 1)  # subspace dim

        self.conv_kernel_size = 3  # kernel size in conv layer
        self.hidden_dim = head_hidden_dim
        self.embedding_dim = embedding_dim
        self.base_model = BaseModel(config, embedding_dim, num_class, head_hidden_dim=head_hidden_dim)

        self.conv_shape = (self.hidden_dim_conv, self.hidden_dim_conv, self.conv_kernel_size, self.conv_kernel_size)
        self.input_shape = (self.hidden_dim_conv, self.input_channel, self.conv_kernel_size, self.conv_kernel_size)

        self.pool_kernel_size = 2
        self.max_pool = nn.MaxPool2d(kernel_size=self.pool_kernel_size)

    def forward_feature(self, x):
        return self.encoder.forward(x)


class MUSML(AbstractAlg):
    def __init__(self, config):
        super(MUSML, self).__init__(config=config)

    def is_last(self, layer_id):
        return layer_id == self.model.num_conv - 1

    def meta_update(self, x,  stage):
        n_way, n_support, n_query, y_support, y_query = self.get_basic_expt_info(stage)
        is_train = stage == StageEnum.TRAIN.value

        # E-step: find the best subspace
        best_subspace_id = 0
        best_score = float("inf")
        best_weight = None

        x_support, _ = split_support_query(x, n_support=n_support, n_query=n_query, n_way=n_way)

        for k in range(self.model.K):
            subspace_weights = nn.Linear(in_features=self.model.n_dim, out_features=1, bias=False)
            subspace_weights.weight.data.fill_(1.0 / self.model.n_dim)
            subspace_weights.cuda()

            subspace_optimizer = torch.optim.SGD(subspace_weights.parameters(), lr=0.05)

            for i_step in range(self.n_inner_step_dict[stage]):
                self.model.zero_grad()
                subspace_optimizer.zero_grad()
                x_support_mod = self.model.base_model.encoder.forward(x_support)

                # 4 conv layers
                for layer_id in range(self.model.num_conv):

                    conv_w = subspace_weights(self.model.base_model.convs[layer_id, k])
                    conv_w = conv_w.squeeze().contiguous().view(self.model.conv_shape)
                    x_support_mod = F.conv2d(x_support_mod, conv_w, bias=None, stride=1,
                                             padding=1, dilation=1, groups=1)
                    x_support_mod = self.model.base_model.batch_norm[k * self.model.num_conv + layer_id](
                        x_support_mod)

                    if not self.is_last(layer_id):
                        x_support_mod = self.model.max_pool(F.relu(x_support_mod))


                x_support_mod = x_support_mod.view(x_support_mod.size(0), -1)

                # build pn
                x_support_mod = x_support_mod - x_support_mod.mean(0)
                protos = x_support_mod.contiguous().view(n_way, n_support, -1).mean(dim=1)
                y_support_pred = K.compute_cosine(x_support_mod, protos, is_batch=False)
                support_loss = F.cross_entropy(y_support_pred, y_support)
                support_loss.backward(retain_graph=True) #
                subspace_optimizer.step()
            if support_loss.item() < best_score:
                best_subspace_id = k
                best_score = support_loss.item()
                best_weight = subspace_weights

        # 4 conv layers
        x_mod = self.model.base_model.encoder.forward(x)
        for layer_id in range(self.model.num_conv):
            conv_w = best_weight(self.model.base_model.convs[layer_id, best_subspace_id])
            conv_w = conv_w.squeeze().contiguous().view(self.model.conv_shape)
            x_mod = F.conv2d(x_mod, conv_w, bias=None, stride=1,
                             padding=1, dilation=1, groups=1)
            x_mod = self.model.base_model.batch_norm[best_subspace_id * self.model.num_conv + layer_id](x_mod)
            if not self.is_last(layer_id):
                x_mod = self.model.max_pool(F.relu(x_mod))

        x_mod = x_mod.view(x_mod.size(0), -1)
        x_mod = x_mod - x_mod.mean(0)

        # build pn
        x_f = x_mod.view(n_way, n_support + n_query, -1)
        x_support_mod, x_query_mod = split_support_query_for_x_in_cls(x_f, n_support)
        protos = x_support_mod.contiguous().view(n_way, n_support, -1).mean(dim=1)
        x_query_mod = x_query_mod.contiguous().view(n_way * n_query, -1)
        y_query_pred = K.compute_cosine(x_query_mod, protos, is_batch=False)
        query_loss = F.cross_entropy( y_query_pred, y_query)

        result_dict = {
            "y_query_pred": y_query_pred,
            "query_loss": query_loss,
            "meta_loss": query_loss,
            "best_cluster_id": best_subspace_id,
        }

        return result_dict

    def get_model(self):
        num_class = self.ways["train"]
        embedding_dim = 1600
        head_hidden_dim = self.config.get("head_hidden_dim", 512)
        return MUSML_Model(config=self.config, embedding_dim=embedding_dim, num_class=num_class, head_hidden_dim=head_hidden_dim)
