


from itertools import combinations
import argparse
import wandb

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.nn.parallel import DataParallel

from policy import Policy_REINFORCEMENT, PPO_MLP, PPO_LSTM, PPO_EnsembledMLP
from environment import make_state_for_MLP_agent, make_state_for_LSTM_agent


#args parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str)
parser.add_argument('--child_model_batch_size', default=64, type=int)
parser.add_argument('--child_model_lr', default=0.001, type=float)
parser.add_argument('--child_model', type=str)
parser.add_argument('--algorithm', type=str)
parser.add_argument('--training_epoch', type=int, default= 100)
parser.add_argument('--agent', type=str)
parser.add_argument('--update_sample_num', type=int, default=20)
parser.add_argument('--GPUparallel', type=bool, default=False)
parser.add_argument('--EpisodeNum', type=int, default=1)
parser.add_argument('--today', type=float)
parser.add_argument('--step_per_iter', type=int, default=50)
parser.add_argument('--type', type=str)
parser.add_argument('--checkpointFolderPath', type=str)

opt = parser.parse_args()


#visualization
wandb.init(project="BigNoise")


#injection module의 개수

if opt.child_model == "resnet18":
    DROPOUT_MODULE_NUM = 8
elif opt.child_model == "resnet34" or opt.child_model == "resent50":
    DROPOUT_MODULE_NUM = 16
elif opt.child_model == "resnet101":
    DROPOUT_MODULE_NUM = 33
print("DROPOUT_MODULE_NUM",DROPOUT_MODULE_NUM)


#hyper param tuning
if opt.algorithm == "REINFORCEMENT":
    GAMMA = 0.99
    LEARNING_RATE = 0.0005
    WEIGHT_DECAY= 0.0001

elif opt.algorithm == "PPO":
    GAMMA = 0.99
    K_epochs = 5  # update policy for K epochs in one PPO update

    eps_clip = 0.2  # clip parameter for PPO
    gamma = 0.99  # discount factor

    lr_actor = 0.0003  # learning rate for actor network
    lr_critic = 0.001  # learning rate for critic network


TRAIN_STEP_PER_ITER = VAL_STEP_PER_ITER = opt.step_per_iter
QUEUE_ITER_SIZE = 5


#device
device = (
            "cuda"
            if torch.cuda.is_available()
            else "mps"
            if torch.backends.mps.is_available()
            else "cpu"
        )

def frange(start, stop, step):
    x = start
    while x < stop:
        yield x
        x += step

#dropout-probability sets
DROPOUT_MODULE = [prob for prob in frange(0.1,0.6,0.005)]


ACTION_NUM = len(DROPOUT_MODULE)
ACTION_IDX = {}
for action_idx, prob in enumerate(DROPOUT_MODULE):
    ACTION_IDX[action_idx] = prob


#Queue definition
class Queue():
    def __init__(self, size):
        self.size = size 
        self.queue = [0 for x in range(size)]

    def enqueue(self, data):
        self.queue.pop(0)
        self.queue.append(data)

class Queues():
    def __init__(self, train_loss_queue, train_f1_queue, val_loss_queue, val_f1_queue):
        self.train_loss_queue = train_loss_queue
        self.train_f1_queue = train_f1_queue
        self.val_loss_queue = val_loss_queue
        self.val_f1_queue = val_f1_queue


#Queue setting
def Queue_init(agent_type):

    if agent_type == "MLP_agent" or "EnsembledMLP_agent":
        train_loss_queue = Queue(size=TRAIN_STEP_PER_ITER )
        train_f1_queue = Queue(size=TRAIN_STEP_PER_ITER )
        val_loss_queue = Queue(size=VAL_STEP_PER_ITER )
        val_f1_queue = Queue(size=VAL_STEP_PER_ITER )

    elif agent_type == "LSTM_agent":
        train_loss_queue = Queue(size = TRAIN_STEP_PER_ITER * QUEUE_ITER_SIZE)
        train_f1_queue = Queue(size = TRAIN_STEP_PER_ITER * QUEUE_ITER_SIZE)
        val_loss_queue = Queue(size = VAL_STEP_PER_ITER * QUEUE_ITER_SIZE)
        val_f1_queue = Queue(size = VAL_STEP_PER_ITER * QUEUE_ITER_SIZE)

    queues = Queues(train_loss_queue, train_f1_queue, val_loss_queue, val_f1_queue)
    
    return queues

#init queues
queues = Queue_init(opt.agent)

#make_state function
if opt.agent == "MLP_agent" or "EnsembledMLP_agent":
    make_state = make_state_for_MLP_agent
elif opt.agent == "LSTM_agent":
    make_state = make_state_for_LSTM_agent


#Environment setting
from environment import Environment
env = Environment(ACTION_IDX, TRAIN_STEP_PER_ITER, VAL_STEP_PER_ITER, opt.child_model_batch_size, opt.child_model_lr, opt.child_model, QUEUE_ITER_SIZE, opt.agent, opt.GPUparallel,opt.dataset, opt.today, opt.type, opt.checkpointFolderPath)


#policy network setting
input_size = 2 * TRAIN_STEP_PER_ITER  + 2 * VAL_STEP_PER_ITER
    #1.train loss
    #2.train f1
    #3.val loss
    #4.val f1

if opt.algorithm == "REINFORCEMENT":
    policy = Policy_REINFORCEMENT(input_size,ACTION_NUM,GAMMA).to(device)
    policy.optimizer = optim.Adam(policy.parameters(), lr=LEARNING_RATE, weight_decay= WEIGHT_DECAY)

elif opt.algorithm == "PPO":

    if opt.agent == "MLP_agent":
        policy = PPO_MLP(input_size, ACTION_NUM, lr_actor, lr_critic, GAMMA, K_epochs, eps_clip)
    
    elif opt.agent == "LSTM_agent":
        policy = PPO_LSTM(input_size, ACTION_NUM, lr_actor, lr_critic, GAMMA, K_epochs, eps_clip, opt.update_sample_num)
    elif opt.agent =="EnsembledMLP_agent":
        policy = PPO_EnsembledMLP(input_size, ACTION_NUM, lr_actor, lr_critic, GAMMA, K_epochs, eps_clip, opt.GPUparallel)


if opt.EpisodeNum ==1 :

    for i in range(opt.training_epoch):
        # init
        if opt.agent == "MLP_agent" or "EnsembledMLP_agent":
            next_state = make_state(queues)
        elif opt.agent == "LSTM_agent":
            next_state = make_state(queues, QUEUE_ITER_SIZE)
        done = 0
        stpes = 0
        while(not done):
            given_state = next_state
            sampled_action_index = policy.sample_action(given_state)
            reward, next_state ,done = env.step(sampled_action_index, queues)
            policy.put_reward(reward)
            policy.put_done(done)

            stpes = stpes + 1
            if stpes % opt.update_sample_num ==0 or done:
                policy.train_net()

        
        
else:
    #episode number
    for episode_num in range(opt.EpisodeNum):
        print(episode_num)

        for i in range(opt.training_epoch):
            # init
            if opt.agent == "MLP_agent" or "EnsembledMLP_agent":
                next_state = make_state(queues)
            elif opt.agent == "LSTM_agent":
                next_state = make_state(queues, QUEUE_ITER_SIZE)
            done = 0
            stpes = 0
            while(not done):
                given_state = next_state
                sampled_action_index = policy.sample_action(given_state)
                reward, next_state ,done = env.step(sampled_action_index, queues)
                policy.put_reward(reward)
                policy.put_done(done)

                stpes = stpes + 1
                if stpes % opt.update_sample_num ==0 or done:
                    policy.train_net()
  
        
        #save policy param
        #torch.save(policy.policy.state_dict(),f"/VirtualSanghwa/pythonProject/RLNoiseInjectionPolicyBaseAgentDropout/{opt.child_model_batch_size}+{opt.child_model_lr}+{opt.child_model}+{opt.algorithm}+{opt.training_epoch}+{opt.agent}+{opt.update_sample_num}+{opt.EpisodeNum}+{opt.today}+{opt.step_per_iter}+{opt.dataset}.pth" )
        
       
        #env reset
        #env = Environment(ACTION_IDX, TRAIN_STEP_PER_ITER, VAL_STEP_PER_ITER, opt.child_model_batch_size, opt.child_model_lr, opt.child_model, QUEUE_ITER_SIZE, opt.agent, opt.GPUparallel)
        
        #init queues
        #queues = Queue_init(opt.agent)

