from sorcerun.sacred_utils import run_sacred_experiment
from functools import partial
import os
import pickle
from sorcerun.git_utils import get_repo
import sys

import torch


torch.set_grad_enabled(False)
import numpy as np

from torch.profiler import ProfilerActivity, profile, record_function


ROOT = get_repo().working_dir
sys.path.append(ROOT)
os.makedirs(f"{ROOT}/temp", exist_ok=True)
from globals import MCTS
from make_algorithm.losses import LOSSES
from matrix_distributions.matrix_distributions import MATRIX_DISTRIBUTIONS
from globals import _sign, _sqrt, _inv, _proot

from make_algorithm.actions import (
    ACTIONS,
    estimate_relative_times,
)
import utils
from make_algorithm.mcts import (
    extract_policy,
    run_mcts,
    Tree,
    populate_tree,
    add_rollout_baseline_to_tree,
)
from make_algorithm.baselines import (
    baselines,
    adaptive_baselines,
    DEFAULT_BASELINE_CONFIGS,
)


def adapter(config, _run):

    torch.set_grad_enabled(False)
    if config["precision"] == "double":
        torch.set_default_dtype(torch.float64)
    # seed = config.get("seed", 42)
    seed = config["seed"]
    utils.set_all_seeds(seed)

    matrix_function = config["matrix_function"]
    init_mat_name = config["init_mat_name"]
    init_mat_config = config["init_mat_config"]

    A = MATRIX_DISTRIBUTIONS[init_mat_name](**init_mat_config)
    print(f"Matrix distribution: {init_mat_name}")
    print(f"Matrix distribution config: {init_mat_config}")
    print(f"Initial matrix shape: {A.shape}")

    algorithm_name = config["algorithm_name"]
    algorithm_config = config["algorithm_config"]

    print(f"Algorithm name: {algorithm_name}")
    print(f"Algorithm config: {algorithm_config}")

    if MCTS in algorithm_name:
        A = A.to(algorithm_config["device"])
        spec0 = torch.linalg.eigh(A)[0].cpu().numpy()
        spec0 = spec0.astype(np.float64)

        # update theta bounds from config
        actions = []
        single_actions = []
        couplingaction = None
        for action_name, theta_bounds in algorithm_config["actions"]:
            action = ACTIONS[action_name]
            if theta_bounds:  # if None, use default
                action.theta_bounds = np.array(theta_bounds)
            actions.append(action)
            if action.num_matrices == 1 or action.name == _sqrt("couple") or action.name == _proot("couple"):
                single_actions.append(action)
            if action.name == _sqrt("couple") or action.name == _proot("couple"):
                couplingaction = action
        actions = sorted(
            actions,
            key=lambda a: a.name,
        )
        print(f"Actions: {[a.name for a in actions]}")

        # get times
        estimate_relative_times(
            actions,
            size=spec0.shape[0],
            device=algorithm_config["device"],
        )

        # setup MCTS config
        epsilon = algorithm_config["epsilon"]
        early_termination_epsilon = algorithm_config["early_termination_epsilon"]
        termination_predicate = (
            lambda current_state: LOSSES[matrix_function](current_state[0], spec0, config["custom_loss"])
            < epsilon
        )
        alpha_pw = algorithm_config["alpha_pw"]
        budget = algorithm_config["budget"]
        print_every = algorithm_config["print_every"]
        max_termination_count = algorithm_config["max_termination_count"]
        tree_initial_capacity = algorithm_config["tree_initial_capacity"]
        c_ucb = algorithm_config["c_ucb"]
        EXPLORE_K = algorithm_config["EXPLORE_K"]
        initialize_with_baselines = algorithm_config["initialize_with_baselines"]

        # create the tree

        if matrix_function != "proot":
            tree = Tree(
                capacity=tree_initial_capacity,
                state0=(spec0, np.ones_like(spec0)),
                term_pred=termination_predicate,
                alpha_pw=alpha_pw,
                actions=actions,
                single_actions = single_actions,
                couplingaction = couplingaction
            )
        else:
            tree = Tree(
                capacity=tree_initial_capacity,
                state0=(np.ones_like(spec0), spec0),
                term_pred=termination_predicate,
                alpha_pw=alpha_pw,
                actions=actions,
                single_actions = single_actions,
                couplingaction = couplingaction
            )

        if initialize_with_baselines:
            for baseline_name, baseline_config in DEFAULT_BASELINE_CONFIGS.items():
                algorithm = baselines(baseline_name, modify_time = False, **baseline_config)
                baseline_action_names = set(a[0].name for a in algorithm)
                available_action_names = set(a.name for a in actions)
                # check if all actions are in the tree
                if not baseline_action_names.issubset(available_action_names):
                    print(
                        f"Skipping {baseline_name} because not all actions are in the tree"
                    )
                    continue

                # update the actions in algorithm to have the same time as the
                # actions in the tree
                algorithm = [(ACTIONS[a.name], theta) for a, theta in algorithm]
                
                add_rollout_baseline_to_tree(
                    tree,
                    algorithm,
                )
                '''
                r_roll, cumulative_reward = populate_tree(
                    tree,
                    algorithm,
                )
                
                print(
                    f"Populated tree with {baseline_name}: "
                    f"cumulative reward: {cumulative_reward} "
                    f"rollout reward: {r_roll} ",
                )
                '''

        # extract_policy(tree, criterion="value")
        # run MCTS
        (
            root_policy_visits,
            root_policy_value,
            root_policy_best_value,
            visits_reward,
            value_reward,
            best_value_reward,
        ) = run_mcts(
            tree,
            budget=budget,
            print_every=print_every,
            max_termination_count=max_termination_count,
            termination_epsilon=early_termination_epsilon,
            c_ucb=c_ucb,
            EXPLORE_K=EXPLORE_K,
            _run=_run,
        )

        print("=== RESULT ===")
        print(f"By Visits: estimated reward: {visits_reward}")
        for act, theta in root_policy_visits:
            print(act.name, theta)

        print(f"By Value: estimated reward: {value_reward}")
        for act, theta in root_policy_value:
            print(act.name, theta)

        print(f"By Best Value: estimated reward: {best_value_reward}")
        for act, theta in root_policy_best_value:
            print(act.name, theta)

        best_root_policy = sorted(
            [
                (root_policy_visits, visits_reward),
                (root_policy_value, value_reward),
                (root_policy_best_value, best_value_reward),
            ],
            key=lambda x: x[1],
        )[-1][0]

        print("Best value of root node:")
        print(tree.best_value[0])

        print("Best root policy:")
        for act, theta in best_root_policy:
            print(act.name, theta)
        print("=== END ===")

        learned_params = best_root_policy

    elif not "greedy" in algorithm_name and not algorithm_name == _sqrt("scaled_db"):
        learned_params = baselines(algorithm_name, **algorithm_config)
    else:
        learned_params = adaptive_baselines(algorithm_name, A, **algorithm_config)

    path_to_learned_params = f"{ROOT}/temp/learned_params.pkl"
    with open(path_to_learned_params, "wb") as f:
        pickle.dump(learned_params, f)
    _run.add_artifact(path_to_learned_params, name="learned_params.pkl")
    print("Finished making algorithm")

    #import ipdb; ipdb.set_trace() # fmt: skip
    return 0


adapter.experiment_name = "make_algorithm"

if __name__ == "__main__":
    import make_algorithm.make_algo_config as make_algo_config
    import importlib

    importlib.reload(make_algo_config)
    config = make_algo_config.config

    PROFILE = False

    if PROFILE:
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            profile_memory=True,  # capture allocs
            record_shapes=False,
            # with_stack=True,  # <-   off
            # schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
        ) as prof:
            _run = run_sacred_experiment(adapter, make_algo_config.config)

        print("Profiler results:")
        ke = prof.key_averages()
        prof.export_chrome_trace(f"{ROOT}/temp/trace.json")
        _run.add_artifact(f"{ROOT}/temp/trace.json", name="trace.json")
        print("Trace exported to temp/trace.json")
        print(ke.table(sort_by="self_cuda_time_total", row_limit=20))

    else:

        import time
        start = time.time()
        _run = run_sacred_experiment(adapter, make_algo_config.config)
        end = time.time()

        
        if config['algorithm_name'] == 'mcts':
            with open("../times.txt","a") as f:
                f.write(f"Finished making algorithm in {end - start} seconds\n")