import argparse
import os
import sys

import jax.random
from gym import spaces
from tqdm import tqdm

from compositional_envs import RoomMetaEnv, RoomEnv
from rsm_utils import (
    lipschitz_l1_jax,
    triangular,
    pretty_time,
    pretty_number,
    lipschitz_linf_jax,
)
from rsm_learner import RSMLearner
from rsm_verifier import RSMVerifier, get_n_for_bound_computation
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import jax.numpy as jnp

from vrl_environments import (
    vLDSEnv,
    vHumanoidBalance2,
    vInvertedPendulum,
    vCollisionAvoidanceEnv,
    vDroneEnv,
    vMazeEnv,
    vDebugEnv,
)


class RSMLoop:
    def __init__(
        self,
        learner,
        verifier,
        env,
        plot,
        soft_constraint,
        train_p,
        min_iters,
    ):
        """
        Implementation of the Learer-Verifier loop
        :param learner: The Learner module
        :param verifier: The Verifier module
        :param env: The environment
        :param plot:
        :param soft_constraint: If True, only the expected decrease condition will be checked,
            reach-avoid probability will be optimized in the training and computed but not enforced,
            i.e., the loop will terminate even if the reach-avoid threshold specified in the verifer module is not
            achieved.
        :param train_p: Integer indicated after which loop iteration the policy parameters will be optimized.
            Note that it is recommended to not optimized the policy parameters for the first few iterations.
            This is because the policy is pre-trained using PPO but the RSM is randomly initialized.
            If None, the policy is frozen
        :param min_iters: Integer specified the minimum number of iterations that the loop should run.
            Note that it may be useful in some cases to continue the loop for a few iterations to further improve the
            high reach-avoid probability.
        """
        self.env = env
        self.learner = learner
        self.verifier = verifier
        self.train_p = train_p
        self.min_iters = min_iters
        self.soft_constraint = soft_constraint

        os.makedirs("saved", exist_ok=True)
        os.makedirs("plots", exist_ok=True)
        os.makedirs("loop", exist_ok=True)
        self._current_lipschitz_k = 1.0
        self.plot = plot
        self.prefill_delta = 0
        self.iter = 0
        self.info = {}

    def learn(self):
        train_ds = self.verifier.train_buffer.as_tfds(batch_size=2 * 4096)
        current_delta = self.prefill_delta
        current_delta = 0
        start_metrics = None
        num_epochs = (
            50 if self.iter > 0 else 200
        )  # in the first iteration we train a bit longer

        start_time = time.time()
        pbar = tqdm(total=num_epochs, unit="epochs")
        for epoch in range(num_epochs):
            # in the first 3 iterations we only train the RSM
            train_p = self.train_p > 0 and self.iter >= self.train_p

            # we always train the RSM
            train_v = True
            lipschitz_k = jnp.clip(
                self._current_lipschitz_k / self.verifier.grid_size, 0, 0.05
            )
            lipschitz_k = 0
            metrics = self.learner.train_epoch(
                train_ds, current_delta, lipschitz_k, train_v, train_p
            )
            if start_metrics is None:
                start_metrics = metrics
            pbar.update(1)
            pbar.set_description_str(
                f"Train [v={train_v}, p={train_p}]: loss={metrics['loss']:0.3g}, dec_loss={metrics['dec_loss']:0.3g}, violations={metrics['train_violations']:0.3g}"
            )
        pbar.close()
        self.info["ds_size"] = len(self.verifier.train_buffer)

        training_time = pretty_time(time.time() - start_time)

        print(
            f"Trained on {pretty_number(len(self.verifier.train_buffer))} samples, start_loss={start_metrics['loss']:0.3g}, end_loss={metrics['loss']:0.3g}, start_violations={start_metrics['train_violations']:0.3g}, end_violations={metrics['train_violations']:0.3g} in {training_time}"
        )

    def check_decrease_condition(self):
        if self.verifier.norm == "l1":
            K_f = self.env.lipschitz_constant
            K_p = lipschitz_l1_jax(self.learner.p_state.params).item()
            K_l = lipschitz_l1_jax(self.learner.v_state.params).item()
        else:
            K_f = self.env.lipschitz_constant_linf
            K_p = lipschitz_linf_jax(self.learner.p_state.params).item()
            K_l = lipschitz_linf_jax(self.learner.v_state.params).item()

        lipschitz_k = K_l * K_f * (1 + K_p) + K_l
        lipschitz_k = float(lipschitz_k)
        self._current_lipschitz_k = lipschitz_k
        self.log(lipschitz_k=lipschitz_k)
        self.log(K_p=K_p)
        self.log(K_f=K_f)
        self.log(K_l=K_l)
        (
            violations,
            hard_violations,
            max_decrease,
            max_decay,
        ) = self.verifier.check_dec_cond(lipschitz_k)
        self.log(violations=int(violations))
        self.log(hard_violations=int(hard_violations))
        self.log(max_decrease=max_decrease)
        self.log(max_decay=max_decay)

        if violations == 0:
            return True
        if hard_violations == 0 and self.iter > 2:
            if self.env.observation_space.shape[0] == 2:
                # in 2D -> double grid size
                self.verifier.grid_size *= 2
            elif self.env.observation_space.shape[0] == 3:
                # in 3D -> increase grid size by 50%
                expand_factor = 1.1 if self.verifier.fail_check_fast else 1.5
                self.verifier.grid_size = int(expand_factor * self.verifier.grid_size)
                self.verifier.grid_size = min(self.verifier.grid_size, 1024)
            else:
                # increase grid size by 30%
                self.verifier.grid_size = int(1.3 * self.verifier.grid_size)
            print(f"Increasing grid resolution -> {self.verifier.grid_size}")

        return False

    def verify(self):
        dec_sat = self.check_decrease_condition()

        if dec_sat:
            print("Decrease condition fulfilled!")
            # self.learner.save(f"saved/{self.env.name}_loop.jax")
            # print("[SAVED]")
            n = get_n_for_bound_computation(self.env.observation_dim)
            _, ub_init = self.verifier.compute_bound_init(n)
            lb_unsafe, _ = self.verifier.compute_bound_unsafe(n)
            lb_domain, _ = self.verifier.compute_bound_domain(n)
            self.log(ub_init=ub_init)
            self.log(lb_unsafe=lb_unsafe)
            self.log(lb_domain=lb_domain)

            if lb_unsafe < ub_init:
                print(
                    "WARNING: RSM is lower at unsafe than in init. No Reach-avoid guarantees can be obtained."
                )
            # normalize to lb_domain -> 0
            ub_init = ub_init - lb_domain
            lb_unsafe = lb_unsafe - lb_domain
            # normalize to ub_init -> 1
            lb_unsafe = lb_unsafe / ub_init
            actual_reach_prob = 1 - 1 / np.clip(lb_unsafe, 1e-9, None)
            self.log(old_reach_prob=actual_reach_prob)
            self.log(actual_reach_prob=actual_reach_prob)

            # lambda = lb_unsafe
            # LV = self.info["K_l"]
            # epsilon = self.info["max_decrease"]
            # delta = self.env.delta

            num = -2 * (-self.info["max_decrease"]) * (lb_unsafe - 1)
            denom = np.square(self.info["K_l"]) * np.square(self.env.delta)
            other_reach_prob = 1 - np.exp(num / np.clip(denom, 1e-9, None))
            self.log(other_reach_prob=other_reach_prob)
            best_reach_bound = np.maximum(actual_reach_prob, other_reach_prob)

            N = np.floor((lb_unsafe - 1) / (self.info["K_l"] * self.env.delta))
            improved_bound = 1 - 1 / np.clip(lb_unsafe, 1e-9, None) * (
                self.info["max_decay"] ** N
            )
            self.log(improved_bound=improved_bound)
            best_reach_bound = np.maximum(best_reach_bound, improved_bound)

            with open("log_new_bound", "a") as f:
                f.write(f"\n#### {self.env.name} ####\n")
                f.write(f"orig_bound     = {actual_reach_prob}\n")
                f.write(f"lambda  = {lb_unsafe:0.4g}\n")
                f.write(f"epsilon = {-self.info['max_decrease']:0.4g}\n")
                f.write(f"LV      = {self.info['K_l']:0.4g}\n")
                f.write(f"delta   = {self.env.delta:0.4g}\n")
                f.write(f"num     = {num:0.4g}\n")
                f.write(f"denom   = {denom:0.4g}\n")
                f.write(f"frac    = {num / np.clip(denom, 1e-9, None):0.4g}\n")
                f.write(f"exp     = {np.exp(num / np.clip(denom, 1e-9, None)):0.4g}\n")
                f.write(f"bound   = {other_reach_prob:0.4g}\n")
                f.write(f"------------------------------\n")
                f.write(f"max_decay      = {self.info['max_decay']:0.4g}\n")
                f.write(f"N              = {N}\n")
                f.write(f"improved_bound = {improved_bound}\n")

            if (
                self.soft_constraint or best_reach_bound >= self.verifier.reach_prob
            ) and self.iter >= self.min_iters:
                return best_reach_bound
        return None

    def log(self, **kwargs):
        for k, v in kwargs.items():
            self.info[k] = v

    def run(self, timeout):
        start_time = time.time()
        last_saved = time.time()
        self.prefill_delta = self.verifier.prefill_train_buffer()
        while True:
            runtime = time.time() - start_time
            self.log(runtime=runtime)
            self.log(iter=self.iter)

            if runtime > timeout:
                print("Timeout!")
                # self.learner.save(f"saved/{self.env.name}_loop.jax")
                return False
            # if time.time() - last_saved > 60 * 60:
            # save every hour
            # last_saved = time.time()
            # self.learner.save(f"saved/{self.env.name}_loop.jax")
            # print("[SAVED]")

            print(
                f"\n#### Iteration {self.iter} (runtime: {pretty_time(runtime)}) #####"
            )
            self.learn()

            actual_reach_prob = self.verify()

            print("Log=", str(self.info))
            sys.stdout.flush()

            if actual_reach_prob is not None and actual_reach_prob > 0:
                print(
                    f"Probability of reaching the target safely is at least {actual_reach_prob * 100:0.3f}% (higher is better)"
                )
                return True

            if self.plot:
                self.plot_l(f"loop/{self.env.name}_{self.iter:04d}.png")
            self.iter += 1

    def rollout(self, seed=None):
        if seed is None:
            seed = np.random.default_rng().integers(0, 10000)
        rng = jax.random.PRNGKey(seed)
        space = spaces.Box(
            low=self.env.observation_space.low,
            high=self.env.observation_space.high,
            dtype=np.float32,
        )
        # space = spaces.Box(
        #     low=self.env.init_spaces[0].low,
        #     high=self.env.init_spaces[0].high,
        #     dtype=np.float32,
        # )

        state = space.sample()
        trace = [np.array(state)]
        for i in range(200):
            action = self.learner.p_state.apply_fn(self.learner.p_state.params, state)
            next_state = self.env.next(state, action)
            rng, seed = jax.random.split(rng)
            noise = triangular(rng, (self.env.observation_dim,))
            noise = noise * self.env.noise
            state = next_state + noise
            trace.append(np.array(state))
        return np.stack(trace, axis=0)

    def plot_l(self, filename):
        if self.env.observation_dim > 2:
            return
        grid, _, _ = self.verifier.get_unfiltered_grid(n=50)
        l = self.learner.v_state.apply_fn(self.learner.v_state.params, grid).flatten()
        l = np.array(l)
        # l = np.clip(l, 0, 5)
        # np.savez(f"plots/{env.name}.npz", grid=grid, l=l)
        sns.set()
        fig, ax = plt.subplots(figsize=(6, 6))
        sc = ax.scatter(grid[:, 0], grid[:, 1], marker="s", c=l, zorder=1, alpha=0.7)
        fig.colorbar(sc)
        ax.set_title(f"L at iter {self.iter} for {self.env.name}")

        terminals_x, terminals_y = [], []
        for i in range(30):
            trace = self.rollout(seed=i)
            ax.plot(
                trace[:, 0],
                trace[:, 1],
                color=sns.color_palette()[0],
                zorder=2,
                alpha=0.3,
            )
            ax.scatter(
                trace[:, 0],
                trace[:, 1],
                color=sns.color_palette()[0],
                zorder=2,
                marker=".",
            )
            terminals_x.append(float(trace[-1, 0]))
            terminals_y.append(float(trace[-1, 1]))
        ax.scatter(terminals_x, terminals_y, color="white", marker="x", zorder=5)
        if self.verifier.hard_constraint_violation_buffer is not None:
            # print(
            #     "self.verifier.hard_constraint_violation_buffer: ",
            #     self.verifier.hard_constraint_violation_buffer[0:10],
            # )
            ax.scatter(
                self.verifier.hard_constraint_violation_buffer[:, 0],
                self.verifier.hard_constraint_violation_buffer[:, 1],
                color="green",
                marker="s",
                alpha=0.7,
                zorder=6,
            )
        if self.verifier._debug_violations is not None:
            ax.scatter(
                self.verifier._debug_violations[:, 0],
                self.verifier._debug_violations[:, 1],
                color="cyan",
                marker="s",
                alpha=0.7,
                zorder=6,
            )
        for init in self.env.init_spaces:
            x = [
                init.low[0],
                init.high[0],
                init.high[0],
                init.low[0],
                init.low[0],
            ]
            y = [
                init.low[1],
                init.low[1],
                init.high[1],
                init.high[1],
                init.low[1],
            ]
            ax.plot(x, y, color="cyan", alpha=0.5, zorder=7)
        for unsafe in self.env.unsafe_spaces:
            x = [
                unsafe.low[0],
                unsafe.high[0],
                unsafe.high[0],
                unsafe.low[0],
                unsafe.low[0],
            ]
            y = [
                unsafe.low[1],
                unsafe.low[1],
                unsafe.high[1],
                unsafe.high[1],
                unsafe.low[1],
            ]
            ax.plot(x, y, color="magenta", alpha=0.5, zorder=7)
        for target_space in self.env.target_spaces:
            x = [
                target_space.low[0],
                target_space.high[0],
                target_space.high[0],
                target_space.low[0],
                target_space.low[0],
            ]
            y = [
                target_space.low[1],
                target_space.low[1],
                target_space.high[1],
                target_space.high[1],
                target_space.low[1],
            ]
            ax.plot(x, y, color="green", alpha=0.5, zorder=7)

        # if len(self.learner._debug_unsafe) > 0:
        #     init_samples = np.concatenate(self.learner._debug_init, axis=0)
        #     unsafe_samples = np.concatenate(self.learner._debug_unsafe, axis=0)
        #     ax.scatter(
        #         unsafe_samples[:, 0],
        #         unsafe_samples[:, 1],
        #         color="red",
        #         marker="x",
        #         alpha=0.1,
        #         zorder=7,
        #     )
        #     ax.scatter(
        #         init_samples[:, 0],
        #         init_samples[:, 1],
        #         color="green",
        #         marker="x",
        #         alpha=0.1,
        #         zorder=7,
        #     )
        #     self.learner._debug_init = []
        #     self.learner._debug_unsafe = []
        # print(f"Terminals x={terminals_x}, y={terminals_y}")
        ax.set_xlim(
            [self.env.observation_space.low[0], self.env.observation_space.high[0]]
        )
        ax.set_ylim(
            [self.env.observation_space.low[1], self.env.observation_space.high[1]]
        )
        fig.tight_layout()
        fig.savefig(filename)
        plt.close(fig)


def interpret_size_arg(cmd):
    """
    Converts a string with multiplications into an integer
    e.g., "2*8*1" -> 16
    """
    parts = cmd.split("*")
    bs = 1
    for p in parts:
        if "k" in p or "K" in p:
            p = p.replace("k", "").replace("K", "")
            bs *= 1024 * int(p)
        elif "M" in p:
            p = p.replace("M", "")
            bs *= 1024 * 1024 * int(p)
        elif "G" in p:
            p = p.replace("G", "")
            bs *= 1024 * 1024 * 1024 * int(p)
        else:
            bs *= int(p)
    return bs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="lds0")
    parser.add_argument("--timeout", default=8 * 60, type=int)  # in minutes
    parser.add_argument("--reach_prob", default=0.3, type=float)
    parser.add_argument("--eps", default=0.001, type=float)
    parser.add_argument("--lip_lambda", default=0.001, type=float)
    parser.add_argument("--p_lip", default=3.0, type=float)
    parser.add_argument("--v_lip", default=8.0, type=float)
    parser.add_argument("--hidden", default=128, type=int)
    parser.add_argument("--min_iters", default=0, type=int)
    parser.add_argument("--num_layers", default=2, type=int)
    parser.add_argument("--batch_size", default="16k")
    parser.add_argument("--ppo_iters", default=50, type=int)
    parser.add_argument("--n_step", default=1, type=int)
    parser.add_argument("--v_act", default="relu")
    parser.add_argument("--norm", default="linf")
    parser.add_argument("--ds_type", default="hard")
    parser.add_argument("--policy", default="policies/lds0_zero.jax")
    parser.add_argument("--debug_k0", action="store_true")
    parser.add_argument("--gen_plot", action="store_true")
    parser.add_argument("--no_refinement", action="store_true")
    parser.add_argument("--plot", action="store_true")
    parser.add_argument("--load_sb", default="")
    parser.add_argument("--skip_ppo", action="store_true")
    parser.add_argument("--continue_ppo", action="store_true")
    parser.add_argument("--only_ppo", action="store_true")
    parser.add_argument("--small_mem", action="store_true")
    parser.add_argument("--continue_rsm", type=int, default=0)
    parser.add_argument("--train_p", type=int, default=5)
    parser.add_argument("--fail_check_fast", type=int, default=0)
    parser.add_argument("--soft_constraint", type=int, default=1)
    parser.add_argument("--normalize_r", type=int, default=0)
    parser.add_argument("--normalize_a", type=int, default=1)
    parser.add_argument("--grid_size", default="8M")
    parser.add_argument("--std_start", default=1.0, type=float)
    parser.add_argument("--std_end", default=0.05, type=float)
    parser.add_argument("--lip_cheat", default=1, type=float)
    parser.add_argument("--p_lr", default=0.00005, type=float)
    parser.add_argument("--c_lr", default=0.0005, type=float)
    parser.add_argument("--c_ema", default=0.9, type=float)
    parser.add_argument("--v_lr", default=0.0005, type=float)
    parser.add_argument("--v_ema", default=0.0, type=float)
    args = parser.parse_args()

    if args.env.startswith("vlds"):
        env = vLDSEnv()
        env.name = args.env
    elif args.env.startswith("vdebug"):
        env = vDebugEnv()
        env.name = args.env
    elif args.env.startswith("vpend"):
        env = vInvertedPendulum()
        env.name = args.env
    elif args.env.startswith("vdrone"):
        env = vDroneEnv()
        env.name = args.env
    elif args.env.startswith("vmaze"):
        env = vMazeEnv()
        env.name = args.env
    elif args.env.startswith("vcavoid"):
        env = vCollisionAvoidanceEnv()
        env.name = args.env
    elif args.env.startswith("human2"):
        env = vHumanoidBalance2()
        env.name = args.env
    elif args.env.startswith("rooms"):
        rooms = args.env.replace("rooms", "").split("_")
        assert len(rooms) >= 2
        room1, room2 = int(rooms[0]), int(rooms[1])
        assert room1 != room2
        env = RoomEnv(room1, room2, init_radius=0.1)
        env.name = args.env
    elif args.env.startswith("room1"):
        env = RoomEnv(0, 3, init_radius=0.1)
        env.name = args.env
    elif args.env.startswith("room2"):
        env = RoomEnv(0, 3, init_radius=0.15)
        env.name = args.env
    elif args.env.startswith("room3"):
        env = RoomEnv(0, 3, init_radius=0.2)
        env.name = args.env
    elif args.env.startswith("room0"):
        meta_env = RoomMetaEnv()
        env = meta_env.get_subtask(0, 1)
        env.name = args.env
    else:
        raise ValueError(f"Unknown environment '{args.env}'")

    assert args.norm.lower() in ["l1", "linf"], "L1 and Linf norm are allowed"
    os.makedirs("checkpoints", exist_ok=True)
    learner = RSMLearner(
        [args.hidden for i in range(args.num_layers)],
        [128, 128],
        env,
        p_lip=args.p_lip,
        v_lip=args.v_lip,
        lip_lambda=args.lip_lambda,
        eps=args.eps,
        reach_prob=args.reach_prob,
        v_activation=args.v_act,
        norm=args.norm.lower(),
        n_step=args.n_step,
    )
    if (args.skip_ppo or args.continue_ppo) and args.load_sb == "":
        learner.load(f"checkpoints/{args.env}_ppo.jax", force_load_all=False)
        txt_return, _ = learner.evaluate_rl()
        print(f"Restored policy with {txt_return} return")
    if args.load_sb != "":
        learner.load_sb(args.load_sb)

    if not args.skip_ppo:
        learner.pretrain_policy(
            args.ppo_iters,
            lip_start=0.05 / 10,
            lip_end=0.05,
            save_every=None,
            std_start=args.std_start,
            std_end=args.std_end,
            normalize_r=args.normalize_r > 0,
            normalize_a=args.normalize_a > 0,
        )
        learner.save(f"checkpoints/{args.env}_ppo.jax")
        print("[SAVED]")

    verifier = RSMVerifier(
        learner,
        env,
        batch_size=interpret_size_arg(args.batch_size),
        reach_prob=args.reach_prob,
        fail_check_fast=bool(args.fail_check_fast),
        target_grid_size=interpret_size_arg(args.grid_size),
        dataset_type=args.ds_type,
        lip_cheat=args.lip_cheat,
        norm=args.norm.lower(),
    )

    if args.continue_rsm > 0:
        learner.load(f"saved/{args.env}_loop.jax")
        verifier.grid_size *= args.continue_rsm

    loop = RSMLoop(
        learner,
        verifier,
        env,
        plot=args.plot,
        train_p=args.train_p,
        min_iters=args.min_iters,
        soft_constraint=bool(args.soft_constraint),
    )

    # print("Sampling reward of the policy")
    txt_return, _ = learner.evaluate_rl()

    loop.plot_l(f"plots/{args.env}_start.png")
    with open("ppo_results.txt", "a") as f:
        f.write(f"{args.env}: {txt_return}\n")

    if args.only_ppo:
        import sys

        sys.exit(0)

    sat = loop.run(args.timeout * 60)
    loop.plot_l(f"plots/{args.env}_end.png")

    os.makedirs("study_results", exist_ok=True)
    env_name = args.env.split("_")
    if len(env_name) > 2:
        env_name = env_name[0] + "_" + env_name[1]
    else:
        env_name = args.env
    cmd_line = " ".join(sys.argv)
    with open(f"study_results/info_{env_name}.log", "a") as f:
        f.write(f"python3 {cmd_line}\n")
        f.write("    args=" + str(vars(args)) + "\n")
        f.write("    return =" + txt_return + "\n")
        f.write("    info=" + str(loop.info) + "\n")
        f.write("    sat=" + str(sat) + "\n")
        f.write("\n\n")
    with open(f"global_summary.txt", "a") as f:
        f.write(f"{cmd_line}\n")
        f.write("    args=" + str(vars(args)) + "\n")
        f.write("    return =" + txt_return + "\n")
        f.write("    info=" + str(loop.info) + "\n")
        f.write("    sat=" + str(sat) + "\n")
        f.write("\n\n")

    if args.env.startswith("rooms") or args.env.startswith("nlrooms"):
        with open(f"rooms_summary.txt", "a") as f:
            reachprob = loop.info.get("actual_reach_prob", "0")
            f.write(
                f"{args.env}, {str(sat)}, {str(reachprob)}, (runtime {pretty_time(loop.info.get('runtime',0))}) {txt_return}"
            )
            f.write(
                f" [old_reachprob={str(loop.info.get('old_reach_prob'))}, other_reachprob={str(loop.info.get('other_reach_prob'))}, improved_bound={str(loop.info.get('improved_bound'))}"
            )
            f.write("\n")

    # with open("info.log", "a") as f:
    #     f.write("args=" + str(vars(args)) + "\n")
    #     f.write("sat=" + str(sat) + "\n")
    #     f.write("info=" + str(loop.info) + "\n\n\n")txt_return