import pickle
import time
from pathlib import Path
import sys
import torch
import incense
from sorcerun.sacred_utils import run_sacred_experiment

from sorcerun.git_utils import get_repo

REPO = get_repo()
sys.path.append(f"{REPO.working_dir}")
RUNS_DIR = f"{REPO.working_dir}/file_storage/runs"

import utils
from matrix_distributions.matrix_distributions import MATRIX_DISTRIBUTIONS
from make_algorithm.actions import Action, MatrixActionInput, ACTIONS
from make_algorithm.losses import LOSSES
from globals import PROOT, SIGN


def warmup_matmul(device="cpu", runs=10, size=500):
    A = torch.randn(size, size, device=device)
    B = torch.randn(size, size, device=device)

    for _ in range(runs):
        _ = torch.matmul(A, B)
        _ = torch.linalg.inv(A)
        if device == "cuda":
            torch.cuda.synchronize()

def matrix_loss(X, A, matrix_func, custom_loss = False):
    
    I = torch.eye(A.size(0)).to(A.device)
    if matrix_func == "sign":
        if custom_loss:
            S = torch.linalg.eig(X)[0].real
            return torch.sum(torch.maximum(torch.abs(torch.abs(S) - 1) - 0.3, torch.zeros_like(S)))
        return torch.linalg.norm(X@X - I)
    elif matrix_func == "sqrt":
        return torch.linalg.norm(X@X - A)
    elif matrix_func == "inv":
        return torch.linalg.norm(X@A - I)
    else:
        return torch.linalg.norm(X@X@X - A)

def adapter(config, _run):

    seed = config.get("seed", 42)
    if config["precision"] == "double":
        torch.set_default_dtype(torch.float64)
    torch.set_grad_enabled(False)
    utils.set_all_seeds(seed)

    device = config["device"]
    print(f"Using device: {device}")

    make_algorithm_run_id = config["make_algorithm_run_id"]
    exp = incense.experiment.FileSystemExperiment.from_run_dir(
        Path(f"{RUNS_DIR}/{make_algorithm_run_id}")
    )

    print(f"Make algorithm experiment found: {exp}")
    print(f"Make algorithm hash: {exp.config.make_algorithm_hash}")

    art = exp.artifacts["learned_params.pkl"]
    with open(art.file, "rb") as f:
        # Load the learned parameters
        learned_params = pickle.load(f)

    test_mat_name = config["test_mat_name"]
    test_mat_config = config["test_mat_config"]
    matrix_function = config["matrix_function"]


    A = MATRIX_DISTRIBUTIONS[test_mat_name](**test_mat_config)
    A = A.to(device)
    print(f"Matrix A shape: {A.shape}")
    I = torch.eye(A.size(0)).to(A.device)
    current_matrix = A.clone()

    if matrix_function == SIGN:
        U, _, Vt = torch.linalg.svd(A)
        SignA = U @ Vt

    time_so_far = 0
    if matrix_function != PROOT:
        current_matrices = tuple(
            [
                current_matrix, I
            ]
        )
    else:
        current_matrices = tuple(
            [
                I, current_matrix
            ]
        )

    _run.log_scalar("time", 0)
    _run.log_scalar("loss", (matrix_loss(current_matrices[0], A, matrix_function, config['custom_loss'])/torch.linalg.norm(A)).item())
   
    for i, (action, theta) in enumerate(learned_params):

        inp = MatrixActionInput(
            current_matrices=current_matrices,
            a_matrix=A,
            theta=theta,
        )

        warmup_matmul(device=device)
        current_matrices, t = action.matrix_iteration(inp)
        time_so_far += t
        _run.log_scalar("time", time_so_far)

        loss = (matrix_loss(current_matrices[0], A, matrix_function, config['custom_loss'])/torch.linalg.norm(A)).item()
        _run.log_scalar("loss", loss)

        print(
            f"Step {i+1}, loss: {loss:.3e}, time so far: {time_so_far:.3e},"
            + f" action: {action.name}, theta: {theta}"
        )
        if time_so_far > 15:
            print(f"Elapsed time exceeded threshold at step {i+1}, stopping execution.")
            return 1
        
        if loss > 1e10:
            print(f"Loss exceeded threshold at step {i+1}, stopping execution.")
            return 1
        elif loss < 1e-11:
            return 0
    return 0


adapter.experiment_name = "test_algorithm"

if __name__ == "__main__":
    import test_algorithm.test_algo_config as test_algo_config
    import importlib

    importlib.reload(test_algo_config)

    config = test_algo_config.config
    run_sacred_experiment(adapter, test_algo_config.config)
