"""
Based on https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb#scrollTo=UKLabE9qCjLh
"""
from typing import List, Optional
import numpy as np
import torch
from torch import nn
import gc

from torch.utils import data as torchdata


class ConstantNet(nn.Module):
    def __init__(self, output_dims, value=0.):
        super().__init__()
        if not isinstance(output_dims, list):
            self.output_dims = [output_dims]
        else:
            self.output_dims = output_dims
        self.value = value

    def forward(self, x):
        return torch.ones((x.shape[0], *self.output_dims)) * self.value


def garbage_collect(fun):
    def wrapper(*args, **kwargs):
        ret = fun(*args, **kwargs)
        gc.collect()
        return ret

    return wrapper


@garbage_collect
def batch_predict(model, dataloader):
    pred_velocities = []
    pred_densities = []
    true_velocities = []
    true_densities = []
    coords = []
    for step, (x_, vf_true) in enumerate(dataloader):
        x_, vf_true = x_.to(model.device), vf_true.to(model.device)

        velocity_hat = model.velocity(x_).cpu().detach().numpy()
        density_hat = model.sqrt_density(x_).cpu().detach().numpy()

        pred_velocities.append(velocity_hat)
        pred_densities.append(density_hat)

        true_velocities.append(to_numpy(vf_true[:, 1:]))
        true_densities.append(to_numpy(vf_true[:, [1]]))

        coords.append(to_numpy(x_))

    pred_velocities = np.concatenate(pred_velocities, 0)
    pred_densities = np.concatenate(pred_densities, 0)
    true_velocities = np.concatenate(true_velocities, 0)
    true_densities = np.concatenate(true_densities, 0)
    # coords = np.concatenate(coords, 0)

    return pred_velocities, pred_densities, true_velocities, true_densities

@garbage_collect
def batch_predict_densityonly(model, dataloader):
    pred_densities = []
    true_densities = []
    coords = []
    for step, (x_, y_) in enumerate(dataloader):
        x_, y_ = x_.to(model.device), y_.to(model.device)
        density_hat = model.density(x_).cpu().detach().numpy()
        pred_densities.append(density_hat)
        true_densities.append(to_numpy(y_))

        coords.append(to_numpy(x_))

    pred_densities = np.concatenate(pred_densities, 0)
    true_densities = np.concatenate(true_densities, 0)
    # coords = np.concatenate(coords, 0)

    return pred_densities, true_densities


@garbage_collect
def batch_predict2(model, dataloader):
    pred_velocities = []
    pred_densities = []
    for step, (x_, ) in enumerate(dataloader):
        x_ = x_.to(model.device)

        velocity_hat = model.velocity(x_).cpu().detach().numpy()
        density_hat = model.sqrt_density(x_).cpu().detach().numpy()

        pred_velocities.append(velocity_hat)
        pred_densities.append(density_hat)


    pred_velocities = np.concatenate(pred_velocities, 0)
    pred_densities = np.concatenate(pred_densities, 0)
    # coords = np.concatenate(coords, 0)

    return pred_velocities, pred_densities

def to_numpy(input):
    if isinstance(input, torch.Tensor):
        return input.detach().cpu().numpy()
    elif isinstance(input, np.ndarray):
        return input
    else:
        raise TypeError('Unknown type of input, expected torch.Tensor or ' \
                        'np.ndarray, but got {}'.format(type(input)))


def to_tensor(input, dtype=torch.float32, device="cuda", **kwargs):
    if isinstance(input, torch.Tensor):
        return input
    else:
        return torch.tensor(input, dtype=dtype, **kwargs).to(device)




def dataloader_from_np(arr_list: List[np.array], mb_size: int, **kwargs: Optional):
    train_dataset = torch.utils.data.TensorDataset(
        *[torch.Tensor(arr) for arr in arr_list])  # torch.Tensor(X_train), torch.Tensor(Y_train))
    train_dataloader = torchdata.DataLoader(train_dataset, batch_size=mb_size, **kwargs)
    return train_dataloader
