import math
import multiprocessing
import random
import time
from typing import List

import tensorflow as tf

from tf_agents.environments import tf_py_environment

from action_freezer import ActionFreezer
from replay_buffer import init_reverb_replay_buffer
from MultiSM import MultiStateMachine
from checkpoint import save_policy, save_checkpointer, load_checkpointer
from config import configuration, REFINE_THRESH, PROFILE_FRE
from atomic_commit.AtomicCommitEnv import AtomicCommitEnv
from distributed_locking.DistributedLockingEnv import DistributedLockingEnv
from import_functions import import_driver, import_verifier, print_states
from simple_counter.ComplexCounterEnv import ComplexCounterEnv
from math_func.MathEnv import MathEnv
from Timer import Timer
from informal_verifier import Verifier, informal_verify, sequential_verify
from init_agent import init_agent
from init_network import init_q_net, init_transformer_network
from primary_backup.PrimaryBackupEnv import PrimaryBackupEnv
from generate_combs import Info, generate_all_combs
from utils import print_digit_information, log, parse_arg


def save_and_finish(save_dir, agent, train_checkpointer, v):
    save_policy(f"{save_dir}_policy", agent.policy)
    train_checkpointer.save(agent.train_step_counter)
    print("PASS ALL")
    v.finish()
    return


def train(opt):
    players = opt.players
    num_round = opt.rounds
    collect_iter = opt.collect_iter
    protocol = opt.protocol
    save_dir = opt.save_dir
    load_dir = opt.load_dir

    # num_paralle_calls should be smaller than batch_size
    if opt.debug:
        batch_size = 1
    else:
        batch_size = 64
    num_parallel_calls = 16 if batch_size >= 16 else 1
    if not opt.gpu:
        # Set CPU as available physical device
        tf.config.set_visible_devices([], "GPU")
    else:
        # Don't allocate all gpu memory for one process
        gpus = tf.config.experimental.list_physical_devices("GPU")
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    conf = configuration(opt)
    # import functions based on protocol
    collect_data = import_driver(opt.protocol)

    conf.log(opt)
    this_func = "TRAIN"
    envs = [None] * players
    tf_envs = [None] * players
    datasets = [None] * players
    multi_sm = MultiStateMachine(players, protocol)
    for i in range(players):
        if protocol == "atomic_commit":
            envs[i] = AtomicCommitEnv(players, i, multi_sm, is_training=True)
        elif protocol == "distributed_locking":
            envs[i] = DistributedLockingEnv(players, i, multi_sm, is_training=True)
        elif protocol == "counter":
            envs[i] = ComplexCounterEnv(players, i, multi_sm, is_training=True)
        elif protocol == "math":
            envs[i] = MathEnv(players, i, multi_sm, is_training=True)
        elif protocol == "primary_backup":
            envs[i] = PrimaryBackupEnv(
                players, i, num_round, multi_sm, is_training=True
            )
        tf_envs[i] = tf_py_environment.TFPyEnvironment(envs[i])
    # all_combs = range(2 ** players)  # TODO: this should be case-wise
    all_combs: List[Info] = generate_all_combs(players, num_round)
    id_input_map = {i: comb for i, comb in enumerate(all_combs)}
    all_combs_id = list(id_input_map.keys())
    agents = [None] * players
    observers = [None] * players
    buffers = [None] * players

    q_net = (
        init_q_net(envs[0]) if opt.model == "mlp" else init_transformer_network(envs[0])
    )
    for i in range(players):
        agents[i] = init_agent(tf_envs[i], q_net)

    observers[0], buffers[0] = init_reverb_replay_buffer(
        agents[0], f"uniform_table_{0}"
    )

    # Load saved checkpointer from <load_dir> to agents[0]
    if load_dir:
        load_checkpointer(load_dir, agents[0])
    # create another checkpointer incase we want to save model to another directory
    if save_dir:
        train_checkpointer = save_checkpointer(save_dir, agents[0])

    t = Timer()
    freezer = ActionFreezer(protocol, enable_freeze=opt.freeze)
    v = Verifier(8, all_combs, freezer)

    """Phase-1 Train with random data"""
    random.seed(time.perf_counter())
    if opt.debug:
        iter_num = 1
        if opt.test_id:
            test_case_id = opt.test_id
        else:
            test_case_id = random.choice(all_combs_id)
    else:
        iter_num = 2000

    ### Phase-1 Collect data ###
    print("Phase-1 collecting data...")
    if opt.debug:
        collect_data(
            envs,
            agents,
            observers,
            freezer,
            iter_num,
            players,
            all_combs,
            is_refine=True,
            failed_cases=[test_case_id],
        )
    else:
        collect_data(
            envs,
            agents,
            observers,
            freezer,
            iter_num,
            players,
            all_combs,
            is_refine=False,
            failed_cases=None,
        )

    datasets[0] = iter(
        buffers[0].as_dataset(
            num_parallel_calls=num_parallel_calls,
            sample_batch_size=batch_size,
            num_steps=1,
        )
    )

    ### Phase-1 Training ###
    print("Phase-1 training...")
    for it in range(REFINE_THRESH):
        if it % 500 == 0:
            print(f"Phase-1 training iteration: {it}")
        exp, unused_info = next(datasets[0])
        log(this_func, f"exp: {exp}")
        targets = exp.reward.numpy()
        train_loss = agents[0].my_train(exp, targets).loss
        if opt.debug:
            print_states(protocol)
            exit()
    buffers[0].clear()

    ### Phase-1 Verification ###
    print("Phase-1 verification...")
    train_checkpointer.save(agents[0].train_step_counter)
    pass_all, failed_cases = informal_verify(
        opt,
        agents[0]._q_network,
        v,
        all_combs,
        all_combs_id,
        v.num_process,
        is_testall=True,
        is_log=False,
    )
    if pass_all:
        save_and_finish(save_dir, agents[0], train_checkpointer, v)
        return agents[0]
    print(f"collect {len(failed_cases)} new failed cases")

    ### Phase-2 Training ###
    print("Phase-2 training...")
    remaining = failed_cases
    collect_round = 200
    inner_iter = math.ceil(
        collect_round * players * 2 / batch_size
    )  # Make sure every new data collected can be touched at least once
    for coll_it in range(collect_iter):
        if coll_it % 500 == 0:
            print(f"Phase-2 training iteration: {coll_it}")

        if coll_it % PROFILE_FRE == 0:
            t.start()

        random.seed(time.perf_counter())
        collect_data(
            envs,
            agents,
            observers,
            freezer,
            collect_round,
            players,
            all_combs,
            is_refine=True,
            failed_cases=remaining,
        )
        datasets[0] = iter(
            buffers[0].as_dataset(
                num_parallel_calls=num_parallel_calls,
                sample_batch_size=batch_size,
                num_steps=1,
            )
        )
        if coll_it % PROFILE_FRE == 0:
            t.stop()
            print(f"[PROFILE]: {t.elapsed_time():.2f} seconds for data collection")
            t.start()

        for _ in range(inner_iter):
            exp, unused_info = next(datasets[0])
            log(this_func, f"exp: {exp}")
            # tracker.track_all(exp)
            targets = exp.reward.numpy()
            train_loss = agents[0].my_train(exp, targets).loss

        if coll_it % PROFILE_FRE == 0:
            t.stop()
            print(f"[PROFILE]: {t.elapsed_time():.2f} seconds for training")

        train_checkpointer.save(agents[0].train_step_counter)
        if coll_it % PROFILE_FRE == 0:
            t.start()
        pass_failed_case, remaining_failed = informal_verify(
            opt,
            agents[0]._q_network,
            v,
            all_combs,
            remaining,
            v.num_process,
            is_testall=True,
            is_log=False,
        )

        if coll_it % PROFILE_FRE == 0:
            t.stop()
            print(
                f"[PROFILE]: {t.elapsed_time():.2f} seconds for verifying all failed cases"
            )
        print(f"collect {len(remaining_failed)} failed cases")
        # print(remaining_failed)
        remaining = remaining_failed
        if pass_failed_case:
            print("all failed cases passed")
            t.start()
            pass_all, failed_cases = informal_verify(
                opt,
                agents[0]._q_network,
                v,
                all_combs,
                all_combs_id,
                v.num_process,
                is_testall=True,
                is_log=False,
            )
            t.stop()
            verify_seconds = t.elapsed_time()
            remaining = failed_cases
            print(f"[PROFILE]: {verify_seconds:.2f} seconds for verifying all cases")
            print(f"collect new failed cases: {len(failed_cases)}")
            # print(failed_cases)
            if pass_all:
                save_and_finish(save_dir, agents[0], train_checkpointer, v)
                return agents[0]

        buffers[0].clear()
    save_policy(f"{save_dir}_policy", agents[0].policy)
    train_checkpointer.save(agents[0].train_step_counter)
    v.finish()
    return agents[0]


if __name__ == "__main__":
    opt = parse_arg()
    print(opt)
    # num_episodes: number of episodes for data collection
    # iteration: training iteration in each collect_iter
    # collect_iter: iterations to collect new data
    train(opt)
