import torch
import torch.nn as nn
import numpy as np
from visual import Linear_fw, LayerNorm_fw
from transformers.activations import gelu
from data.meter import Meter
from utils import *

class TextEncoder(nn.Module):
    def __init__(self, roberta_model, roberta_tokenizer, roberta_config, image_embedding_dim):
        super().__init__()
        text_embedding_dim = roberta_config.hidden_size
        
        self.roberta_model = roberta_model
        self.roberta_tokenizer = roberta_tokenizer
        self.roberta_config = roberta_config
        self.dense = nn.Linear(text_embedding_dim, text_embedding_dim)
        self.layer_norm = nn.LayerNorm(text_embedding_dim, roberta_config.layer_norm_eps)
        self.project = nn.Linear(text_embedding_dim, image_embedding_dim)

    def forward(self, classnames):
        hidden_dim = self.roberta_config.hidden_size
        prompts = ['The appearance of ' + classname + ' is <mask>.' for classname in classnames]
        tokenized_results = [self.roberta_tokenizer(prompt, return_tensors='pt', padding='max_length', max_length=16)
                             for prompt in prompts]
        
        tokenized_prompts = torch.cat([tokenized_result['input_ids'] for tokenized_result in tokenized_results]).cuda()
        
        tokenized_attention_masks = torch.cat([tokenized_result['attention_mask'] for tokenized_result in tokenized_results]).cuda()
        mask_positions = (tokenized_attention_masks.sum(dim=1) - 3)
        mask_gather_ids = torch.repeat_interleave(mask_positions.unsqueeze(dim=1).unsqueeze(dim=2), repeats=hidden_dim, dim=2)
        
        output = self.roberta_model(tokenized_prompts, attention_mask=tokenized_attention_masks)
        sequence_output = output[0]
        mask_output = torch.gather(sequence_output, 1, mask_gather_ids).squeeze(dim=1)
        
        text_embeddings = self.dense(mask_output)
        text_embeddings = gelu(text_embeddings)
        text_embeddings = self.layer_norm(text_embeddings)
        text_embeddings = self.project(text_embeddings)
        text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
        
        return text_embeddings


class MetricImage2Text(nn.Module):    
    def __init__(self, image_embedding_dim):
        super().__init__()
        self.metric = Linear_fw(image_embedding_dim, image_embedding_dim)
        
    def forward(self, image_embeddings, text_embeddings):
        text_embeddings = self.metric(text_embeddings)
        
        score = image_embeddings @ text_embeddings.t()
        return score


class EnsembleModel(nn.Module):
    def __init__(self, roberta_model, roberta_tokenizer ,roberta_config, model_func):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        self.feature = model_func()
        self.text_encoder = TextEncoder(roberta_model, roberta_tokenizer, roberta_config, self.feature.feat_dim)
        self.metric_image2text = MetricImage2Text(self.feature.feat_dim)

    def forward(self, x, classnames):
        img_embeddings = self.feature.forward(x)
        img_embeddings = img_embeddings / img_embeddings.norm(dim=-1, keepdim=True)

        text_embeddings = self.text_encoder(classnames)

        return img_embeddings, text_embeddings

    def train_loop(self, train_loader, optimizer, params):
        loss_meter = Meter()
        acc_meter = Meter()
        optimizer.zero_grad()
        for i, (x, y) in enumerate(train_loader):
            classnames = y[0]
            x = x.cuda()

            x_support = x[:, :params.n_shot]
            x_query = x[:, params.n_shot:]

            x_support = x_support.contiguous().view(params.n_way * params.n_shot, *x_support.size()[2:])
            x_query = x_query.contiguous().view(params.n_way * params.n_query, *x_query.size()[2:])

            fast_parameters = list(parameter for parameter in self.metric_image2text.parameters())

            for weight in self.parameters():
                weight.fast = None
            self.zero_grad()

            for update_step in range(params.update_num):
                self.feature.eval()
                
                images_embeddings_support, text_embeddings = self.forward(x_support, classnames)
                logits_support = self.set_forward(params.n_way, params.n_shot, images_embeddings_support, text_embeddings)
                logits_phi = logits_support - params.margin
                y_query = torch.tensor(np.repeat(range(params.n_way), params.n_shot), dtype=torch.int64)
                logits_support_margin = torch.where(one_hot(y_query, params.n_way).bool().cuda(), logits_phi, logits_support) / params.temperature_inner
                
                loss_support = self.set_forward_loss(logits_support_margin)
                grad = torch.autograd.grad(loss_support, fast_parameters)

                fast_parameters = []
                for k, weight in enumerate(list(parameter for parameter in self.metric_image2text.parameters())):
                    if weight.fast is None:
                        weight.fast = weight - params.update_lr * grad[k]
                    else:
                        weight.fast = weight.fast - params.update_lr * grad[k]
                    fast_parameters.append(weight.fast)
            
            self.feature.train()
            
            images_embeddings_query, text_embeddings = self.forward(x_query, classnames)
            logits_query = self.set_forward(params.n_way, params.n_query, images_embeddings_query, text_embeddings) / params.temperature_outer
            loss_query = self.set_forward_loss(logits_query)

            correct_this, count_this = self.correct(logits_query)
            avg_acc = correct_this / count_this * 100

            loss_query.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_meter.update(loss_query.item())
            acc_meter.update(avg_acc)

        return loss_meter.avg(), acc_meter.avg(), acc_meter.confidence_interval()


    def test_loop(self, test_loader, params):
        loss_meter = Meter()
        acc_meter = Meter()
        for i, (x, y) in enumerate(test_loader):
            classnames = y[0]
            x = x.cuda()

            x_support = x[:, :params.n_shot]
            x_query = x[:, params.n_shot:]

            x_support = x_support.contiguous().view(params.n_way * params.n_shot, *x_support.size()[2:])
            x_query = x_query.contiguous().view(params.n_way * params.n_query, *x_query.size()[2:])

            fast_parameters = list(parameter for parameter in self.metric_image2text.parameters())

            for weight in self.parameters():
                weight.fast = None
            self.zero_grad()

            for update_step in range(params.update_num):
                images_embeddings_support, text_embeddings = self.forward(x_support, classnames)
                logits_support = self.set_forward(params.n_way, params.n_shot, images_embeddings_support, text_embeddings)
                logits_phi = logits_support - params.margin
                y_query = torch.tensor(np.repeat(range(params.n_way), params.n_shot), dtype=torch.int64)
                logits_support_margin = torch.where(one_hot(y_query, params.n_way).bool().cuda(), logits_phi, logits_support) / params.temperature_inner
                
                loss_support = self.set_forward_loss(logits_support_margin)
                grad = torch.autograd.grad(loss_support, fast_parameters)

                fast_parameters = []
                for k, weight in enumerate(list(parameter for parameter in self.metric_image2text.parameters())):
                    if weight.fast is None:
                        weight.fast = weight - params.update_lr * grad[k]
                    else:
                        weight.fast = weight.fast - params.update_lr * grad[k]
                    fast_parameters.append(weight.fast)

            with torch.no_grad():
                images_embeddings_query, text_embeddings = self.forward(x_query, classnames)
                logits_query = self.set_forward(params.n_way, params.n_query, images_embeddings_query, text_embeddings) / params.temperature_outer

                loss_query = self.set_forward_loss(logits_query)

                correct_this, count_this = self.correct(logits_query)
                avg_acc = correct_this / count_this * 100
                
                loss_meter.update(loss_query.item())
                acc_meter.update(avg_acc)

        return loss_meter.avg(), acc_meter.avg(), acc_meter.confidence_interval()


    def set_forward(self, n_way, n, image_embeddings, text_embeddings):
        image_embeddings = image_embeddings.contiguous().view(n_way * n, -1)
        logits_image2text = self.metric_image2text(image_embeddings, text_embeddings)
        return logits_image2text

    def set_forward_loss(self, logits):
        y_query = torch.from_numpy(np.repeat(range(logits.size(1)), logits.size(0) / logits.size(1)))
        y_query = y_query.cuda()
        loss = self.loss_fn(logits, y_query.long())
        return loss

    def correct(self, logits):
        y_query = np.repeat(range(logits.size(1)), logits.size(0) / logits.size(1))
        topk_logits, topk_labels = logits.data.topk(1, 1, True, True)
        topk_ind = topk_labels.cpu().numpy()
        top1_correct = np.sum(topk_ind[:, 0] == y_query)
        return float(top1_correct), len(y_query)
    