import torch
import torch.nn as nn
import torch.nn.functional as F

import models.encoder as Encoder
from collections import OrderedDict 

### for the teacher mde
class EMA():
    def __init__(self):
        super().__init__()

    def update_average(self, old, new, beta):
        if old is None:
            return new
        return old * beta + (1 - beta) * new

### only support vit-base
class UniMedI(nn.Module):
    def __init__(self, image2D_size=224, hidden_dim=2048, emb_dim=128, drop_path_rate=0.0,
                 patch_out_dim=8192, mask_type='attn', mask_ratio = 0., with_distill=False, freeze_bert=False):
        super().__init__()
        # text model
        self.text_encoder_q = Encoder.BertEncoder(output_dim=emb_dim, freeze_bert=freeze_bert)
        # different pre-training strategy
        if not with_distill:
            self.image_encoder_q_student = Encoder.create_vit(image2D_size=image2D_size, hidden_dim=hidden_dim, output_dim=emb_dim, patch_out_dim=patch_out_dim,
                                                        with_distill=with_distill, masked_im_modeling=False, mask_ratio=mask_ratio, drop_path_rate=drop_path_rate)  
            checkpoint = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", map_location="cpu", check_hash=True)
            state_dict = checkpoint["model"]
            msg = self.image_encoder_q_student.load_state_dict(state_dict, strict=False)
        else:
            self.image_encoder_q_student = Encoder.create_vit(image2D_size=image2D_size, hidden_dim=hidden_dim, output_dim=emb_dim, patch_out_dim=patch_out_dim,
                                                        with_distill=with_distill, masked_im_modeling=True, mask_ratio=mask_ratio, drop_path_rate=drop_path_rate) 
            self.image_encoder_q_teacher = Encoder.create_vit(image2D_size=image2D_size, hidden_dim=hidden_dim, output_dim=emb_dim, patch_out_dim=patch_out_dim,
                                                        with_distill=with_distill, masked_im_modeling=False, mask_ratio=mask_ratio, drop_path_rate=drop_path_rate) 
            checkpoint = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", map_location="cpu", check_hash=True)
            state_dict = checkpoint["model"]
            msg = self.image_encoder_q_student.load_state_dict(state_dict, strict=False)
            ### setting for distill --- teacher model
            self.ema_updater = EMA()
            msg = self.image_encoder_q_teacher.load_state_dict(self.image_encoder_q_student.state_dict(), strict=False)
            for p in self.image_encoder_q_teacher.parameters():
                p.requires_grad = False

        self.mask_type = mask_type
        self.with_distill = with_distill

    def update_moving_average(self, beta):
        ### TODO: two updating methods
        ### EMA all 
        # names_q, params_q, names_k, params_k = [], [], [], []
        # for name_q, param_q in self.image_encoder_q_student.named_parameters():
        #     names_q.append(name_q)
        #     params_q.append(param_q)
        # for name_k, param_k in self.image_encoder_q_teacher.named_parameters():
        #     names_k.append(name_k)
        #     params_k.append(param_k)
        # names_common = list(set(names_q) & set(names_k))
        # params_q = [param_q for name_q, param_q in zip(names_q, params_q) if name_q in names_common]
        # params_k = [param_k for name_k, param_k in zip(names_k, params_k) if name_k in names_common]
        # for param_q, param_k in zip(params_q, params_k):            
        #     up_weight, old_weight = param_q.detach().data, param_k.data
        #     param_k.data = self.ema_updater.update_average(old_weight, up_weight, beta)

        ### EMA w/o patch embedding
        names_q, params_q, names_k, params_k = [], [], [], []
        names_q_same, params_q_same, names_k_same, params_k_same = [], [], [], []
        for name_q, param_q in self.image_encoder_q_student.named_parameters():
            if not name_q.startswith('patch_embed'):
                names_q.append(name_q)
                params_q.append(param_q)
            else:
                names_q_same.append(name_q)
                params_q_same.append(param_q)
        for name_k, param_k in self.image_encoder_q_teacher.named_parameters():
            if not name_k.startswith('patch_embed'):
                names_k.append(name_k)
                params_k.append(param_k)
            else:
                names_k_same.append(name_k)
                params_k_same.append(param_k)
        names_common = list(set(names_q) & set(names_k))
        names_common_same = list(set(names_q_same) & set(names_k_same))
        params_q = [param_q for name_q, param_q in zip(names_q, params_q) if name_q in names_common]
        params_k = [param_k for name_k, param_k in zip(names_k, params_k) if name_k in names_common]
        params_q_same = [param_q_same for name_q_same, param_q_same in zip(names_q_same, params_q_same) if name_q_same in names_common_same]
        params_k_same = [param_k_same for name_k_same, param_k_same in zip(names_k_same, params_k_same) if name_k_same in names_common_same]
        for param_q, param_k in zip(params_q, params_k):   
            up_weight, old_weight = param_q.detach().data, param_k.data
            param_k.data = self.ema_updater.update_average(old_weight, up_weight, beta)
        for param_q, param_k in zip(params_q_same, params_k_same):            
            up_weight, old_weight = param_q.detach().data, param_k.data
            param_k.data = self.ema_updater.update_average(old_weight, up_weight, 0.0)

        ### freeze patch embedding
        # names_q, params_q, names_k, params_k = [], [], [], []
        # for name_q, param_q in self.image_encoder_q_student.named_parameters():
        #     if not name_q.startswith('patch_embed'):
        #         names_q.append(name_q)
        #         params_q.append(param_q)
        # for name_k, param_k in self.image_encoder_q_teacher.named_parameters():
        #     if not name_k.startswith('patch_embed'):
        #         names_k.append(name_k)
        #         params_k.append(param_k)
        # names_common = list(set(names_q) & set(names_k))
        # params_q = [param_q for name_q, param_q in zip(names_q, params_q) if name_q in names_common]
        # params_k = [param_k for name_k, param_k in zip(names_k, params_k) if name_k in names_common]
        # for param_q, param_k in zip(params_q, params_k):   
        #     up_weight, old_weight = param_q.detach().data, param_k.data
        #     param_k.data = self.ema_updater.update_average(old_weight, up_weight, beta)

    def forward(self, batch, status='train'):
        # valid the performance of our model
        if status == 'valid':
            self.image_encoder_q_student.masked_im_modeling = False

            img_feat_q_student, _ = self.image_encoder_q_student(batch["imgs"])
            img_feat_q = img_feat_q_student[:, 0, :]
            # use the student's global embedding
            img_emb_q = self.image_encoder_q_student.global_embedding(img_feat_q)
            img_emb_q = F.normalize(img_emb_q, dim=-1)

            report_feat_q, _, _, _ = self.text_encoder_q(batch["caption_ids"], batch["attention_mask"], batch["token_type_ids"])
            report_emb_q = self.text_encoder_q.global_embed(report_feat_q)
            report_emb_q = F.normalize(report_emb_q, dim=-1)

            self.image_encoder_q_student.masked_im_modeling = True
            return img_emb_q, report_emb_q
        # Forward of query image encoder
        if self.with_distill:
            img_feat_q_teacher, attn = self.image_encoder_q_teacher(batch["imgs"])
            if self.mask_type == 'rand':
                img_feat_q_student, mask = self.image_encoder_q_student(batch["imgs"], attn=None)
            else:
                img_feat_q_student, mask = self.image_encoder_q_student(batch["imgs"], attn=attn)
        else:
            img_feat_q_student, _ = self.image_encoder_q_student(batch["imgs"])
        img_feat_q = img_feat_q_student[:, 0, :]
        img_emb_q = self.image_encoder_q_student.global_embedding(img_feat_q)
        # img_emb_q = img_feat_q @ self.image_encoder_q_student.global_embedding
        img_emb_q = F.normalize(img_emb_q, dim=-1)

        # Forward of query text encoder
        report_feat_q, _, _, _ = self.text_encoder_q(batch["caption_ids"], batch["attention_mask"], batch["token_type_ids"])
        report_emb_q = self.text_encoder_q.global_embed(report_feat_q)
        # report_emb_q = report_feat_q @ self.text_encoder_q.global_embedding
        report_emb_q = F.normalize(report_emb_q, dim=-1)

        if self.with_distill:
            student_output = self.image_encoder_q_student.local_embedding(img_feat_q_student[:, 0, :])
            teacher_output = self.image_encoder_q_teacher.local_embedding(img_feat_q_teacher[:, 0, :])
            return img_emb_q, report_emb_q, student_output, teacher_output, mask
        else:
            return img_emb_q, report_emb_q



