import torch
from functools import partial

# Convenience functions for manipulating weights, biases, and extracting them
# from flattened tensors
def wt_and_bias(params, layer):
    # Get layer description
    in_sz, out_sz, bias = layer
    # Split into weights and bias
    if bias:
        wt, bias = params.split((in_sz * out_sz, out_sz), dim = -1)
    else:
        wt, bias = params, None

    # Unflatten the weights
    wt = wt.view(-1, out_sz, in_sz)
    # Ensure weight and bias dimensions are compatible
    bias = bias.unsqueeze(0) if bias is not None and bias.dim() < 2 else bias
    return wt, bias

def layer_numel(in_sz, out_sz, bias = False):
    return (in_sz + bool(bias)) * out_sz

def total_num_param(*layers):
    return sum(layer_numel(*layer) for layer in layers)

def get_weights_and_biases(params, *layers):
    # Get a tuple of the flattened parameters of each layer
    sizes = [layer_numel(*layer) for layer in layers]
    layer_params = torch.split(params, sizes, dim = -1)
    # Unflatten the parameters of each layer and return a list of the parameters
    return [wt_and_bias(params, layer) for params, layer in zip(layer_params,
                                                                layers)]

# Function which computes the forward pass of a single hidden layer NN
def one_hidden_layer_network(data, params, layer1, layer2, nonlinearity,
                             param_wts = None):
    if data.dim() > 2 or params.dim() > 2 or data.dim() < 1 or params.dim() < 1:
        raise ValueError("Input data and parameter tensors must be 1D or 2D!")
    if param_wts is not None and (params.dim() < 2 or len(params) != len(param_wts)):
        raise ValueError("param_wts has incorrect dimensions or does not make sense!")
    if data.dim() < 2: data = data.unsqueeze(0)
    if params.dim() < 2: params = params.unsqueeze(0)

    layer1_params, layer2_params = get_weights_and_biases(params, layer1, layer2)

    layer1_wts, layer1_bias = layer1_params
    output = torch.matmul(data.unsqueeze(0), torch.transpose(layer1_wts, -1, -2))
    if layer1_bias is not None: output = output + layer1_bias.unsqueeze(1)
    output = output.clamp(min=0)

    layer2_wts, layer2_bias = layer2_params
    output = torch.matmul(output, torch.transpose(layer2_wts, -1, -2))
    if layer2_bias is not None: output = output + layer2_bias.unsqueeze(1)
    output = nonlinearity(output)

    if param_wts is None: output = output.mean(0)
    else:
        param_wts = param_wts / param_wts.sum()
        output = torch.sum(output * param_wts.reshape(-1,1,1), 0)
    return output

# Function which computes the kernel induced by a neural network
def student_teacher_obj(x, data, resids, network, data_wts = None):
    obj = torch.sum(resids * network(data, x), dim = -1)
    if data_wts is None:
        return obj.mean()
    data_wts = data_wts / data_wts.sum()
    return torch.sum(obj * data_wts)

# Sets up the student teacher problem
def gausstrans(x, sigma = 2.0):
    return torch.exp(-x ** 2 / (sigma ** 2))

# Sets up the student teacher problem
class StudentTeacher:
    def __init__(self, data_sampler, teacher_sampler, in_sz, hidden_sz,
	         out_sz, non_linearity = gausstrans, iterate_bsize = float("inf"),
                 bias = None):
       self.network = partial(one_hidden_layer_network,
                              layer1 = (in_sz, hidden_sz, bias),
                              layer2 = (hidden_sz, out_sz, bias),
                              nonlinearity = non_linearity)
       self.data_sampler = data_sampler
       self.teacher_sampler = teacher_sampler
       self.iterate_bsize = iterate_bsize

    def get_ifunc(self, mu_sampler):
        data_pts, teacher_pts, mu_pts = self.data_sampler(), \
            self.teacher_sampler(), mu_sampler(self.iterate_bsize)
        t_pred, s_pred = StudentTeacher.get_preds(self.network, data_pts,
                                                  teacher_pts, mu_pts)
        return partial(student_teacher_obj, data = data_pts,
                       resids = s_pred - t_pred, network = self.network)

    def get_mmd(self, mu_sampler):
        data_pts, teacher_pts, mu_pts = self.data_sampler(), \
            self.teacher_sampler(), mu_sampler(self.iterate_bsize)
        t_pred, s_pred = StudentTeacher.get_preds(self.network, data_pts,
                                                  teacher_pts, mu_pts)
        # return torch.sum((s_pred - t_pred) ** 2, -1).mean()
        return StudentTeacher.mmd2(t_pred, s_pred)

    def get_student_pred(self, mu_sampler):
        data_pts = self.data_sampler()
        return StudentTeacher.get_preds(self.network, data_pts,
                                        mu_sampler(self.iterate_bsize))

    @staticmethod
    def mmd2(teacher_output, student_output):
        return torch.sum((teacher_output - student_output) ** 2, -1).mean()

    @staticmethod
    def get_preds(network, data_pts, *dist_samps):
        return tuple(network(data_pts, dist_samp) for dist_samp in dist_samps)
