import os
import sys
import environment as env
import KnapsackBaseline
import OracleIntLPbaseline
import FrankWolfeBaseline
import DeepScheduler
from torch import cuda




#~~~~~~HYPER-PARAMETERS OF ALGORITHMS ~~~~~~~~~
def Define_HyperParameters(Method):
    if Method == 'Knapsack':
        hyper_param = {
            'N_parallel_env_test':  3
        }
    elif Method == 'KnapsackCheat':
        hyper_param = {
            'N_parallel_env_test':  6
        }
    elif Method == "IntegerLPoracle":
        hyper_param ={
            'N_parallel_env_test':  2,
            'OracleTimesteps': 4,#how many timesteps ahead it forsees (has to be greater than 0 and for 0 use 'Knapsack')
        }
    elif Method == "FrankWolfe":
        hyper_param = {
            'N_parallel_env_test': 2,#
            'AssumeChannel': "iid",# The way channel of a user is assumed to evolve per time slots. Choices: "iid","Constant"
                            # Carefull this affects only the math used to define the objective function and doesn't determine 
                            # the real channel a user experiences.
            'T_greedysteps': 4, #Be carefull, the complexity grows exponentially to the number of 
                                #classes and how big the maxLatency of a class is to this variable
            'N_DifferentInitializations':  20, #there are a lot of local optimum due to the probability of 
                                    #success being not smooth (it is more like a heaviside function)
            'InitializationType':["Gamma",0.3,1] ,# choices: ["absGaussian", mean, scale], ["Gamma",0.3,1]->good when iid channel (shape = 0.3, scale = 1)
                                    #["Uniform"] -> good for  constant channel
            'MaxSteps': 15, #Maximum number of steps the algorithm is allowed to take to find the local optimum
        }
    elif Method == 'DeepScheduler':  
        hyper_param = {
            'device': "cpu",# if cuda.is_available() else "cpu",
            #Architecture of NNs  
            'N_Quantiles': 50,
            'MemorizedActions': 0, #how many actions of previous time steps to be stored


            #Training hyperparameters
            'TradeOff_between_lossesCritic': 1.0,  
            'TradeOff_reward_punish':1.0,   
            'Gamma': 0.95,    
            'BatchSize': 32,
            'CapacityPool': 5*1000, #Approximate MemoryNeede = 50Bytes per User * Kusers * CapacityPool
            'LearningRate_Critic': 1e-3,   
            'LearningRate_Actor' : 1e-3,    
            'SyncProcess_act': 0.5e-2,  #If it is integer then the targetNN is synced after that integer number of iterations 
                                        #If it is float in (0,1.0] then targetNN is synced partially in every iteration    
            'SyncProcess_crt': 0.5e-2,
            'ExploreProbability':0.25, #Bernoulli probability (annealing not implemented yet) 
            'N_parallel_env_train': 16,  


            #Monitoring progress
            'CreateWriter': True,       
            'AfterSamplesToWrite_train': 10*1000,            
            'AfterSamplesToTest': 20*1000, #i.e. in 'AfterSamplesToTest'/'BatchSize' iterations
            'N_parallel_env_test': 10,   


            #Termination   
            'MaxIter': 1000*1000, 
            'MaxNsamples': 6*1000*1000
        }
    return hyper_param        






if __name__ == "__main__":
#~~~~~ SYSTEM PARAMETERS ~~~~~~~
    #PROTOCOL
    RetransProtocolAvailable = ["HARQ-type I",] #Other kinds of retransmission protocols could be implemented
    CSIestimation_Available = ["Full", "Statistical"]
    PROTOCOL = {
        'CSIestimation' : CSIestimation_Available[1], 
        'RetransProtocol' : RetransProtocolAvailable[0]
        }

    #GEOMETRY
    GEOMETRY = {#Defining the disk the base station is covering
        'Rmin' : 0.05,#in Km 
        'Rmax' : 1.0
        }   
    
    #CHANNEL
    channel_Available = ["ExpMarkovTimeCorr",]
    CHANNEL = {
        'ChannelType': channel_Available[0],
        'ChannelInfo': 0.0, #For "ExpMarkovTimeCorr" is  "r" in the equation : h_t = r*h_{t-1} + CN(0,sqrt(1-r^2)/2), where
                            #   CN(0,sigma) is a complex gaussian random variable with real and imaginary part being
                            #   independent with variance=sigma^2
        'PathLoss' : 3.7,
        'ConstLoss_div_Noise' : 204 #ConstLoss/sigma_noise^2 ... (for example 10^(-0.1*120.9)/(-114dbm) = 204 ) 
        }

    #TRAFFIC
    #ClassUser =  namedtuple('TypeOfClassUser', {'Max_Latency', 'Data', 'Importance', 'Arrival_Prob'}) 
    TRAFFIC = {    
        'ClassesInfo':          [{'Max_Latency': 2, 'Data':32*8*256, 'Importance':1.0, 'Arrival_Prob': 0.15}, 
                                {'Max_Latency': 2, 'Data':32*8*256, 'Importance':2.0, 'Arrival_Prob': 0.05},
                                {'Max_Latency': 10, 'Data':32*8*2048, 'Importance':1.0, 'Arrival_Prob': 0.3},  
                                {'Max_Latency': 10, 'Data':32*8*2048, 'Importance':2.0, 'Arrival_Prob': 0.05},                                                             
                                {'Max_Latency':1, 'Data':0, 'Importance':0, 'Arrival_Prob':0.45}],   #ZERO CLASS ALWAYS LAST
        'Kusers': 100
        }

    #RESOURCES
    RESOURCES = {
        'BW' :2.0e6   ,#in Hz
        'EnergyPerSymbol' : 1.0/320  #energy per symbol per Hz. We assumed no Power allocation but it can be added 
                            # as the environment is designed so as to accept users being allocated with different powers
        }  

     

#~~~~~~ RUNNING THE ALGORITHMS ~~~~~~~~~        
    #~~~~~~~ METHOD ~~~~~~~~~
    methodAvailable = ["Knapsack", "IntegerLPoracle", "FrankWolfe", "DeepScheduler"]
    MethodUsed = methodAvailable[3] #Define your method
    if MethodUsed in ["DeepScheduler"]:
        LOAD_MODEL = False #Choose if you want to load the NN model. If true you will be asked to enter the name of the model
    else: 
        LOAD_MODEL = None

    #~~~ Creating the test-environment & Printing general Info ~~~~
    HYPER_PARAMETERS = Define_HyperParameters(MethodUsed)
    env_test  = env.Environment( PROTOCOL, GEOMETRY, CHANNEL, TRAFFIC, RESOURCES, HYPER_PARAMETERS, 'test')  
    print("Method Used:  ", MethodUsed, "\nChannel corr.: ", CHANNEL['ChannelInfo'])
    print("BW: ", RESOURCES['BW'], "/// Power: ", RESOURCES['EnergyPerSymbol'])
    print("Protocol Used: ", PROTOCOL["CSIestimation"])

    #~~~~~ Run the algortihms and test them ~~~~~
    if MethodUsed == "Knapsack":     
            KnapsackBaseline.BaselineTest_Knapsack(env_test, RESOURCES)                   


    elif MethodUsed == "IntegerLPoracle":
        OracleIntLPbaseline.BaselineTest_intLPoracle(env_test, RESOURCES)


    elif MethodUsed == "FrankWolfe":  
        FrankWolfeBaseline.BaselineTest_FrankWolfeOpt(HYPER_PARAMETERS, env_test, RESOURCES, GEOMETRY, CHANNEL)
                

    elif MethodUsed == "DeepScheduler":
        #Creating a tag Name with which we will save the NNs and for TensorBoard naming curves
        TagName = 'example'#Put the Tag Name you desire


        #Define Saving path and inform for possible overwritting
        save_path = os.path.join("saves", MethodUsed) 
        if os.path.exists(save_path) and LOAD_MODEL:
            print("Model if there is with tag name: ", TagName)
            print("will be loaded and if its performance is improved previous version will be overwritten. Press Enter")
            input()
        else:
            print("Be sure that a useful model will not be lost due to same TagName")
            input()
            os.makedirs(save_path, exist_ok=True)

       
        #Run    
        env_train = env.Environment( PROTOCOL, GEOMETRY, CHANNEL, TRAFFIC, RESOURCES, HYPER_PARAMETERS, 'train')
        DeepScheduler.train_test_DeterDistribDueling(env_train, env_test, RESOURCES, HYPER_PARAMETERS, save_path, LOAD_MODEL, TagName )


