from tqdm import tqdm
from termcolor import colored

import sys
sys.path.append("./")
from src.mc import *
from src.query_strategies import *
from src.plotting import create_plots

torch.set_default_dtype(torch.float64)

query_strategy = random_with_random
sampling_strategy = sobol

import argparse

parser = argparse.ArgumentParser(description='the exponential search algorithm')
parser.add_argument('-num_runs', nargs="?", default=50, type=int, help='number of runs')
parser.add_argument('-num_samples', nargs="?", default=1000, type=int, help='number of samples for numerical integration')
parser.add_argument('-gamma', nargs="?", default=5, type=int, help='gamma')
parser.add_argument('-num_dims', nargs="?", default=2, type=int, help='embedding dim')
parser.add_argument('-threshold_backtrack', nargs="?", default=0.92, type=float, help='threshold for backtracking')
parser.add_argument('-threshold_proceed', nargs="?", default=0.85, type=float, help='threshold for proceeding')
parser.add_argument('-proceed_factor', nargs="?", default=0.7, type=float, help='zoom in factor')
parser.add_argument('-stopping_width', nargs="?", default=0.0005, type=float, help='width at which we stop the search')
parser.add_argument('-slack_width', nargs="?", default=1.25, type=float, help='slack width')
parser.add_argument('-sample_width', nargs="?", default=1.3, type=float, help='query region width')
parser.add_argument('-far_width', nargs="?", default=1.5, type=float, help='far width')
parser.add_argument('-plotdir', nargs="?", default='../plots/', type=str, help='directory for plots')

args = parser.parse_args()

config = vars(args)

torch.manual_seed(2398745)
np.random.seed(0)
torch.autograd.set_detect_anomaly(True)

# bring parameters into scope
gamma = config["gamma"]
num_dims = config["num_dims"]
slack_width = config["slack_width"]
sample_width = config['sample_width']
far_width = config["far_width"]
proceed_factor = config["proceed_factor"]
backtrack_factor = (proceed_factor * 0.5 + 2) / proceed_factor
threshold_backtrack = config["threshold_backtrack"]
threshold_proceed = config["threshold_proceed"]
num_runs = config["num_runs"]
num_samples = config["num_samples"]
stopping_width = config['stopping_width']
plotdir = config['plotdir']

all_runs = []
debug_output = False

for run in tqdm(range(num_runs)):

    distance_from_target = []

    # at start, these regions are all the same
    center = torch.zeros((1, num_dims))
    width = 1

    # sample target
    target = unif(1, center, 0.8, num_dims)

    outerbreak = False

    # this is one run
    max_stages = 100000

    for s in range(max_stages):

        # outcomes are reset at each stage
        outcomes = []

        # we don't zoom out beyond the initial region
        width = min(width, 2)
        center = torch.clamp(center, -1, 1)

        # we end this run, once we have zoomed in past a certain threshold
        if width < 0.0001:
            break

        if outerbreak:
            break

        s_width = width * slack_width
        q_width = width * sample_width
        f_width = width * far_width
        width_child = width * proceed_factor

        samples, IS_pdf = sampling_strategy(num_samples, center, f_width, num_dims)
        loglik = torch.zeros((num_samples, 1))

        # the first outcome we sample at random
        # ToDo: this might not be a good idea, better to ask one specific optimized query
        new_outcomes = generate_outcomes(target, center, q_width, num_dims, num_outcomes=1, gamma=5)
        outcomes = []

        # we check the color of the current node, this can be used for debug printing
        target_in_R = torch.linalg.norm(target - center, axis=-1) < width
        target_in_E = torch.linalg.norm(target - center, axis=-1) < s_width
        color = "yellow" if target_in_E else "red"
        color = "green" if target_in_R else color

        if debug_output:
            print(colored(f"width: {width}", color))

        max_queries = 100000
        for num_query in range(max_queries):
            distance_from_target.append(torch.linalg.norm(target - center))

            # calculate likelihood
            loglik += log_likelihood(samples, new_outcomes, gamma)
            lik = torch.exp(loglik) * 1 / IS_pdf
            C = sum(lik)
            lik /= C

            outcomes += new_outcomes
            new_outcomes = []

            # calculate probability mass in R, slack and far
            inside_R = torch.linalg.norm(samples - center, axis=-1) < width
            inside_RE = torch.linalg.norm(samples - center, axis=-1) < s_width
            inside_EF = torch.linalg.norm(samples - center, axis=-1) > width
            inside_F = inside_RE.logical_not()
            samples_R = samples[inside_R]
            samples_RE = samples[inside_RE]
            samples_EF = samples[inside_EF]
            samples_F = samples[inside_F]
            probas_R = lik[inside_R]
            probas_RE = lik[inside_RE]
            probas_EF = lik[inside_EF]
            probas_F = lik[inside_F]

            # center of the region we'd proceed to
            child_center, _ = mean_and_cov(samples_R, probas_R, compute_cov=False)
            # child_idx = torch.argmax(probas_RE)
            # child_center = samples_RE[child_idx]

            child_mask = torch.linalg.norm(samples - child_center, axis=-1) < width_child
            child_probas = lik[child_mask]

            if debug_output:
                print(f"proba best guess: {child_probas.sum()}")
                print(f"distance best guess: {torch.linalg.norm(child_center - target)}")
                print(f"proba R: {probas_R.sum()}")
                print(f"proba RE: {probas_RE.sum()}")
                print(f"proba F: {probas_F.sum()}")

            # generate query points
            A, B = query_strategy(
                samples_R, samples_RE, samples_F,
                probas_R, probas_RE, probas_F,
                center,
                width, s_width, q_width, f_width,
                num_dims, project=False)

            # observe oracle reply, this reorders the points such that A > B
            A, B = oracle(A, B, target, gamma)
            new_outcomes += [(A, B)]

            # the region inside R has become much more unlikely than the region outside
            if probas_EF.sum() > threshold_backtrack:
                if debug_output:
                    print("backtracking")

                width *= backtrack_factor
                break

            # the inside of R has become much more likely
            if child_probas.sum() > threshold_proceed:  # or i == max_queries - 1:
                if debug_output:
                    print('proceeding')

                # expected value of samples in RE
                center = child_center

                # reducing by a constant factor, here we could also go faster
                width *= proceed_factor

                break

    all_runs.append(distance_from_target)

    if outerbreak:
        break

create_plots(all_runs, config)
