from pathlib import Path
import pandas as pd
import itertools
import logging
import sys
import argparse

import vvcgym
from vvcgym.planes.f16_plane import F16Plane
from vvcgym.planes.utils.f16Classes import Guide, ControlLaw, PlaneModel
from vvcgym.env import VVCGymEnv
from vvcgym.utils.my_log import get_logger

PROJECT_ROOT_DIR: Path = Path(__file__).parent.parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))


action_mins = F16Plane.get_action_lower_bounds()
action_maxs = F16Plane.get_action_higher_bounds()

P_MAX = action_maxs.p
P_MIN = action_mins.p
NZ_MAX = action_maxs.nz
NZ_MIN = action_mins.nz
PLA_MAX = action_maxs.pla
PLA_MIN = action_mins.pla

class Rollout:

    def __init__(self, 
        target_v, target_mu, target_chi, 
        h0=5000, v0=200, 
        v_threshold=10., mu_threshold=1., chi_threshold=1., 
        integral_time_length=1, step_frequence=100, 
        max_rollout_time=120, 
        traj_save_dir: Path=PROJECT_ROOT_DIR / "data" / "tmp",
        trajectory_save_prefix: str="traj",
        my_logger: logging.Logger=None
    ) -> None:

        self.target_v = target_v
        self.target_mu = target_mu
        self.target_chi = target_chi
        
        self.h0 = h0
        self.v0 = v0

        self.v_threshold = v_threshold
        self.mu_threshold = mu_threshold
        self.chi_threshold = chi_threshold
        self.integral_time_length = integral_time_length
        self.step_frequence = step_frequence

        self.max_rollout_time = max_rollout_time

        self.traj_save_dir = traj_save_dir
        self.trajectory_save_prefix = trajectory_save_prefix

        self.gtmp = Guide()
        self.f16cl = ControlLaw(stepTime=1./step_frequence)
        self.f16model = PlaneModel(h0, v0, stepTime=1./step_frequence)

        self.log_state_keys = ["phi", "theta", "psi", "v", "mu", "chi", "p", "h"]
        self.log_action_keys = ["p", "nz", "pla", "rud"]
        self.logs = {}
        
        self.init_log()

        self.logger = my_logger

    @property
    def wCmds(self) -> dict:
        return {
            "v": self.target_v,
            "mu": self.target_mu,
            "chi": self.target_chi
        }
    
    @property
    def sim_interval(self) -> float:
        return 1. / self.step_frequence

    @property
    def integral_window_length(self):
        return self.integral_time_length * self.step_frequence
    
    @property
    def v_integral_threshold(self):
        return self.v_threshold * self.integral_window_length
    
    @property
    def mu_integral_threshold(self):
        return self.mu_threshold * self.integral_window_length
    
    @property
    def chi_integral_threshold(self):
        return self.chi_threshold * self.integral_window_length

    @property
    def max_rollout_length(self) -> int:
        return self.max_rollout_time * self.step_frequence
    
    def init_log(self):
        self.logs = {}
        self.logs["time"] = []
        for k in self.log_state_keys:
            self.logs[f"s_{k}"] = []
        for k in self.log_action_keys:
            self.logs[f"a_{k}"] = []

    def log(self, state, action, time:float):
        self.logs["time"].append(round(time, 2))
        for k in self.log_state_keys:
            self.logs[f"s_{k}"].append(state[k])
        for k in self.log_action_keys:
            self.logs[f"a_{k}"].append(action[k])

    def save(self):
        df = pd.DataFrame(self.logs)
        df.to_csv(str((self.traj_save_dir / f"{self.trajectory_save_prefix}_{self.target_v}_{self.target_mu}_{self.target_chi}.csv").absolute()), index=False)

    def is_terminated(self) -> bool:
        if len(self.logs["time"]) < self.integral_window_length:
            return False
        else:
            v_flag, mu_flag, chi_flag = False, False, False
            if sum([abs(self.target_v - item) for item in self.logs["s_v"][-self.integral_window_length:]]) < self.v_integral_threshold:
                v_flag = True
            if sum([abs(self.target_mu - item) for item in self.logs["s_mu"][-self.integral_window_length:]]) < self.mu_integral_threshold:
                mu_flag = True
            if sum([abs(self.target_chi - item) for item in self.logs["s_chi"][-self.integral_window_length:]]) < self.chi_integral_threshold:
                chi_flag = True
            if v_flag and mu_flag and chi_flag:
                return True
            else:
                return False

    def rollout(self) -> int:
        stsDict = self.f16model.getPlaneState()  # stsDict: lef, npos, epos, h, alpha, beta, phi, theta, psi, p, q, r, v, vn, ve, vh, nx, ny, nz, ele, ail, rud, thrust, lon, lat, mu, chi

        for i in range(self.max_rollout_length):
            self.gtmp.step(self.wCmds, stsDict)
            gout = self.gtmp.getOutputDict()

            if not (P_MIN <= gout["p"] <= P_MAX):
                print('Invalid p: ', gout["p"])
            if not (NZ_MIN <= gout["nz"] <= NZ_MAX):
                print('Invalid nz: ', gout["nz"])
            if not (PLA_MIN <= gout["pla"] <= PLA_MAX):
                print("Invalid pla: ", gout["pla"])

            gout["p"] = P_MAX if gout["p"] > P_MAX else gout["p"]
            gout["p"] = P_MIN if gout["p"] < P_MIN else gout["p"]
            gout["nz"] = NZ_MAX if gout["nz"] > NZ_MAX else gout["nz"]
            gout["nz"] = NZ_MIN if gout["nz"] < NZ_MIN else gout["nz"]
            gout["pla"] = PLA_MAX if gout["pla"] > PLA_MAX else gout["pla"]
            gout["pla"] = PLA_MIN if gout["pla"] < PLA_MIN else gout["pla"]

            self.f16cl.step(gout, stsDict)
            clout = self.f16cl.getOutputDict()
            self.f16model.step(clout)
            
            self.log(state=stsDict, action=gout, time=i * self.sim_interval)

            if stsDict["h"] < 0.:
                # print(f'\033[1;31mcrashed!!!!!!!!!!!\033[0m v = {self.target_v}, mu = {self.target_mu}, chi = {self.target_chi}')
                if self.logger is not None:
                    self.logger.info(f'\033[1;31mcrashed!!!!!!!!!!!\033[0m v = {self.target_v}, mu = {self.target_mu}, chi = {self.target_chi}')
                return 0

            # judge termination
            if self.is_terminated():
                # print(f"\033[1;32m Terminated. \033[0m v = {self.target_v}, mu = {self.target_mu}, chi = {self.target_chi}, len = {len(self.logs['time'])}")
                if self.logger is not None:
                    self.logger.warning(f"\033[1;32m Terminated. \033[0m v = {self.target_v}, mu = {self.target_mu}, chi = {self.target_chi}, len = {len(self.logs['time'])}")
                self.save()
                return len(self.logs["time"])

            stsDict = self.f16model.getPlaneState()

        # print(f'\033[1;31mreach max length!!!!!!!!!!!\033[0m v = {self.target_v}, mu = {self.target_mu}, chi = {self.target_chi}')
        
        if self.logger is not None:
            self.logger.info(f'\033[1;31mreach max length!!!!!!!!!!!\033[0m v = {self.target_v}, mu = {self.target_mu}, chi = {self.target_chi}')
        # self.save()
        return 0


class ScheduleForRollout:

    def __init__(self, 
            rollout_class=Rollout, 
            v_range:list=[100, 300], v_interval:int=10, 
            mu_range:list=[-85, 85], mu_interval:int=5, 
            chi_range:list=[-180, 180], chi_interval=5,
            step_frequence:int=10,
            save_dir: str="data/tmp"
        ) -> None:
        self.rollout_class = rollout_class
        self.v_range = v_range
        self.v_interval = v_interval
        self.mu_range = mu_range
        self.mu_interval = mu_interval
        self.chi_range = chi_range
        self.chi_interval = chi_interval
        self.step_frequence = step_frequence
        self.save_dir = PROJECT_ROOT_DIR / save_dir
        
    def work(self):
        log = {
            "v": [],
            "mu": [],
            "chi": [],
            "length": []
        }

        
        if not self.save_dir.exists():
            self.save_dir.mkdir()

        my_logger = get_logger(logger_name="ucav", log_file_dir=str(self.save_dir / 'my_sys_logs.log'))

        for v, mu, chi in itertools.product(
            range(self.v_range[0], self.v_range[1]+1, self.v_interval), 
            range(self.mu_range[0], self.mu_range[1]+1, self.mu_interval), 
            range(self.chi_range[0], self.chi_range[1]+1, self.chi_interval)):
            
            rollout_worker = self.rollout_class(
                target_v=v, target_mu=mu, target_chi=chi, 
                my_logger=my_logger,
                traj_save_dir=self.save_dir,
                step_frequence=self.step_frequence
            )
            episode_length = rollout_worker.rollout()
            
            log["v"].append(v)
            log["mu"].append(mu)
            log["chi"].append(chi)
            log["length"].append(episode_length)

        df = pd.DataFrame(log)
        df.to_csv(self.save_dir / "res.csv", index=False)


if __name__ == "__main__":  
    
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--data-dir-suffix", type=str, default="v1", help="suffix of data dir")
    parser.add_argument("--step-frequence", type=int, default=10, help="simulation frequence")
    parser.add_argument("--v-min", type=int, default=100, help="minimal value of speed")
    parser.add_argument("--v-max", type=int, default=300, help="maximum value of speed")
    parser.add_argument("--v-interval", type=int, default=10, help="sample interval of speed")
    parser.add_argument("--mu-min", type=int, default=-85, help="minimal value of flight path elevator angle")
    parser.add_argument("--mu-max", type=int, default=85, help="maximum value of flight path elevator angle")
    parser.add_argument("--mu-interval", type=int, default=5, help="sample interval of flight path elevator angle")
    parser.add_argument("--chi-min", type=int, default=-170, help="minimal value of flight path azimuth angle")
    parser.add_argument("--chi-max", type=int, default=170, help="maximum value of flight path azimuth angle")
    parser.add_argument("--chi-interval", type=int, default=5, help="sample interval of flight path azimuth angle")
    args = parser.parse_args()
    
    s = ScheduleForRollout(
        rollout_class=Rollout, 
        v_range=[args.v_min, args.v_max], 
        v_interval=args.v_interval, 
        mu_range=[args.mu_min, args.mu_max], 
        mu_interval=args.mu_interval, 
        chi_range=[args.chi_min, args.chi_max], 
        chi_interval=args.chi_interval,
        step_frequence=args.step_frequence,
        save_dir=PROJECT_ROOT_DIR / "demonstrations" / "data" / f"{args.step_frequence}hz_{args.v_interval}_{args.mu_interval}_{args.chi_interval}_{args.data_dir_suffix}"
    )
    s.work()