from sb3_contrib import CrossQ

import gymnasium as gym
import numpy as np

from molecule_movement.cli_parsing import create_parser, dir_path
from molecule_movement.statistics import enable_statistics_logger

import pygame


from molecule_movement.wrapper import (
    CorridorWrapper,
    ManipulateCurrentMoleculeWrapper,
    SATBasedSchedulingWrapper,
    RandomObstaclesWrapper, ObstacleType,
    )

import sys
from datetime import datetime

from loguru import logger

import pygame

logger.configure(handlers=[{"sink": sys.stderr, "level": "INFO"}])
logger.enable("molecule_movement")

def make_env(args, corridor_width: int, human_render: bool = False):
    env = gym.make(args.env,
                   render_mode=args.render_mode if human_render else "none",
                   scale=args.scale,
                   render_grid=args.render_grid,
                   render_sensors=args.render_sensors,
                   num_sensors=args.sensors,
                   max_steps=args.max_steps,
                   seed=args.seed,
                   origin_offset=args.offset,
                   draw_names=args.molecule_names)
    env = CorridorWrapper(env, corridor_width=corridor_width, parking_buffer=1.0, parking_distance=5)
    size = 1.5
    env = ManipulateCurrentMoleculeWrapper(env, (-size, size, 0.3), (-size, size, 0.3))
    env = SATBasedSchedulingWrapper(env, store_explanations={"img": f"{args.env}", "tensorboard" : "trained_policies/sat_based_scheduling"})
    env = RandomObstaclesWrapper(env, percentage=0.002, obstacle_types=ObstacleType.MOLECULE, seed=args.seed)
    return env

def main():
    clock = pygame.time.Clock()

    statistics = enable_statistics_logger(logging_tasks=["stats", "scheduling"],log_trace_only=False)
    parser = create_parser()
    parser.add_argument('--model-path', type=dir_path, help="Logging directory for which 'best_model.zip' should be evaluated")
    args = parser.parse_args()
    logger.info(f"Starting run for {args.env}")
    env = make_env(args, corridor_width=args.corridor_width, human_render=True)
    model = CrossQ.load(f"{args.model_path}/best_model")
    steps = 0
    now = datetime.now().isoformat(timespec="microseconds")
    iters = 0
    try:
        obs, _ = env.reset(seed=args.seed)
    except StopIteration as e:
        logger.error(e)
        input("")
    #statistics.df_to_csv(filename=f"evaluation/SAT_cw{args.corridor_width}_{now}_{iters}_{env.spec.id.replace('/','_')}.csv")
    logger.bind(task="stats", newline=True).trace("")

    clock = pygame.time.Clock()
    num_runs = 100
    while True:
        clock.tick()
        action, _ = model.predict(obs, deterministic=False)
        obs, _, truncated, terminated, _ = env.step(action)

        clock.tick()
        steps += 1
        if truncated or terminated:
            try:
                steps = 0
                obs, _ = env.get_wrapper_attr('increment_matching')()
            except StopIteration as e:
                final_distance = 0
                final_distances = list()
                final_orientations = 0
                for m in env.get_wrapper_attr("matching"):
                    distance = m.molecule.center.distance(m.goal.position)
                    final_distance += distance
                    final_distances.append(distance)
                    final_orientations += 1 if np.abs(m.molecule.orientation - m.goal.orientation) > 0 else 0
                logger.bind(task="stats", final_distance=final_distance, final_orientations=final_orientations, final_distances=str(final_distances).replace(",",";"), newline=True).trace(f"{final_distance}/20 = {final_distance/20}")
                logger.info(f"{env.spec.id.replace('/','_')} -> {iters=}")
                #statistics.df_to_csv(filename=f"evaluation/{now}_{iters}_{env.spec.id.replace('/','_')}.csv")
                #statistics.clear()
                iters += 1
                if iters == num_runs:
                    logger.info(statistics.stats("inside_corridor"))
                    logger.info(statistics.stats("movement_travelled"))
                    logger.info(statistics.sum("crashed"))
                    sys.exit(0)
                env.reset(seed=args.seed)
                now = datetime.now().isoformat(timespec="seconds")
if __name__ == "__main__":
    main()
