import os
from collections import OrderedDict
from types import SimpleNamespace
import torch
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import torch.nn.functional as F
from .module_clip import CLIP, convert_weights, _PT_NAME
from .module_cross import CrossModel, Transformer as TransformerClip
from .until_module import LayerNorm, AllGather, AllGather2, CrossEn
from tvr.models.co_attention_transformer_module import Co_attention_block
import numpy as np
allgather = AllGather.apply
allgather2 = AllGather2.apply


class ResidualLinear(nn.Module):
    def __init__(self, d_int: int):
        super(ResidualLinear, self).__init__()

        self.fc_relu = nn.Sequential(nn.Linear(d_int, d_int),
                                     nn.ReLU(inplace=True))

    def forward(self, x):
        x = x + self.fc_relu(x)
        return x


class MBDA(nn.Module):
    def __init__(self, config):
        super(MBDA, self).__init__()

        self.config = config
        self.agg_module = getattr(config, 'agg_module', 'meanP')
        backbone = getattr(config, 'base_encoder', "ViT-B/32")

        embed_dim=512
        transformer_heads=8
        self.co_connetion_transformer_model_block = nn.Sequential(*[Co_attention_block(hidden_size=embed_dim, 
                                                num_attention_heads=transformer_heads, dropout_rate=0.1) for i in range(1)])

        # Add GLU gates for two branches
        self.glu_spa = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.GLU(dim=-1)
        )
        self.glu_tem = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 2),
            nn.GLU(dim=-1)
        )

        with torch.no_grad():
            self.glu_spa[0].weight.data = torch.cat([torch.eye(embed_dim), torch.zeros(embed_dim, embed_dim)], dim=0)
            self.glu_spa[0].bias.data = torch.zeros(embed_dim * 2)
            self.glu_tem[0].weight.data = torch.cat([torch.eye(embed_dim), torch.zeros(embed_dim, embed_dim)], dim=0)
            self.glu_tem[0].bias.data = torch.zeros(embed_dim * 2)

        assert backbone in _PT_NAME
        model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[backbone])
        if os.path.exists(model_path):
            FileNotFoundError
        try:
            # loading JIT archive
            model = torch.jit.load(model_path, map_location="cpu").eval()
            state_dict = model.state_dict()
        except RuntimeError:
            state_dict = torch.load(model_path, map_location="cpu")

        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len(
            [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size

        embed_dim = state_dict["text_projection"].shape[1]
        context_length = state_dict["positional_embedding"].shape[0]
        vocab_size = state_dict["token_embedding.weight"].shape[0]
        transformer_width = state_dict["ln_final.weight"].shape[0]
        transformer_heads = transformer_width // 64
        transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))

        self.clip = CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
                         context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)

        if torch.cuda.is_available():
            convert_weights(self.clip)  # fp16

        cross_config = SimpleNamespace(**{
            "attention_probs_dropout_prob": 0.1,
            "hidden_act": "gelu",
            "hidden_dropout_prob": 0.1,
            "hidden_size": 512,
            "initializer_range": 0.02,
            "intermediate_size": 2048,
            "max_position_embeddings": 128,
            "num_attention_heads": 8,
            "num_hidden_layers": 4,
            "vocab_size": 512,
            "soft_t": 0.07,
        })
        cross_config.max_position_embeddings = context_length
        cross_config.hidden_size = transformer_width
        self.cross_config = cross_config
            
        if self.agg_module in ["seqLSTM", "seqTransf"]:
            self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings,
                                                          cross_config.hidden_size)
            if self.agg_module == "seqTransf":
                self.transformerClip = TransformerClip(width=transformer_width,
                                                       layers=config.num_hidden_layers,
                                                       heads=transformer_heads)
            if self.agg_module == "seqLSTM":
                self.lstm_visual = nn.LSTM(input_size=cross_config.hidden_size, hidden_size=cross_config.hidden_size,
                                           batch_first=True, bidirectional=False, num_layers=1)

        self.loss_fct = CrossEn(config)
        
        self.apply(self.init_weights)  # random init must before loading pretrain
        self.clip.load_state_dict(state_dict, strict=False)

        ## ===> Initialization trick [HARD CODE]
        new_state_dict = OrderedDict()
                
        if self.agg_module in ["seqLSTM", "seqTransf"]:
            contain_frame_position = False
            for key in state_dict.keys():
                if key.find("frame_position_embeddings") > -1:
                    contain_frame_position = True
                    break
            if contain_frame_position is False:
                for key, val in state_dict.items():
                    if key == "positional_embedding":
                        new_state_dict["frame_position_embeddings.weight"] = val.clone()
                        continue
                    if self.agg_module in ["seqTransf"] and key.find("transformer.resblocks") == 0:
                        num_layer = int(key.split(".")[2])
                        # cut from beginning
                        if num_layer < config.num_hidden_layers:
                            new_state_dict[key.replace("transformer.", "transformerClip.")] = val.clone()
                            continue

        self.num_captions = 30
        self.sentence_position_embeddings = nn.Embedding(self.num_captions, embed_dim)  
        self.caption_transformer_layer = nn.TransformerEncoderLayer(d_model=transformer_width,
                                                                    nhead=transformer_heads,
                                                                    dim_feedforward=transformer_width, dropout=0,
                                                                    batch_first=True)
        self.caption_transformer_encoder = nn.TransformerEncoder(self.caption_transformer_layer, num_layers=2)
        self.text_position_embeddings = nn.Embedding(context_length, embed_dim)

        self.load_state_dict(new_state_dict, strict=False)  # only update new state (seqTransf/seqLSTM/tightTransf)
        ## <=== End of initialization trick

    def forward(self, text_ids, text_mask, t_data, vcap_mask, video, video_mask=None, idx=None, global_step=0):
        text_ids = text_ids.view(-1, text_ids.shape[-1])
        text_mask = text_mask.view(-1, text_mask.shape[-1])
        t_data = t_data.view(-1, t_data.shape[-1])
        vcap_mask = vcap_mask.view(-1, vcap_mask.shape[-1])
        video_mask = video_mask.view(-1, video_mask.shape[-1])
        # B x N_v x 3 x H x W - >  (B x N_v) x 3 x H x W
        video = torch.as_tensor(video).float()
        if len(video.size()) == 5:
            b, n_v, d, h, w = video.shape
            video = video.view(b * n_v, d, h, w)
        else:
            b, pair, bs, ts, channel, h, w = video.shape
            video = video.view(b * pair * bs * ts, channel, h, w)

        cls, text_feat, td_cls, video_feat, st_feat, video_cls = self.get_text_video_feat(text_ids, text_mask, t_data, vcap_mask, video, video_mask, shaped=True)

        if self.training:
            if torch.cuda.is_available():  # batch merge here
                idx = allgather(idx, self.config)
                text_feat = allgather(text_feat, self.config)
                video_feat = allgather(video_feat, self.config)
                text_mask = allgather(text_mask, self.config)
                video_mask = allgather(video_mask, self.config)
                vcap_mask = allgather(vcap_mask, self.config)
                cls = allgather(cls, self.config)
                td_cls = allgather(td_cls, self.config)
                video_cls = allgather(video_cls, self.config)
                st_feat = allgather(st_feat, self.config)
                torch.distributed.barrier()  # force sync

            idx = idx.view(-1, 1)
            idx_all = idx.t()
            pos_idx = torch.eq(idx, idx_all).float()
            sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
            logit_scale = self.clip.logit_scale.exp()
            loss = 0.

            t2v_logits, v2t_logits, tem_logits, spa_logits, b_logits = self.get_similarity_logits(text_feat, cls, td_cls, st_feat, video_feat, video_cls,
                                                                    text_mask, video_mask, shaped=True)
            qt_logits = self.get_text_titles_similarity_logits(cls, td_cls)
            
            loss_t2v = self.loss_fct(t2v_logits * logit_scale)
            qt_loss_t2v = self.loss_fct(qt_logits * logit_scale)
            tem_loss_t2v = self.loss_fct(tem_logits * logit_scale)
            spa_loss_t2v = self.loss_fct(spa_logits * logit_scale)
            b_loss_t2v = self.loss_fct(b_logits * logit_scale)
            loss_v2t = self.loss_fct(v2t_logits * logit_scale)
            qt_loss_v2t = self.loss_fct(qt_logits.T * logit_scale)
            tem_loss_v2t = self.loss_fct(tem_logits.T * logit_scale)
            spa_loss_v2t = self.loss_fct(spa_logits.T * logit_scale)
            b_loss_v2t = self.loss_fct(b_logits.T * logit_scale)
            loss = spa_loss_t2v*0.25 + spa_loss_v2t*0.25 + tem_loss_t2v*0.25 + tem_loss_v2t*0.25
            qt_loss = (qt_loss_t2v + qt_loss_v2t) / 2
            loss = loss * 0.6 + qt_loss * 0.4
            
            return loss
        else:
            return None
        
    def _mean_pooling_for_similarity_visual(self, visual_output, video_mask,):
        video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
        visual_output = visual_output * video_mask_un
        video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
        video_mask_un_sum[video_mask_un_sum == 0.] = 1.
        video_out = torch.sum(visual_output, dim=1) / video_mask_un_sum
        return video_out

    def get_text_titles_similarity_logits(self, text_output, title_output):


        x_original = title_output
        seq_length = title_output.shape[1]
        position_ids = torch.arange(seq_length, dtype=torch.long, device=title_output.device)
        position_ids = position_ids.unsqueeze(0).expand(title_output.size(0), -1)
        sentence_position_embeddings = self.sentence_position_embeddings(position_ids)
        title_output = title_output + sentence_position_embeddings
        title_output = self.caption_transformer_encoder(title_output)


        title_ori = title_output
        text_embed = text_output.squeeze(1)
        ################# title pooling begin ##############

        title_embed_pooled = title_output.mean(dim=1)

        # elif self.text_pool_type == 'topk':
        #     bs_text, embed_dim = text_embed.shape
        #     sims = title_output @ text_embed.t()
        #     sims_topk = torch.topk(sims, self.k, dim=1)[1]
        #     title_output = title_output.unsqueeze(-1).expand(-1, -1, -1, bs_text)
        #     sims_topk = sims_topk.unsqueeze(2).expand(-1, -1, embed_dim, -1)
        #     title_embeds_topk = torch.gather(title_output, dim=1, index=sims_topk)
        #     title_embed_pooled = title_embeds_topk.sum(dim=1)
        #     title_embed_pooled = title_embed_pooled.permute(0, 2, 1)

        ################# title pooling end ##############

        text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
        title_embed_pooled = title_embed_pooled / title_embed_pooled.norm(dim=-1, keepdim=True)

        q2t_logits = torch.mm(text_embed, title_embed_pooled.t())

        retrieve_logits = q2t_logits


        return retrieve_logits

    def get_logits(self, text_feat, cls, td_cls, st_feat, video_feat, video_cls, text_mask, video_mask):
        
        # compute spa_feat & tem_feat

        td_cls1 = td_cls
        for co_layer in self.co_connetion_transformer_model_block:
            video_feat, td_cls1, co_attention_probs = co_layer(video_feat,td_cls1)

        bs,n,_,d = st_feat.shape
        st_feat = st_feat.reshape(bs, n, 7, 7, d)
        st_feat = st_feat.permute(0, 1, 4, 2, 3) # bs*frame*dim*h*w
        st_feat_reshaped = st_feat.contiguous().reshape(-1, d)
        st_feat_spa = self.glu_spa(st_feat_reshaped).reshape(bs, n, d, 7, 7)
        st_feat_tem = self.glu_tem(st_feat_reshaped).reshape(bs, n, d, 7, 7)
        
        adaptive_pool = nn.AdaptiveAvgPool2d((4, 4)) 
        pool_n = 1

        # prepare tem:
        st_feat_tem_reshaped = st_feat_tem.contiguous().reshape(bs*n*d, *st_feat_tem.shape[3:])
        pooled = adaptive_pool(st_feat_tem_reshaped)  # (bs*n*d, 4, 4)
        tem_features = pooled.reshape(bs, n, d, *pooled.shape[1:])
        tem_features = tem_features.permute(0, 1, 3, 4, 2).contiguous() # Shape: bs*frame*h*w*dim
        
        # prepare spa:
        st_feat_spa = st_feat_spa.permute(0, 1, 3, 4, 2) 
        num_spa_frames = n // pool_n 
        st_feat_spa_reshaped = st_feat_spa[:, :num_spa_frames * pool_n, :, :, :].reshape(bs, num_spa_frames, pool_n, *st_feat_spa.shape[2:])  
        spa_features = st_feat_spa_reshaped.mean(dim=2).contiguous() 

        f1,frame_f,_,_,f2 = tem_features.shape
        s1,frame_s,_,_,s2 = spa_features.shape
        tem_features = tem_features.reshape(f1,frame_f,-1,f2)
        spa_features = spa_features.reshape(s1,frame_s,-1,s2)
        st_features = st_feat.reshape(bs,n,-1,d)
        _video_feat = video_feat.unsqueeze(2) 
        
        tem_feat = torch.cat((_video_feat, tem_features), dim=2)
        tem_features = tem_feat.reshape(f1,-1,f2)
        
        video_feat_mean = _video_feat.reshape(s1, frame_s, pool_n, -1).mean(dim=2).unsqueeze(2)
        spa_feat = torch.cat((video_feat_mean, spa_features), dim=2)    
        spa_features = spa_feat.reshape(s1,-1,s2)




        base_feat = torch.cat((_video_feat, st_features), dim=2)
        base_features = base_feat.reshape(bs,-1,d)
        


        # prepare for tem logits
        tem_weight = torch.einsum('ad,bvd->abv', [cls, tem_features]) 
        tem_weight = torch.softmax(tem_weight / self.config.temp, dim=-1)
        tem_features = torch.einsum('abv,bvd->abd', [tem_weight, tem_features])
        tem_features = tem_features / tem_features.norm(dim=-1, keepdim=True)

        # prepare for spa logits
        spa_weight = torch.einsum('ad,bvd->abv', [cls, spa_features])
        spa_weight = torch.softmax(spa_weight / self.config.temp, dim=-1)
        # test_weight = torch.einsum('abv,abv->ab', [s1_weight, s_weight])
        spa_features = torch.einsum('abv,bvd->abd', [spa_weight, spa_features])
        spa_features = spa_features / spa_features.norm(dim=-1, keepdim=True)

        # prepare for base logits
        b_weight = torch.einsum('ad,bvd->abv', [cls, base_features])
        b_weight = torch.softmax(b_weight / self.config.temp, dim=-1)
        base_features = torch.einsum('abv,bvd->abd', [b_weight, base_features])
        base_features = base_features / base_features.norm(dim=-1, keepdim=True)        

        v_weight = torch.einsum('ad,bvd->abv', [cls, video_feat])
        v_weight = torch.softmax(v_weight / self.config.temp, dim=-1)
        v_weight = torch.einsum('abv,bv->abv', [v_weight, video_mask])
        video_feat = torch.einsum('abv,bvd->abd', [v_weight, video_feat])
        
        _cls = cls / cls.norm(dim=-1, keepdim=True)
        _td_cls = td_cls / td_cls.norm(dim=-1, keepdim=True)
        _v_feat = video_feat / video_feat.norm(dim=-1, keepdim=True)
        
        tem_logits = torch.einsum('ad,abd->ab', [_cls, tem_features]) 
        spa_logits = torch.einsum('ad,abd->ab', [_cls, spa_features])
        b_logits = torch.einsum('ad,abd->ab', [_cls, base_features])
        
        retrieve_logits = torch.einsum('ad,abd->ab', [_cls, _v_feat])
        return retrieve_logits, retrieve_logits.T, tem_logits, spa_logits, b_logits

    def get_text_feat(self, text_ids, text_mask, t_data, vcap_mask, shaped=False):
        if shaped is False:
            text_ids = text_ids.view(-1, text_ids.shape[-1])
            text_mask = text_mask.view(-1, text_mask.shape[-1])
            t_data = t_data.view(-1, t_data.shape[-1])
            vcap_mask = vcap_mask.view(-1, vcap_mask.shape[-1])

        bs_pair = text_ids.size(0)
        bs_pair_fast = t_data.size(0)
        cls, text_feat = self.clip.encode_text(text_ids, return_hidden=True, mask=text_mask)
        td_cls, td_text_feat = self.clip.encode_text(t_data, return_hidden=True, mask=vcap_mask)
        cls, text_feat = cls.float(), text_feat.float()
        td_cls, td_text_feat = td_cls.float(), td_text_feat.float()
        text_feat = text_feat.view(bs_pair, -1, text_feat.size(-1))
        cls = cls.view(bs_pair, -1, cls.size(-1)).squeeze(1)
        td_cls = td_cls.view(bs_pair, -1, td_cls.size(-1)).squeeze(1)
        return text_feat, cls, td_cls

    def get_video_feat(self, video, video_mask, shaped=False):
        if shaped is False:
            video_mask = video_mask.view(-1, video_mask.shape[-1])
            video = torch.as_tensor(video).float()
            if len(video.size()) == 5:
                b, n_v, d, h, w = video.shape
                video = video.view(b * n_v, d, h, w)
            else:
                b, pair, bs, ts, channel, h, w = video.shape
                video = video.view(b * pair * bs * ts, channel, h, w)

        bs_pair, n_v = video_mask.size()
        video_feat, st_feat = self.clip.encode_image(video, return_hidden=True)
        video_feat = video_feat.float()
        st_feat = st_feat.float()
        video_feat = video_feat.view(bs_pair, -1, video_feat.size(-1))
        st_feat = st_feat.view(bs_pair, -1, st_feat.size(-2), st_feat.size(-1))
        video_feat, video_cls = self.agg_video_feat(video_feat, video_mask, self.agg_module) # agg_module=seqTransf
        return video_feat, st_feat, video_cls

    def get_text_video_feat(self, text_ids, text_mask, t_data, vcap_mask, video, video_mask, shaped=False):
        if shaped is False:
            text_ids = text_ids.view(-1, text_ids.shape[-1])
            text_mask = text_mask.view(-1, text_mask.shape[-1])
            t_data = t_data.view(-1, t_data.shape[-1])
            vcap_mask = vcap_mask.view(-1, vcap_mask.shape[-1])
            video_mask = video_mask.view(-1, video_mask.shape[-1])
            video = torch.as_tensor(video).float()
            if len(video.shape) == 5:
                b, n_v, d, h, w = video.shape
                video = video.view(b * n_v, d, h, w)
            else:
                b, pair, bs, ts, channel, h, w = video.shape
                video = video.view(b * pair * bs * ts, channel, h, w)

        text_feat, cls, td_cls = self.get_text_feat(text_ids, text_mask, t_data, vcap_mask, shaped=True)
        video_feat, st_feat, video_cls = self.get_video_feat(video, video_mask, shaped=True)

        return cls, text_feat, td_cls, video_feat, st_feat, video_cls

    def get_video_avg_feat(self, video_feat, video_mask):
        video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
        video_feat = video_feat * video_mask_un
        video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
        video_mask_un_sum[video_mask_un_sum == 0.] = 1.
        video_feat = torch.sum(video_feat, dim=1) / video_mask_un_sum
        return video_feat

    def get_text_sep_feat(self, text_feat, text_mask):
        n_dim = text_feat.dim()
        text_feat = text_feat.contiguous()
        if n_dim == 3: 
            text_feat = text_feat[torch.arange(text_feat.shape[0]), torch.sum(text_mask, dim=-1) - 1, :]
            text_feat = text_feat.unsqueeze(1).contiguous()
        elif n_dim == 4:
            bs_pair, n_text, n_word, text_dim = text_feat.shape
            text_feat = text_feat.view(bs_pair * n_text, n_word, text_dim)
            text_mask = text_mask.view(bs_pair * n_text, n_word)
            text_feat = text_feat[torch.arange(text_feat.shape[0]), torch.sum(text_mask, dim=-1) - 1, :]
            text_feat = text_feat.view(bs_pair, n_text, text_dim)
        return text_feat

    def agg_video_feat(self, video_feat, video_mask, agg_module):
        video_feat = video_feat.contiguous()
        if agg_module == "None":
            pass
        elif agg_module == "seqLSTM":
            # Sequential type: LSTM
            video_feat_original = video_feat
            video_feat = pack_padded_sequence(video_feat, torch.sum(video_mask, dim=-1).cpu(),
                                              batch_first=True, enforce_sorted=False)
            video_feat, _ = self.lstm_visual(video_feat)
            if self.training: self.lstm_visual.flatten_parameters()
            video_feat, _ = pad_packed_sequence(video_feat, batch_first=True)
            video_feat = torch.cat(
                (video_feat, video_feat_original[:, video_feat.size(1):, ...].contiguous()), dim=1)
            video_feat = video_feat + video_feat_original
        elif agg_module == "seqTransf":
            # Sequential type: Transformer Encoder
            video_feat_original = video_feat
            seq_length = video_feat.size(1)
            position_ids = torch.arange(seq_length, dtype=torch.long, device=video_feat.device)
            position_ids = position_ids.unsqueeze(0).expand(video_feat.size(0), -1)
            frame_position_embeddings = self.frame_position_embeddings(position_ids)
            video_feat = video_feat + frame_position_embeddings
            extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0
            extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1)
            video_feat = video_feat.permute(1, 0, 2)  # NLD -> LND
            video_feat = self.transformerClip(video_feat, extended_video_mask)
            video_feat = video_feat.permute(1, 0, 2)  # LND -> NLD
            video_feat = video_feat + video_feat_original
        return video_feat, video_feat_original


    def get_similarity_logits(self, text_feat, cls, td_cls, st_feat, video_feat, video_cls, text_mask, video_mask, shaped=False):
        if shaped is False:
            text_mask = text_mask.view(-1, text_mask.shape[-1])
            video_mask = video_mask.view(-1, video_mask.shape[-1])

        t2v_logits, v2t_logits, tem_logits, spa_logits, b_logits = self.get_logits(text_feat, cls, td_cls, st_feat, video_feat, video_cls, text_mask, video_mask)
        
        return t2v_logits, v2t_logits, tem_logits, spa_logits, b_logits

    @property
    def dtype(self):
        """
        :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
        """
        try:
            return next(self.parameters()).dtype
        except StopIteration:
            # For nn.DataParallel compatibility in PyTorch 1.5
            def find_tensor_attributes(module: nn.Module):
                tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
                return tuples

            gen = self._named_members(get_members_fn=find_tensor_attributes)
            first_tuple = next(gen)
            return first_tuple[1].dtype

    def init_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, LayerNorm):
            if 'beta' in dir(module) and 'gamma' in dir(module):
                module.beta.data.zero_()
                module.gamma.data.fill_(1.0)
            else:
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()