import tools
import tqdm
import random
import torch
import numpy as np
import glob

def create_manual_cost_function(configuration):
    class ManualCostFunction(tools.base.Function):
        beta = configuration["beta"]
        discount_factor = configuration["discount_factor"]
        def __call__(self, sa):
            s, a = sa
            if configuration["cost_condition"](s, a):
                return 1.
            return 0.
    manual_cost = ManualCostFunction()
    return manual_cost

def get_tuples(x):
    if type(x) == tuple:
        return [x]
    elif hasattr(x, "__len__") and type(x) != str:
        tuples = []
        for xx in x:
            tuples += get_tuples(xx)
        return tuples
    else:
        print("got ", x)
        exit(0)

def train(config_file, arch, seed, epochs=None):
    update_params = {"seed": seed}
    configuration = tools.data.Configuration.from_json(config_file, update_params)
    state_action_space = tools.environments.get_state_action_space(
        configuration["env_type"], configuration["env_id"])
    configuration.update({"state_action_space": state_action_space})
    manual_cost = create_manual_cost_function(configuration)
    reduced_space = [item[0] for item in state_action_space]
    state_action_pairs = get_tuples(reduced_space)
    cost = tools.functions.CostFunction(configuration, i=configuration["i"], h=1, o=1) # readjusted next
    arch[0][0] = configuration["i"]
    cost.Cost = configuration["t"].nn(arch)
    cost.Opt = configuration["t"].adam(cost.Cost, 
        configuration["learning_rate"])
    losses = []
    manualcostvalues, _ = \
        manual_cost.outputs(configuration["state_action_space"])
    manualcostvalues = np.array(manualcostvalues).squeeze()
    # if epochs is None:
    #     epochs = 2*configuration["outer_epochs"]*configuration["updates_per_epoch"]
    for epoch in range(epochs):
        chosen = state_action_pairs
        processed = []
        ground_truths = []
        for s, a in chosen:
            ground_truths += [manual_cost((s, a))]
            s = cost.state_reduction(s)
            a = cost.action_reduction(a)
            sa = configuration["t"].f(cost.input_format(s, a))
            processed += [sa]
        batch = torch.stack(processed).to(configuration["t"].device)
        ground_batch = configuration["t"].f(ground_truths)
        predicted = cost.Cost(batch).view(-1)
        cost.Opt.zero_grad()
        # print(predicted, ground_batch)
        loss = torch.nn.BCELoss()(predicted, ground_batch)
        losses += [loss.item()]
        loss.backward()
        cost.Opt.step()
    # print(losses[0], losses[-1])
    costvalues, _ = cost.outputs(configuration["state_action_space"])
    costvalues = np.array(costvalues).squeeze()
    cmp = configuration["cost_comparison"](manualcostvalues, costvalues)
    return cmp

configs = \
    glob.glob("code_ppo_pen/configs/gridworld*.json")+\
    glob.glob("code_ppo_pen/configs/cartpole*.json") +\
    glob.glob("code_ppo_lag/configs/ant.json") +\
    glob.glob("code_ppo_lag/configs/hc.json")
archs = {
    "A": [[None, 32], "r", [32, 32], "r", [32, 1], "s"],
    "B": [[None, 64], "r", [64, 64], "r", [64, 1], "s"],
    "C": [[None, 128], "r", [128, 128], "r", [128, 1], "s"],
    "D": [[None, 64], "r", [64, 1], "s"],
    "E": [[None, 64], "r", [64, 64], "r", [64, 64], "r", [64, 1], "s"],
}
d = {}
for config in configs:
    d[config] = {}
    for archname, arch in archs.items():
        d[config][archname] = []
        for seed in [1,2,3,4,5]:
            d[config][archname] += [train(config, arch, seed=seed, epochs=1000)]
        d[config][archname] = "%.2f ± %.2f" % (np.mean(d[config][archname]), np.std(d[config][archname]))
        # print(config, archname, d[archname])
import pprint
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(d)