from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import opts

opt_ = opts.parse_opt()

def if_use_att(caption_model):
    # Decide if load attention feature according to caption model
    if caption_model in ['show_tell', 'all_img', 'fc']:
        return False
    return True

# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
def decode_sequence(ix_to_word, seq):
    N, D = seq.size()
    out = []
    for i in range(N):
        txt = ''
        for j in range(D):
            ix = seq[i,j]
            if ix > 0 :
                if j >= 1:
                    txt = txt + ' '
                txt = txt + ix_to_word[str(ix.item())]
            else:
                break
        out.append(txt)
    return out

def to_contiguous(tensor):
    if tensor.is_contiguous():
        return tensor
    else:
        return tensor.contiguous()

class RewardCriterion(nn.Module):
    def __init__(self):
        super(RewardCriterion, self).__init__()

    def forward(self, input, seq, reward):
        input = to_contiguous(input).view(-1)
        reward = to_contiguous(reward).view(-1)
        mask = (seq>0).float()
        mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1)
        output = - input * reward * mask
        output = torch.sum(output) / torch.sum(mask)

        return output

class LanguageModelCriterion(nn.Module):
    def __init__(self):
        super(LanguageModelCriterion, self).__init__()

    def getNovelMask(self, target, vocab_size):
    # target = torch.max(target,dim=-1)[1]
        b,l = target.size()
        zeros = torch.zeros(b,l,vocab_size).cuda()
        ones = torch.ones(b,l,vocab_size).cuda()
        target_index = target.unsqueeze(1).expand(b,l,l).transpose(1,2)
        target_index = target_index.triu().transpose(1,2)

        matrix = zeros.scatter_add_(2, target_index, ones)
        matrix[:,:,0] = 0
        summ_true = torch.tensor(range(1,l+1)).unsqueeze(0).float().cuda()
        summ_now = torch.sum(matrix,dim=-1)
        diff = summ_true - summ_now
        matrix[:,:,0] = diff
        matrix = torch.cat((torch.zeros(b,1,vocab_size).cuda(),matrix[:,:-1,:]),-2)
        novel_mask = (1 - (matrix >= 1.).float()).float()

        return novel_mask

    def forward(self, input_, target, mask):
        # truncate to the same size
        target = target[:, :input_.size(1)] 
        mask =  mask[:, :input_.size(1)]
        
        ###########################################################
        # input here is log probability
        probs = input_.exp()
        novel_mask = self.getNovelMask(target, probs.size(-1))
        rep_mask = 1 - novel_mask
        new_probs = probs * novel_mask * opt_.gamma + probs * rep_mask +1e-8
        new_probs = F.normalize(new_probs,p=1,dim=-1)
        input_ = new_probs.log()
        ###########################################################

        output = -input_.gather(2, target.unsqueeze(2)).squeeze(2) * mask

        output = torch.sum(output) / torch.sum(mask)


        return output



def set_lr(optimizer, lr):
    for group in optimizer.param_groups:
        group['lr'] = lr

def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            param.grad.data.clamp_(-grad_clip, grad_clip)

def build_optimizer(params, opt):
    if opt.optim == 'rmsprop':
        return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
    elif opt.optim == 'adagrad':
        return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgd':
        return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgdm':
        return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
    elif opt.optim == 'sgdmom':
        return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
    elif opt.optim == 'adam':
        return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
    else:
        raise Exception("bad option opt.optim: {}".format(opt.optim))
    