import numpy as np
from analysis.utils import flatten_parameters, assign_params
from tqdm import tqdm
import torch


def calculate_model_interpolation(model1, model2, dataloader, eval_fn, device):
    """Runs the model interpolation analysis.

    Linearly interpolates from the parameters of model1 to the parameters of model2,
    evaluating on a dataset at every point.

    Args:
        model1: Starting point of interpolation.
        model2: Ending point of interpolation.
        dataloader: Dataloader for the dataset to evaluate on.
        eval_fn: A function that takes a model, a dataloader, and a device, and returns
            a dictionary with two metrics: "loss" and "accuracy".
        device: Device that the model and data should be moved to for evaluation.
    """
    losses, accuracies, ts = [], [], []
    w1 = flatten_parameters(model1).to(device=device)
    w2 = flatten_parameters(model2).to(device=device)
    model1 = model1.to(device=device)
    delta = w2 - w1
    for t in tqdm(np.arange(0.0, 1.01, 0.025)):
        ts.append(t)
        new_weights = w1 + t * delta
        assign_params(model1, new_weights)
        metrics = eval_fn(model1, dataloader, device)
        losses.append(metrics["loss"])
        accuracies.append(metrics["accuracy"])
    return losses, accuracies, ts


def calculate_loss_contours(
    model1, model2, model3, dataloader, eval_fn, device, granularity=20, margin=0.2
):
    """Runs the loss contour analysis.

    Creates plane based on the parameters of 3 models, and computes loss and accuracy
    contours on that plane. Specifically, computes 2 axes based on the 3 models, and
    computes metrics on points defined by those axes.

    Args:
        model1: Origin of plane.
        model2: Model used to define y axis of plane.
        model3: Model used to define x axis of plane.
        dataloader: Dataloader for the dataset to evaluate on.
        eval_fn: A function that takes a model, a dataloader, and a device, and returns
            a dictionary with two metrics: "loss" and "accuracy".
        device: Device that the model and data should be moved to for evaluation.
        granularity: How many segments to divide each axis into. The model will be
            evaluated at granularity*granularity points.
        margin: How much margin around models to create evaluation plane.
    """
    w1 = flatten_parameters(model1).to(device=device)
    w2 = flatten_parameters(model2).to(device=device)
    w3 = flatten_parameters(model3).to(device=device)
    model1 = model1.to(device=device)

    # Define x axis
    u = w3 - w1
    dx = torch.norm(u).item()
    u /= dx

    # Define y axis
    v = w2 - w1
    v -= torch.dot(u, v) * u
    dy = torch.norm(v).item()
    v /= dy

    # Define grid representing parameters that will be evaluated.
    coords = np.stack(get_xy(p, w1, u, v) for p in [w1, w2, w3])
    alphas = np.linspace(0.0 - margin, 1.0 + margin, granularity)
    betas = np.linspace(0.0 - margin, 1.0 + margin, granularity)
    losses = np.zeros((granularity, granularity))
    accuracies = np.zeros((granularity, granularity))
    grid = np.zeros((granularity, granularity, 2))

    # Evaluate parameters at every point on grid
    progress = tqdm(total=granularity * granularity)
    for i, alpha in enumerate(alphas):
        for j, beta in enumerate(betas):
            p = w1 + alpha * dx * u + beta * dy * v
            assign_params(model1, p)
            metrics = eval_fn(model1, dataloader, device)
            grid[i, j] = [alpha * dx, beta * dy]
            losses[i, j] = metrics["loss"]
            accuracies[i, j] = metrics["accuracy"]
            progress.update()
    progress.close()
    return {
        "grid": grid.tolist(),
        "coords": coords.tolist(),
        "losses": losses.tolist(),
        "accuracies": accuracies.tolist(),
    }


def get_xy(point, origin, vector_x, vector_y):
    """Return transformed coordinates of a point given parameters defining coordinate
    system.

    Args:
        point: point for which we are calculating coordinates.
        origin: origin of new coordinate system
        vector_x: x axis of new coordinate system
        vector_y: y axis of new coordinate system
    """
    return np.array(
        [
            torch.dot(point - origin, vector_x).item(),
            torch.dot(point - origin, vector_y).item(),
        ]
    )
