import torch as tc
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from src.utils import grad, SamplingArgs, LinearSolveArgs
from tqdm import tqdm

def sample(model, x: tc.Tensor, L: tc.Tensor, dxdt: tc.Tensor, sampling_args: SamplingArgs):
    """
    Samples hidden layer parameters.

    Args:
        x                       of shape (n_points, n_features)
        L                       Poisson matrix of shape (n_features)
        dxdt                    of shape (n_points, n_features)
        sampling_args           Arguments for random-feature sampling
    """
    model = model.train()
    model.sample_hidden(x, sampling_args, e_pred=None) # Initial sampling if Approximate-SWIM
    model.freeze_hidden_layers()

def sample_and_linear_solve(model, x: tc.Tensor, L: tc.Tensor, dxdt: tc.Tensor,
                            sampling_args: SamplingArgs, linear_solve_args: LinearSolveArgs,
                            return_cond=False):
    """
    Samples hidden layer parameters and only fits last linear layer using least squares.

    Args:
        x                       of shape (n_points, n_features)
        L                       Poisson matrix of shape (n_features)
        dxdt                    of shape (n_points, n_features)
        sampling_args           Arguments for random-feature sampling
        linear_solve_args       Arguments for the linear solver
        return_cond             Returns the condition number of the matrix associated with the linear solve
    """
    model = model.train()
    model.sample_hidden(x, sampling_args, e_pred=None) # Initial sampling if Approximate-SWIM
    if return_cond:
        assert sampling_args.sample_uniformly
        return fit_linear_layer(model, x, L, dxdt, linear_solve_args, return_cond)

    fit_linear_layer(model, x, L, dxdt, linear_solve_args, return_cond)

    if not sampling_args.sample_uniformly: # Approximate-SWIM
        model.sample_hidden(x, sampling_args, e_pred=model.forward(x))
        fit_linear_layer(model, x, L, dxdt, linear_solve_args, return_cond)

def get_indices(arr, indices):
    new_arr = []
    for index in indices:
        new_arr.append(arr[index])
    return new_arr

def fit_linear_layer(model, x, L, dxdt, args: LinearSolveArgs, return_cond=False):
    """
    Computes prelinear output gradient and then fit linear layer weights to match
    the observed dx/dt data using Hamilton's equations. This function should be called after
    sampling the hidden layer weights using any sampling algorithm, e.g., SWIM, to fit the last
    linear layer parameters.

    Args:
        x of shape (n_points, n_features)
        L of shape (n_features, n_features)
        dxdt of shape (n_points, n_features)
        args for linear solver
        return_cond Whether to return the condition number of the matrix we solve
    """
    if args.batch_size is None:
        x, dxdt, L = x.to(args.device), dxdt.to(args.device), L.to(args.device)

        model.to(args.device)
        model = model.train()

        # Compute A matrix for the least-squares
        prelinear_grad = grad(model, x, apply_linear=False, mode=args.mode) # of shape (n_points, out_dim, n_features)
        energy_part = tc.matmul(L.unsqueeze(0), prelinear_grad.transpose(1, 2)).reshape(-1, model.width) # of shape (n_points, n_features)
        if return_cond:
            cond = tc.linalg.cond(energy_part).detach().cpu().item()
        weight = tc.linalg.lstsq(energy_part, dxdt.flatten(), rcond=args.rcond, driver=args.driver).solution

        model.linear.weight = nn.Parameter(weight.reshape(1, -1)) # output is scalar (total energy)
        nn.init.zeros_(model.linear.bias) # bias does not affect the solution for the gradient
        model = model.to('cpu')
        if return_cond:
            return cond
    else:
        L = L.to(args.device)

        model.to(args.device)
        model = model.train()

        # save the whole edge index to restore later
        edge_index_all = model.edge_index

        # batch x and y using dataloader
        assert (not args.batch_size is None) and args.batch_size > 0

        train_loader = DataLoader(
            TensorDataset(x, dxdt, tc.arange(start=0, end=len(edge_index_all), dtype=tc.int64)),
            batch_size=args.batch_size
        )

        total_weight = None
        for x_batch, dxdt_batch, edge_index_indices in tqdm(train_loader, desc="Training batch-wise lstsq"):
            model.edge_index = get_indices(edge_index_all, edge_index_indices)

            prelinear_grad = grad(model, x_batch.to(args.device), apply_linear=False, mode=args.mode) # of shape (n_points, out_dim, n_features)
            energy_part = tc.matmul(L.unsqueeze(0), prelinear_grad.transpose(1, 2)).reshape(-1, model.width) # of shape (n_points, n_features)
            weight = tc.linalg.lstsq(energy_part, dxdt_batch.to(args.device).flatten(), rcond=args.rcond, driver=args.driver).solution

            if total_weight is None: total_weight = weight
            else: total_weight += weight

        model.linear.weight = nn.Parameter((total_weight / len(train_loader)).reshape(1, -1)) # output is scalar (total energy)
        nn.init.zeros_(model.linear.bias) # bias does not affect the solution for the gradient
        model = model.to('cpu')

        # reset the edge indices
        model.edge_index = edge_index_all
