import gymnasium as gym
import numpy as np
import pandas as pd

from sb3_contrib import CrossQ, TQC
from stable_baselines3 import PPO

from datetime import datetime

from molecule_movement.cli_parsing import create_parser, dir_path
from molecule_movement.statistics import enable_statistics_logger
from molecule_movement.exceptions import InfeasibleError

from molecule_movement.wrapper import (
    CorridorWrapper,
    SATBasedSchedulingWrapper,
    RandomObstaclesWrapper, ObstacleType,
    ManipulateCurrentMoleculeWrapper,
    )
import sys

from loguru import logger

logger.configure(handlers=[{"sink": sys.stderr, "level": "INFO"}])
logger.enable("molecule_movement")

def make_env(args, corridor_width: int, percentage: float, human_render: bool = False):
    env = gym.make(args.env,
                   render_mode=args.render_mode if human_render else "none",
                   scale=args.scale,
                   render_grid=args.render_grid,
                   render_sensors=args.render_sensors,
                   num_sensors=args.sensors,
                   max_steps=3000,
                   draw_names=args.molecule_names)
    env = CorridorWrapper(env, corridor_width=corridor_width, parking_distance=5, parking_buffer=1.0)
    env = SATBasedSchedulingWrapper(env, store_explanations={"img": f"{args.env}", "tensorboard" : "trained_policies/sat_based_scheduling"})
    size = 1.5
    env = ManipulateCurrentMoleculeWrapper(env, (-size, size, 0.3), (-size, size, 0.3))
    env = RandomObstaclesWrapper(env, percentage=percentage, obstacle_types=ObstacleType.MOLECULE, seed=args.seed)
    return env


def main(env, corridor_width, percentage, model):
    logger.success(f"Running {env} with {corridor_width} and {percentage=}")
    args.env = env
    logger.info(f"Running for {args}")
    env = make_env(args, corridor_width, percentage=percentage, human_render=False)

    for seed in list(range(0,args.seed + 1)):
        logger.bind(task="scheduling", seed=int(seed)).trace(f"{seed=}")
        try:
            obs, _ = env.reset(seed=seed)
            now = datetime.now().isoformat(timespec="seconds")
            while True:
                action, _ = model.predict(obs, deterministic=True)
                obs, _, truncated, terminated, _ = env.step(action)
                if truncated or terminated:
                    try:
                        obs, _ = env.get_wrapper_attr('increment_matching')()
                    except StopIteration as e:
                        final_distance = 0
                        final_distances = list()
                        final_orientations = 0
                        for m in env.get_wrapper_attr("matching"):
                            distance = m.molecule.center.distance(m.goal.position)
                            final_distance += distance
                            final_distances.append(distance)
                            final_orientations += 1 if np.abs(m.molecule.orientation - m.goal.orientation) > 0 else 0
                        logger.bind(task="stats", final_distance=final_distance, final_orientations=final_orientations, final_distances=str(final_distances).replace(",",";"), newline=True).trace(f"{final_distance} = {final_distance}")
                        statistics.df_to_csv(filename=f"assembly_evaluation/{now}_{env.spec.id.replace('/','_')}_cw{corridor_width}_percentage{percentage}.csv")
                        statistics.clear()
                        break
            #input()
        except InfeasibleError as e:
            #input()
            logger.error(e)
            continue
        except AttributeError as e:
            logger.error(e)
            #input("")
            continue

    statistics.clear()
    return

if __name__ == "__main__":
    parser = create_parser()
    parser.add_argument('--model-path', type=dir_path, help="Logging directory for which 'best_model.zip' should be evaluated")
    args = parser.parse_args()
    if args.seed is None:
        args.seed = 0
    statistics = enable_statistics_logger(logging_tasks=["scheduling", "stats"])
    #fepc_model = CrossQ.load(f"policies/cross/best_model.zip")
    fepc_high_var_model = CrossQ.load(f"policies/cross_high_variance/best_model.zip")
    #circ_model = TQC.load(f"policies/circ/best_model.zip")
    percentages = [0.003]
    for cw in [6]:
        for p in percentages:
              #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-Square40N36-90x90-v0",  corridor_width=cw, percentage=p, model=fepc_model)
              main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-Square40N36-90x90-v0",  corridor_width=cw, percentage=p, model=fepc_high_var_model)
              #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-Square40N36-90x90-v0",  corridor_width=cw, percentage=p, model=circ_model)
        for p in percentages:
              #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-Square110N96-180x180-v0", corridor_width=cw, percentage=p, model=fepc_model)
              main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-Square110N96-180x180-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
              #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-Square110N96-180x180-v0", corridor_width=cw, percentage=p, model=circ_model)
        for p in percentages:
              #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-Square220N180-310x310-v0", corridor_width=cw, percentage=p, model=fepc_model)
              main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-Square220N180-310x310-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
              #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-Square220N180-310x310-v0", corridor_width=cw, percentage=p, model=circ_model)
        for p in percentages:
              #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-Square360N288-470x470-v0", corridor_width=cw, percentage=p, model=fepc_model)
              main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-Square360N288-470x470-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
              #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-Square360N288-470x470-v0", corridor_width=cw, percentage=p, model=circ_model)
        for p in percentages:
              #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-Square520N420-640x640-v0", corridor_width=cw, percentage=p, model=fepc_model)
              main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-Square520N420-640x640-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
              #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-Square520N420-640x640-v0", corridor_width=cw, percentage=p, model=circ_model)
        for p in percentages:
             #main("MoleculeMovement/FePc-Quantum-Corral-250x250", corridor_width=cw, percentage=p, model=fepc_model)
             main("MoleculeMovement/FePc-HighVar-Quantum-Corral-250x250", corridor_width=cw, percentage=p, model=fepc_high_var_model)
             #main("MoleculeMovement/CiRc-Quantum-Corral-250x250", corridor_width=cw, percentage=p, model=circ_model)
        for p in percentages:
             #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-HoneyComb-1-v0", corridor_width=cw, percentage=p, model=fepc_model)
             main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-HoneyComb-1-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
             #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-HoneyComb-1-v0", corridor_width=cw, percentage=p, model=circ_model)
             #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-HoneyComb-2-v0", corridor_width=cw, percentage=p, model=fepc_model)
             main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-HoneyComb-2-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
             #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-HoneyComb-2-v0", corridor_width=cw, percentage=p, model=circ_model)
             #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-HoneyComb-3-v0", corridor_width=cw, percentage=p, model=fepc_model)
             main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-HoneyComb-3-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
             #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-HoneyComb-3-v0", corridor_width=cw, percentage=p, model=circ_model)
             #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-HoneyComb-4-v0", corridor_width=cw, percentage=p, model=fepc_model)
             main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-HoneyComb-4-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
             #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-HoneyComb-4-v0", corridor_width=cw, percentage=p, model=circ_model)
             #main("MoleculeMovement/FePc-Au111-MockUp-Uniform-HoneyComb-5-v0", corridor_width=cw, percentage=p, model=fepc_model)
             main("MoleculeMovement/FePc-Au111-HighVar-MockUp-Uniform-HoneyComb-5-v0", corridor_width=cw, percentage=p, model=fepc_high_var_model)
             #main("MoleculeMovement/CiRc-Au111-MockUp-Uniform-HoneyComb-5-v0", corridor_width=cw, percentage=p, model=circ_model)

