import numpy as np
import fixed_env as env
import load_trace
import abr
import os
import importlib.util
import sys

S_INFO = 6  # bit_rate, buffer_size, next_chunk_size, bandwidth_measurement(throughput and time), chunk_til_video_end
S_LEN = 8  # take how many frames in the past
A_DIM = 6
VIDEO_BIT_RATE = [300,750,1200,1850,2850,4300]  # Kbps
M_IN_K = 1000.0
REBUF_PENALTY = 4.3  # 1 sec rebuffering -> 3 Mbps
SMOOTH_PENALTY = 1
DEFAULT_QUALITY = 1# default video quality without agent
RANDOM_SEED = 42
RAND_RANGE = 1000000
RESEVOIR = 5  # BB
CUSHION = 10  # BB
BUFFER_NORM_FACTOR = 10
SUMMARY_DIR = './results'
CHUNK_TIL_VIDEO_END_CAP = 48.0
TOTAL_VIDEO_CHUNKS = 48

# Command line argument processing
test_trace = sys.argv[1]
alg = sys.argv[2]  # Receive algorithm name as second command line parameter

# Check if additional module path parameter is provided
module_path = None
if len(sys.argv) > 3:
    module_path = sys.argv[3]  # Third parameter as module file path
    process_id = sys.argv[4]

# Dynamically import module
if module_path and os.path.exists(module_path):
    # Import module from specified path
    module_name = os.path.basename(module_path).replace('.py', '')
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    abr = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(abr)
else:
    # Dynamically import corresponding module based on passed algorithm name
    if alg == "bb":
        import bb as abr
    elif alg == "mpc":
        import mpc as abr
    elif alg == "hyb":
        import hyb as abr
    elif alg == "bola":
        import bola as abr
    elif alg == "pitree":
        import pitree as abr
    elif alg == "bb2":
        import bb2 as abr
    else:
        print(f"Unknown algorithm: {alg}")
        sys.exit(1)

LOG_FILE = './results/new' + test_trace + 'log_sim_3'+ alg #+'_'+process_id 
COOKED_TRACE_FOLDER = './test/'+test_trace
# log in format of time_stamp bit_rate buffer_size rebuffer_time chunk_size download_time reward
chunk_length = 4
if not os.path.exists('./results/new' + test_trace ):
    os.makedirs('./results/new' + test_trace)

def serial(state):
    serialized_state = []
    serialized_state.append(state[0, -1])
    serialized_state.append(state[1, -1])
    for i in range(S_LEN):
        serialized_state.append(state[2, i])
    for i in range(S_LEN):
        serialized_state.append(state[3, i])
    for i in range(A_DIM):
        serialized_state.append(state[4, i])
    serialized_state.append(state[5, -1])
    return serialized_state

def main():
    np.random.seed(RANDOM_SEED)

    assert len(VIDEO_BIT_RATE) == A_DIM

    all_cooked_time, all_cooked_bw, all_file_names = load_trace.load_trace(COOKED_TRACE_FOLDER)

    net_env = env.Environment(all_cooked_time=all_cooked_time,
                              all_cooked_bw=all_cooked_bw)

    log_path = LOG_FILE + '_' + all_file_names[net_env.trace_idx]
    log_file = open(log_path, 'w')

    epoch = 0
    time_stamp = 0

    last_bit_rate = DEFAULT_QUALITY
    bit_rate = DEFAULT_QUALITY
    r_batch_all = []
    r_batch = []
    s_batch = [np.zeros((S_INFO, S_LEN))]
    past_bandwidths = []
    video_count = 0

    while True:  # serve video forever
        # the action is from the last decision
        # this is to make the framework similar to the real
        delay, sleep_time, buffer_size, rebuf, \
        video_chunk_size, next_video_chunk_sizes, \
        end_of_video, video_chunk_remain = \
            net_env.get_video_chunk(bit_rate)

        time_stamp += delay  # in ms
        time_stamp += sleep_time  # in ms

        # reward is video quality - rebuffer penalty
        reward = VIDEO_BIT_RATE[bit_rate] / M_IN_K \
                 - REBUF_PENALTY * rebuf \
                 - SMOOTH_PENALTY * np.abs(VIDEO_BIT_RATE[bit_rate] -
                                           VIDEO_BIT_RATE[last_bit_rate]) / M_IN_K
        r_batch.append(reward)

        last_bit_rate = bit_rate

        # log time_stamp, bit_rate, buffer_size, reward
        log_file.write(str(time_stamp / M_IN_K) + '\t' +
                       str(VIDEO_BIT_RATE[bit_rate]) + '\t' +
                       str(buffer_size) + '\t' +
                       str(rebuf) + '\t' +
                       str(video_chunk_size) + '\t' +
                       str(delay) + '\t' +
                       str(reward) + '\n')
        log_file.flush()
        # print(len(s_batch))
        if len(s_batch) == 0:
            state = np.zeros((S_INFO, S_LEN))
        else:
            state = np.array(s_batch[-1], copy=True)
        state = np.roll(state, -1, axis=1)
        # print(state,state.shape)
        # this should be S_INFO number of terms
        state[0, -1] = VIDEO_BIT_RATE[bit_rate] / float(np.max(VIDEO_BIT_RATE))  # last quality
        state[1, -1] = buffer_size / BUFFER_NORM_FACTOR  # 10 sec
        state[2, -1] = float(video_chunk_size) / float(delay) / M_IN_K  # kilo byte / ms
        state[3, -1] = float(delay) / M_IN_K / BUFFER_NORM_FACTOR  # 10 sec
        state[4, :A_DIM] = np.array(next_video_chunk_sizes) / M_IN_K / M_IN_K  # mega byte
        state[5, -1] = np.minimum(video_chunk_remain, CHUNK_TIL_VIDEO_END_CAP) / float(CHUNK_TIL_VIDEO_END_CAP)

        speed = video_chunk_size / delay * 8
        # past_bandwidths.append(speed/8/1000.0)
        past_bandwidths = state[2, -5:]
        while past_bandwidths[0] == 0.0:
            past_bandwidths = past_bandwidths[1:]
        past_bandwidths = [i *1000 for i in past_bandwidths]
        if alg != 'pitree':
            bit_rate = abr.abr(speed,buffer_size,next_video_chunk_sizes,past_bandwidths,video_chunk_remain,bit_rate)
        else:
            serialized_state = serial(state)
            bit_rate = abr.abr(serialized_state)
        s_batch.append(state)
        if end_of_video:
            log_file.write('\n')
            log_file.close()

            last_bit_rate = DEFAULT_QUALITY
            bit_rate = DEFAULT_QUALITY  # use the default action here
            r_batch_all.append(np.mean(r_batch[1:]))
            r_batch = []
            del s_batch[:]
            s_batch =  []
            past_bandwidths = []

            
            video_count += 1

            if video_count > len(all_file_names):
                break

            log_path = LOG_FILE + '_' + all_file_names[net_env.trace_idx]
            log_file = open(log_path, 'w')
    print(np.mean(r_batch_all),np.std(r_batch_all))


if __name__ == '__main__':
    main()