import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions.categorical import Categorical
import numpy as np
import warnings




class insCreator(nn.Module): 
    # sample from an object list
    def __init__(self):
        super(insCreator, self).__init__()

    def forward(self, x, seq_len, num_processes, distribution=None, deterministic=False, random_mode=False, continuous=False, num_mini_batch=None):

        if distribution is not None:
            assert distribution.size(1) == x.size(0)
            assert num_processes == distribution.size(0)
        else:
            distribution = torch.rand((num_processes,x.size(0))).to(x.device)
            distribution = distribution/torch.sum(distribution,dim=1).view((-1,1))

        # sample from the given distribution
        sample_prob = Categorical(distribution)
        sample_idx = sample_prob.sample((seq_len,))
        sample_idx = sample_idx.transpose(0,1).view((-1,)).to(x.device)


        # pure random instance generated by the distribution
        if num_mini_batch is not None:
            sample_idx = sample_idx.view((num_processes,-1)).repeat_interleave(num_mini_batch,dim=0).view((-1,))
            bsz = num_processes*num_mini_batch
        else:
            bsz = num_processes
        output_sample = x[sample_idx].view((bsz,seq_len,-1))
        if continuous:
            output_sample = output_sample.float()
            with torch.no_grad():
                output_sample += torch.round(torch.rand(output_sample.size(),device=output_sample.device),decimals=3) - 0.5
        return output_sample, None



