from operator import attrgetter
import torch

import numpy
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# import mujoco_py
# uncomment this line for mujoco
import argparse
import gym
#import mujoco_py
import numpy as np
from gym.spaces import Box, Discrete
import setup
import math
from poison_rl.memory import Memory
from poison_rl.agents.vpg import VPG
from poison_rl.agents.ppo import PPO
from poison_rl.attackers.wb_attacker import WbAttacker
from poison_rl.attackers.fgsm_attacker import FGSMAttacker
from poison_rl.attackers.targ_attacker import TargAttacker
from poison_rl.attackers.bb_attacker import BbAttacker
from poison_rl.attackers.rand_attacker import RandAttacker
from torch.distributions import Categorical, MultivariateNormal
import random
import logging
from datetime import datetime
from tqdm import tqdm

import copy
import pdb

now = datetime.now()
current_time = now.strftime("%m-%d %H:%M:%S")

parser = argparse.ArgumentParser()

parser.add_argument('--device', type=str, default="cuda:0")
parser.add_argument('--run', type=int, default=-1)
# env settings
parser.add_argument('--env', type=str, default="CartPole-v0", help = "CartPole-v0, LunarLander-v2, Hopper-v2, HalfCheetah-v2, Walker2d-v2")
parser.add_argument('--steps', type=int, default=300)
parser.add_argument('--num_runs', type=int, default=1)

# learner settings
parser.add_argument('--learner', type=str, default="vpg", help="vpg, ppo, sac")
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type = float, default = 0.99)

# Federated settings
parser.add_argument('--agents',type = int, default=10) # number of all agents
parser.add_argument('--Magents', type = int, default =1) # number of malicious agents
parser.add_argument('--rounds', type = int, default = 3) # number of communication rounds
parser.add_argument('--episodes', type=int, default=50)
parser.add_argument('--seed', type=int, default=100)

# attack settings
parser.add_argument('--norm', type=str, default="l2")
parser.add_argument('--stepsize', type=float, default=0.05)
parser.add_argument('--maxiter', type=int, default=10)
parser.add_argument('--radius', type=float, default=1)
parser.add_argument('--radius-s', type=float, default=0.1)
parser.add_argument('--radius-a', type=float, default=0.3)
parser.add_argument('--radius-r', type=float, default=1)
parser.add_argument('--frac', type=float, default=1)
parser.add_argument('--type', type=str, default="wb", help="wb, bb, rand, semirand,targ")
parser.add_argument('--REround', type = int, default = 0)
# defense settings
parser.add_argument('--defense', type = int, default = 0, help = "0 for False, 1 for True")

parser.add_argument('--aim', type=str, default="reward", help="reward, obs, action")

parser.add_argument('--attack', dest='attack', action='store_true')
parser.add_argument('--no-attack', dest='attack', action='store_false')
parser.set_defaults(attack=True)

parser.add_argument('--compute', dest='compute', action='store_true')
parser.add_argument('--no-compute', dest='compute', action='store_false')
parser.set_defaults(compute=False)

# file settings
parser.add_argument('--logdir', type=str, default="logs/")
parser.add_argument('--resdir', type=str, default="results_NT/", help = "results_NT/, results/")
parser.add_argument('--moddir', type=str, default="models_NT/")
parser.add_argument('--weightdir', type=str, default="weights_NT/")
parser.add_argument('--loadfile', type=str, default="")
parser.add_argument('--filename', type = str, default = "")

# early stop settings
parser.add_argument('--score_thresh', type=int, default=0)
parser.add_argument('--rd_thresh', type=int, default=1000)

args = parser.parse_args()

# def get_log(file_name):
#     logger = logging.getLogger('train') 
#     logger.setLevel(logging.INFO) 

#     fh = logging.FileHandler(file_name, mode='a') 
#     fh.setLevel(logging.INFO) 
    
#     formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
#     fh.setFormatter(formatter)
#     logger.addHandler(fh)  
#     return logger

def Train_Local_Agent(attack_flag, random_seed, private_critic, agg_policy, flag_r0): 
    
    # Local training;
  
    
    ############## Hyperparameters ##############
    env_name = args.env #"LunarLander-v2"
    
    max_episodes = args.episodes        # max training episodes
    max_steps = args.steps         # max timesteps in one episode
    
    attack = attack_flag
  
    compute = args.compute
    attack_type = args.type
    learner = args.learner
    aim = args.aim
    
    stepsize = args.stepsize
    maxiter = args.maxiter
    radius = args.radius
    frac = args.frac
    lr = args.lr
    device = 'cuda:0' if torch.cuda.is_available() and 'cuda' in args.device else 'cpu'
    ############ For All #########################
    gamma = args.gamma                # discount factor
    random_seed = random_seed 
    render = False    
    start_episode = 0
    # train starting from the coordinator's broadcast model
     
    memory = Memory()
    
    all_rewards = []
    timestep = 0

    ########## creating environment
    env = gym.make(env_name)
    
    if random_seed:
        torch.manual_seed(random_seed)
        env.seed(random_seed)


    ########## initialize learner
    if learner == "vpg":
        policy_net = VPG(env.observation_space, env.action_space, gamma=gamma, device=device, learning_rate=lr)
    elif learner == "ppo":
        policy_net = PPO(env.observation_space, env.action_space, gamma=gamma, device=device, learning_rate=lr)
    else:
        policy_net = None
    
    if not flag_r0: # not first round, initializa learner as broadcast
        policy_net.set_model_state_dict(agg_policy)

    ########## initialize attacker
  
    # clean critic
    attack_net = WbAttacker(env.observation_space, env.action_space, policy_net, maxat=int(frac*max_episodes), maxeps=max_episodes,
                                gamma=gamma, learning_rate=lr, maxiter=maxiter, radius=radius, stepsize=stepsize, device=device)
    # poisoned critic
    if args.learner == "ppo":
        attack_net_poisoned = WbAttacker(env.observation_space, env.action_space, policy_net, maxat=int(frac*max_episodes), maxeps=max_episodes,
                                gamma=gamma, learning_rate=lr, maxiter=maxiter, radius=radius, stepsize=stepsize, device=device)
    
    # inherit the clean critic from itself from last round
    if not flag_r0 and args.learner == "ppo" and attack_flag: # not first round, inherit private critic from last round
        attack_net.critic.v_net = copy.deepcopy(private_critic.critic.v_net)
    
    if attack_flag:
        print("Poisoning seed {} ...".format(random_seed))
    else:
        print("Clean Local Training {} ...".format(random_seed))

    ######### training
    # Here episode corresponds to the local step in the paper
    for episode in range(start_episode, max_episodes):
        # if episode != 0:
        #     checkpoint = torch.load(args.moddir + args.loadfile)
        #     print("load from ", args.moddir + args.loadfile)
        
        state = env.reset()
        # if len(state)!=1:
            # state = state[0]
        rewards = []
        total_targ_actions = 0
        
        for steps in range(max_steps):
            
            timestep += 1
            
            if render:
                env.render()
                
            # Sample actions by policy NN
            state_tensor, action_tensor, log_prob_tensor = policy_net.act(state)
            
            if isinstance(env.action_space, Discrete):
                action = action_tensor.item()
                if attack_type == "targ" or attack_type == "fgsm":
                    if action == target_policy:
                        total_targ_actions += 1
            else:
                action = action_tensor.cpu().data.numpy().flatten()
                if attack_type == "targ" or attack_type == "fgsm":
                    total_targ_actions += np.linalg.norm(action - target_policy.numpy()) ** 2     
                
            # Interact with env and sample observations
            new_state, reward, done,_ = env.step(action)
            
            if attack_type == "fgsm":                
                new_state = attack_net.attack(new_state)
            
            rewards.append(reward)
            
            # Integrate observations into memory
            memory.add(state_tensor, action_tensor, log_prob_tensor, reward, done)
            
            state = new_state

            # update models after rolling out needed observations, end the current episode and proceed to next episode
            if done or steps == max_steps-1: #timestep % update_every == 0:
                
                if attack and attack_type != "fgsm":
                    if aim == "reward":
                        
                        # by attack_r_general, attack_net update its A-C without poisoning: update_value, update_policy
                        # attack_r is the correct attack calculated by clean estimation provided by attack_net
                        attack_r = attack_net.attack_r_general(memory)

                        # update memory with poisoned rewards
                        memory.rewards = attack_r.copy()
                       
                        # update the attack_net_poisoned with the poisoned memory (rewards)
                        if args.learner == "ppo":
                            attack_net_poisoned.attacker_update(memory)
                        

                if attack_type == "targ" or attack_type == "fgsm":
                    if isinstance(env.action_space, Discrete):
                        targ_file.write(str(float(total_targ_actions) / (steps+1)) + "\n")
                        targ_metrix.append(float(total_targ_actions) / (steps+1))
                        # print("percent of target", float(total_targ_actions) / (steps+1))
                    else:
                        targ_file.write(str(math.sqrt(total_targ_actions / (steps+1))) + "\n")
                        targ_metrix.append(float(total_targ_actions) / (steps+1))
                        # print("average distance to target", math.sqrt(total_targ_actions / (steps+1)))
                
                else:
                    pass

                
                # poison the actor of the learner
                
                policy_net.update_policy(memory)
                
                # poison the critic of the learner
                if args.learner == "ppo" and attack_flag:
                    policy_net.policy.value_layer = copy.deepcopy(attack_net_poisoned.critic.v_net)

                memory.clear_memory()
                timestep = 0

                all_rewards.append(np.sum(rewards))
                break 

    if attack_type == "targ" or attack_type == "fgsm":
        targ_file.close()
 
    if attack_type == "targ":
        return targ_metrix 
    
    
    if attack_flag and args.learner == "ppo":
        return {
            "policy": policy_net, # poisoned actor and poisoned critic, communicate it with server (Eve: Apr 22)
            "private_critic": attack_net # clean critic, keep it private
        } 
    else:
        return{"policy": policy_net}



def aver_model_dict(dict_list, weight, df_flag):
    # generate an average dict from a dict list
    res = {}
    weight = torch.tensor(weight)
    if df_flag:
        # use the aver_return as weight for each model for aggregation
        # weight = torch.tensor(weight/(torch.max(weight)+0.0001))
        # weight  = torch.tensor(weight - torch.min(weight))
        weight = torch.tensor(weight/(sum(weight)+0.0001))
        for k in dict_list[0].keys():
            params_k = torch.stack([i_dict[k] for i_dict in dict_list])

            weighted_params_k = torch.stack([weight[i] * params_k[i] for i in range(len(weight))])
            
            res[k] = torch.sum( weighted_params_k, 0)
        
    else:
        for k in dict_list[0].keys():
            res[k] = torch.mean(torch.stack([i_dict[k] for i_dict in dict_list]), dim = 0)
    
    return res

    

def Poison_Fed_RL(random_seed = args.seed,df_flag = args.defense):

    CO_mean_rs, CO_std_rs = [],[]

    private_critics = []

    args.filename = args.env + "_" + args.learner + "_A" + str(args.agents) + "_M" + str(args.Magents)+ "_C"+ str(args.rounds)+ "_n" + str(args.episodes) +\
                    "_" + args.type + "_" + args.aim  + "_r" + str(args.radius) + "_f" + str(args.frac) + "_s" + str(args.seed)
    filename =  args.filename 
    CO_filename = "CO_" + filename

    if not os.path.exists(args.resdir + filename + "/"):
        os.mkdir(args.resdir + filename + "/")

    early_bad = 0

    for i_round in tqdm(range(args.rounds)):
        print("\n\nFederated Round", i_round+1)
        policy_models, private_attackers = [], [] # private attacker has private clean critic. 
        
        state_dicts_policy, state_dicts_policy_opt = [],[]
        curr_private_critics = []
        weights = [] 
        # start local training 
        for i_agent in range(args.agents):       
            attack_flag = True if i_agent <= args.Magents-1 else False
            
            pri_cri = None
            if args.learner == "ppo" and i_round > 0 and i_agent <= args.Magents-1:
                pri_cri = private_critics[i_agent] 
                
            agg_pol = None if i_round == 0 else agg_policy
            
            flag_0 = 1 if i_round == 0 else 0
            
            i_model = Train_Local_Agent(attack_flag, random_seed, pri_cri, agg_pol, flag_0)
 
            # Submit policy models
            state_dicts_policy.append(i_model["policy"].policy.state_dict())
            # state_dicts_policy_opt.append(i_model["policy"].optimizer.state_dict())
            
            # Store private critic for poisoned agents to send them back in the next round of local training
            if args.learner == "ppo" and i_agent <= args.Magents-1:
                if i_round == 0:
                    private_critics.append(i_model["private_critic"]) # send these critics back in the next round
                else:
                    private_critics[i_agent] = i_model["private_critic"]
            else:
                pass
            
            if df_flag:
                i_rewards = Evaluate_model(i_model["policy"].policy.state_dict(), episodes = 10, random_seed = args.seed)
                i_mean_r = np.mean(i_rewards)
                weights.append(i_mean_r)
            
        # aggregate models (only for policy_model, not private critic) 
        agg_policy = aver_model_dict(dict_list = state_dicts_policy, weight = weights, df_flag = df_flag)
                
        # Evaluate agg model
        i_CO_rs = Evaluate_model(agg_policy = agg_policy, episodes = 50, random_seed = args.seed)
        i_CO_mean_r, i_CO_std_r = np.mean(i_CO_rs), np.std(i_CO_rs)
        CO_mean_rs.append(i_CO_mean_r)
        CO_std_rs.append(i_CO_std_r)
        print("Coordinator mean reward:", i_CO_mean_r)

        if args.Magents == 0:
            if i_CO_mean_r < args.score_thresh:
                early_bad += 1
            if early_bad >= args.rd_thresh:
                break

    np.save(args.resdir + args.filename + "/" + CO_filename  + "_mean.npy", np.array(CO_mean_rs))
    np.save(args.resdir + args.filename + "/" + CO_filename  + "_std.npy", np.array(CO_std_rs))
    # return (CO_mean_rs, CO_std_rs)
    # Evaluate_Coordinator_plot(CO_mean_rs, CO_std_rs, args.resdir + args.filename + "/")



      
def Evaluate_model(agg_policy, episodes, random_seed):
   
    ##### activate environment
    env = gym.make(args.env)
    if random_seed:
        torch.manual_seed(random_seed)
        env.seed(random_seed)
    else:
        pass

    if args.type == "targ":
        if isinstance(env.action_space, Discrete):
            action_dim = env.action_space.n
            target_policy = action_dim - 1
        elif isinstance(env.action_space, Box):
            action_dim = env.action_space.shape[0]
            target_policy = torch.zeros(action_dim)
        else:
            pass
    else:
        pass

    ########## create learner
    if args.learner == "vpg":
        policy_net = VPG(env.observation_space, env.action_space, gamma=args.gamma, device=args.device, learning_rate=args.lr)
    elif args.learner == "ppo":
        policy_net = PPO(env.observation_space, env.action_space, gamma=args.gamma, device=args.device, learning_rate=args.lr)
    else:
        pass

    policy_net.set_model_state_dict(agg_policy)
    # policy_net.set_model_state_dict(i_checkpoint['model_state_dict'])
    
    ##### Link result file for target action fraction
    if args.type == "targ":
        i_targ_file = open(args.resdir + filename +"_targ" +".txt", "w")
    
    ##### Testing      
    start_episode = 0
    all_rewards = []
    timestep = 0
    update_num = 0
        
    ######### training

    for episode in range(start_episode, episodes):
        state = env.reset()
        # if len(state)!=1:
            # state = state[0]
        rewards = []
        total_targ_actions = 0
        for steps in range(args.steps):
            timestep += 1
            state_tensor, action_tensor, log_prob_tensor = policy_net.act(state)
            
            if isinstance(env.action_space, Discrete):
                action = action_tensor.item()
                if args.type == "targ" or args.type == "fgsm":
                    if action == target_policy:
                        total_targ_actions += 1
            else:
                action = action_tensor.cpu().data.numpy().flatten()
                if args.type == "targ" or args.type == "fgsm":
                    total_targ_actions += np.linalg.norm(action - target_policy.numpy()) ** 2
#           print(action, target_policy, total_targ_actions)
                
            new_state, reward, done, _ = env.step(action)

            rewards.append(reward)
            
            # memory.add(state_tensor, action_tensor, log_prob_tensor, reward, done)
            
            if done or steps == args.steps-1: #timestep % update_every == 0:
                if args.type == "targ" or args.type == "fgsm":
                    if isinstance(env.action_space, Discrete):
                        i_targ_file.write(str(float(total_targ_actions) / (steps+1)) + "\n")
                    else:
                        i_targ_file.write(str(math.sqrt(total_targ_actions / (steps+1))) + "\n")
                        # print("average distance to target", math.sqrt(total_targ_actions / (steps+1)))
                # policy_net.update_policy(memory)
                # memory.clear_memory()
                
                else:
                    all_rewards.append(np.sum(rewards))

                timestep = 0
                update_num += 1
                break
                
            state = new_state
        
        # if (episode+1) % save_every == 0 and args.type != "rand" and args.type != "fgsm":
            # print("Episode", episode+1 ,"/",args.episodes)
        
    env.close()

    if args.type == "targ" or args.type == "fgsm":
        i_targ_file.close()
    else:
        # np.save(args.resdir + filename + ".npy", np.array(all_rewards))
        return np.array(all_rewards)
        



def Evaluate_Coordinator_plot(CO_mean_rs, CO_std_rs, write_loc):
    CO_mean_rs, CO_std_rs = np.array(CO_mean_rs) , np.array(CO_std_rs)
    filename =  args.filename 
        # CO_filename += "_" + args.type + "_" + args.aim + "_s" + str(args.stepsize) + "_m" + str(args.maxiter) + "_r" + str(args.radius) + "_f" + str(args.frac)

        # save for each communication round 
    CO_filename = "CO_" + filename
        
    if args.defense:
        CO_filename = "DF_" + CO_filename

    n_round = len(CO_mean_rs)
    
    fig, ax = plt.subplots()
    x_axis = np.linspace(1,n_round, n_round)
    ax.plot( x_axis,  CO_mean_rs, '*', alpha=0.9, label = "Coordinator: mean")
    
    ax.fill_between(x_axis, CO_mean_rs - CO_std_rs, CO_mean_rs + CO_std_rs, alpha=0.2, color = "green", label = "+- std")
    if args.type == "targ":
        plt.ylim([0,1])
        plt.ylabel("Fraction of Target Actions")
    else:
        plt.ylabel("Mean Reward Per Episode")
    plt.xlabel("Communication Rounds (" + str(args.episodes) + " test episodes per round)")
    
    plt.legend()

    # plt.title("Coordinator Performance: "+ str(args.type) + ", " + str(args.learner) +  ", "  + str(args.env) +", " + str(args.agents)+" agents")
    plt.title(CO_filename)
    plt.savefig("{}/plot.jpg".format(write_loc))
        



if __name__ == '__main__': 
    for i in range(args.num_runs):
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)

        Poison_Fed_RL(args.seed)
        args.seed += 1

    # Evaluate_Coordinator(n_round = args.rounds, random_seed = args.seed, attack_flag = 1)
    # Evaluate_Coordinator_plot(CO_mean_rs, CO_std_rs)
    # Evaluate_Agents_vs_episode_plot()
    # Evaluate_Agents_vs_CRound_plot()
    print("finish!")
