import torch
import numpy as np
import time
from termcolor import colored
import inspect
import random
import sys
import concurrent.futures
import types
from collections import OrderedDict

def set_random_seed(seed=None):
    if seed is None:
        seed = int((time.time()*1e6) % 1e8)
    global _random_seed
    _random_seed = seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_random_seed()

_print_refresh_rate = 0.25  

progress_bar_num_iters = None
progress_bar_len_str_num_iters = None
progress_bar_time_start = None
progress_bar_prev_duration = None


def progress_bar(i, len):
    bar_len = 20
    filled_len = int(round(bar_len * i / len))    
    return '#' * filled_len + '-' * (bar_len - filled_len)


def progress_bar_init(message, num_iters, iter_name='Items', rejections=False):
    global progress_bar_num_iters
    global progress_bar_len_str_num_iters
    global progress_bar_time_start
    global progress_bar_prev_duration
    if num_iters < 1:
        raise ValueError('num_iters must be a positive integer')
    progress_bar_num_iters = num_iters
    progress_bar_time_start = time.time()
    progress_bar_prev_duration = 0
    progress_bar_len_str_num_iters = len(str(progress_bar_num_iters))
    print(message)
    sys.stdout.flush()
    if not rejections:
        print('Time spent  | Time remain.| Progress             | {} | {}/sec'.format(iter_name.ljust(progress_bar_len_str_num_iters * 2 + 1), iter_name))
    else:
        print('Time spent  | Time remain.| Progress             | {} | {}/sec | Rejected Samples'.format(iter_name.ljust(progress_bar_len_str_num_iters * 2 + 1), iter_name))


def progress_bar_update(iter,rejections=None):
    global progress_bar_prev_duration
    duration = time.time() - progress_bar_time_start
    if rejections is None:
        if (duration - progress_bar_prev_duration > _print_refresh_rate) or (iter >= progress_bar_num_iters - 1):
            progress_bar_prev_duration = duration
            traces_per_second = (iter + 1) / duration
            print('{} | {} | {} | {}/{} | {:,.2f}       '.format(days_hours_mins_secs_str(duration), days_hours_mins_secs_str((progress_bar_num_iters - iter) / traces_per_second), progress_bar(iter, progress_bar_num_iters), str(iter).rjust(progress_bar_len_str_num_iters), progress_bar_num_iters, traces_per_second), end='\r')
            sys.stdout.flush()
    else:
        if (duration - progress_bar_prev_duration > _print_refresh_rate) or (iter >= progress_bar_num_iters - 1):
            progress_bar_prev_duration = duration
            traces_per_second = (iter + 1) / duration
            print('{} | {} | {} | {}/{} | {:,.2f} |  {:,.2f}     '.format(days_hours_mins_secs_str(duration), days_hours_mins_secs_str((progress_bar_num_iters - iter) / traces_per_second), progress_bar(iter, progress_bar_num_iters), str(iter).rjust(progress_bar_len_str_num_iters), progress_bar_num_iters, traces_per_second,rejections), end='\r')
            sys.stdout.flush()



def progress_bar_end(message=None):
    progress_bar_update(progress_bar_num_iters)
    print()
    if message is not None:
        print(message)


def days_hours_mins_secs_str(total_seconds):
    d, r = divmod(total_seconds, 86400)
    h, r = divmod(r, 3600)
    m, s = divmod(r, 60)
    return '{0}d:{1:02}:{2:02}:{3:02}'.format(int(d), int(h), int(m), int(s))


def has_nan_or_inf(value):
    if torch.is_tensor(value):
        value = torch.sum(value)
        isnan = int(torch.isnan(value)) > 0
        isinf = int(torch.isinf(value)) > 0
        return isnan or isinf
    else:
        value = float(value)
        return (value == float('inf')) or (value == float('-inf')) or (value == float('NaN'))


class LogProbError(Exception):
    pass


def flatten(model):
    return torch.cat([p.flatten() for p in model.parameters()])

def unflatten_dict(model, flattened_params):
    params_list = unflatten(model, flattened_params)
    return {name: param for name, param in zip(model.state_dict().keys(), params_list)}


def unflatten(model, flattened_params):
    if flattened_params.dim() != 1:
        raise ValueError('Expecting a 1d flattened_params')
    params_list = []
    i = 0
    for val in list(model.parameters()):
        length = val.nelement()
        param = flattened_params[i:i+length].view_as(val)
        params_list.append(param)
        i += length

    return params_list


def update_model_params_in_place(model, params):
    for weights, new_w in zip(model.parameters(), params):
        weights.data = new_w



def gradient(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False):
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)
    grads = torch.autograd.grad(outputs, inputs, grad_outputs,
                                allow_unused=True,
                                retain_graph=retain_graph,
                                create_graph=create_graph)
    grads = [x if x is not None else torch.zeros_like(y) for x, y in zip(grads, inputs)]
    return torch.cat([x.contiguous().view(-1) for x in grads])


def hessian(output, inputs, out=None, allow_unused=False, create_graph=False, return_inputs=False):
    assert output.ndimension() == 0

    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)

    n = sum(p.numel() for p in inputs)
    if out is None:
        out = output.new_zeros(n, n)

    ai = 0
    for i, inp in enumerate(inputs):
        [grad] = torch.autograd.grad(output, inp, create_graph=True, allow_unused=allow_unused)
        grad = torch.zeros_like(inp) if grad is None else grad
        grad = grad.contiguous().view(-1)

        for j in range(inp.numel()):
            if grad[j].requires_grad:
                row = gradient(grad[j], inputs[i:], retain_graph=True, create_graph=create_graph)[j:]
            else:
                row = grad[j].new_zeros(sum(x.numel() for x in inputs[i:]) - j)

            out[ai, ai:].add_(row.type_as(out))  
            if ai + 1 < n:
                out[ai + 1:, ai].add_(row[1:].type_as(out))  
            del row
            ai += 1
        del grad
    
    if return_inputs:
        return out, inputs
    else:
        return out

def jacobian(outputs, inputs, create_graph=False, return_inputs = False):
    if torch.is_tensor(outputs):
        outputs = [outputs]
    else:
        outputs = list(outputs)

    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)

    jac = []
    for output in outputs:
        output_flat = output.view(-1)
        output_grad = torch.zeros_like(output_flat)
        for i in range(len(output_flat)):
            output_grad[i] = 1
            jac += [gradient(output_flat, inputs, output_grad, True, create_graph)]
            output_grad[i] = 0
    if return_inputs:
        return torch.stack(jac), inputs
    else:
        return torch.stack(jac)
        
PY2 = sys.version_info[0] == 2
_internal_attrs = {'_backend', '_parameters', '_buffers', '_backward_hooks', '_forward_hooks', '_forward_pre_hooks', '_modules'}

_new_methods = {'conv2d_forward','_forward_impl', '_check_input_dim', '_conv_forward',
                'check_forward_args', 'check_input', 'check_hidden_size', "get_expected_hidden_size",
                "get_expected_cell_size", "permute_hidden"}

class Scope(object):
    def __init__(self):
        self._modules = OrderedDict()

def _make_functional(module, params_box, params_offset):
    self = Scope()
    num_params = len(module._parameters)
    param_names = list(module._parameters.keys())
    
    if 'bias' in param_names and module._parameters['bias'] is None:
        param_names[-1] = 'bias_None' 
    forward = type(module).forward.__func__ if PY2 else type(module).forward
    if type(module) == torch.nn.modules.container.Sequential:
        forward = Sequential_forward_patch
    if 'BatchNorm' in module.__class__.__name__:
        forward = bn_forward_patch

    for name, attr in module.__dict__.items():
        if name in _internal_attrs:
            continue
        setattr(self, name, attr)
    for name in dir(module):
        if name in _new_methods:
            if name == '_conv_forward': 
                setattr(self, name, types.MethodType(type(module)._conv_forward,self))
            if name == 'conv2d_forward':
                setattr(self, name, types.MethodType(type(module).conv2d_forward,self))
            if name == '_forward_impl':
                setattr(self, name, types.MethodType(type(module)._forward_impl,self))
            if name == 'check_forward_args':
                setattr(self, name, types.MethodType(type(module).check_forward_args,self))
            if name == 'check_input':
                setattr(self, name, types.MethodType(type(module).check_input,self))
            if name == 'check_hidden_size':
                setattr(self, name, types.MethodType(type(module).check_hidden_size,self))
            if name == 'get_expected_hidden_size':
                setattr(self, name, types.MethodType(type(module).get_expected_hidden_size,self))
            if name == 'get_expected_cell_size':
                setattr(self, name, types.MethodType(type(module).get_expected_cell_size,self))
            if name == 'permute_hidden':
                setattr(self, name, types.MethodType(type(module).permute_hidden,self))
            if name == '_check_input_dim': 
                
                setattr(self, name, types.MethodType(type(module)._check_input_dim,self))

    child_params_offset = params_offset + num_params
    for name, child in module.named_children():
        child_params_offset, fchild = _make_functional(child, params_box, child_params_offset)
        self._modules[name] = fchild  
        setattr(self, name, fchild)
    def fmodule(*args, **kwargs):        
        if 'bias_None' in param_names:
            params_box[0].insert(params_offset + 1, None)
        for name, param in zip(param_names, params_box[0][params_offset:params_offset + num_params]):
            if name == 'bias_None':
                setattr(self, 'bias', None)
            else:
                setattr(self, name, param)
        return forward(self, *args)

    return child_params_offset, fmodule


def make_functional(module):
    params_box = [None]
    _, fmodule_internal = _make_functional(module, params_box, 0)

    def fmodule(*args, **kwargs):
        params_box[0] = kwargs.pop('params') 
        return fmodule_internal(*args, **kwargs)

    return fmodule


def Sequential_forward_patch(self, input):
    
    for label, module in self._modules.items():
        input = module(input)
    return input


def bn_forward_patch(self, input):
    
    return torch.nn.functional.batch_norm(
                input, running_mean = None, running_var = None,
                weight = self.weight, bias = self.bias,
                training = self.training,
                momentum = self.momentum, eps = self.eps)
