import argparse
from lib2to3.pgen2.grammar import opmap_raw
from statistics import mode
import gym
import numpy as np
import os
import wandb
import math
import utils
import json
from custom_callbacks import CustomCallbacks, MultiAgentCallbacks
from copy import deepcopy

import ray
import ray.rllib.agents.ppo as ray_ppo
from ray.rllib.examples.models.centralized_critic_models import \
    CentralizedCriticModel, TorchCentralizedCriticModel
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.models import ModelCatalog
from pfl_hypernet import PFL_Hypernet
from ccppo import CCPPOTorchPolicy, CCTrainer

from gym_socialgame.envs.socialgame_env import (SocialGameEnvRLLib)
from gym_microgrid.envs.microgrid_env import (MicrogridEnvRLLib)
from gym_microgrid.envs.multiagent_env import (MultiAgentSocialGameEnv, MultiAgentMicrogridEnv)

import torch
import torch.optim as optim
from collections import OrderedDict


hnet = None
hnet_optimizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
agg_data_dict = {}

def get_agent(args):
    global hnet, hnet_optimizer, device
    """
    Purpose: Import the algorithm and policy to create an agent. 
    Returns: Agent
    Exceptions: Malformed args to create an agent. 
    """
    ### Setup for MultiAgent ###
    #Create a dummy environment to get action and observation space for our settings
    dummy_env = environments[args.gym_env](vars(args))
    # TODO: HACK
    # dummy_env = environments["socialgame"](vars(args))
    obs_space = dummy_env.observation_space
    act_space = dummy_env.action_space

    config = {}
    if args.num_gpus > 0:
        device = "cuda"
    else:
        device = 'cpu'
    if args.gym_env in ["socialgame_multi", "microgrid_multi"]:
        
        if args.algo=="ppo" or "afl_ppo":
            #pass
            #p {k: {k2: v2.shape for k2, v2 in v.items()} for k, v in agent.get_weights().items()}
            if args.use_agg_data:
                for i, e in enumerate(dummy_env.envs):
                    if args.hnet_embedding_dim == 0:
                        agg_data = []
                    else:
                        agg_data = [i]
                    agg_data += [pv / 200 for pv in e.pv_sizes] + [b / 200 for b in e.battery_sizes]
                    agg_data_dict[str(i)] = torch.tensor(agg_data).to(device=device)
            config["multiagent"] = {
                "policies": {str(i): (ray_ppo.PPOTorchPolicy, obs_space, act_space, {"fcnet_hidden": [args.sizes] * args.n_layers,
                                                                                     "fcnet_activation": args.fcnet_activation,
                                                                                     "number_of_agents": dummy_env.n_nodes}) for i, scenario in enumerate(range(dummy_env.n_nodes))},
                "policy_mapping_fn": lambda agent_id, episode=None: agent_id
            }
            
        elif args.algo=="single_ppo":
            print("USING SINGLE PPO MODEL")
            config["multiagent"] = {
                "policies": {'0': (ray_ppo.PPOTorchPolicy, obs_space, act_space, {"number_of_agents": dummy_env.n_nodes}) },
                "policy_mapping_fn": lambda agent_id, episode=None: '0'
            }
        elif args.algo=="ccppo":
            config["multiagent"] = {
                "policies": {str(i): (CCPPOTorchPolicy, obs_space, act_space, {"number_of_agents": dummy_env.n_nodes}) for i, scenario in enumerate(range(dummy_env.n_nodes))},
                "policy_mapping_fn": lambda agent_id, episode=None: agent_id
            }
        else:
            raise NotImplementedError

    #### Algorithm: PPO ####
    if args.algo in ["ppo", "ccppo", "single_ppo", "afl_ppo"]:
        # Modify the default configs for PPO
        default_config = ray_ppo.DEFAULT_CONFIG.copy()
        #merge config with default config (with ours overwriting in case of clashes)
        config = {**default_config, **config}
        config["framework"] = "torch"
        config["train_batch_size"] = args.batch_size
        config["sgd_minibatch_size"] = 16
        config["lr"] = args.learning_rate
        config["lambda"] = 0
        config["clip_param"] = args.ppo_clip_param
        config["vf_clip_param"]=5000
        config["num_sgd_iter"] = args.ppo_num_sgd_iter
        config["num_gpus"] =  args.num_gpus # this may throw an error
        config["num_workers"] = args.num_workers
        config["env_config"] = vars(args)
        config["env"] = environments[args.gym_env]
        #########################################################
        # Specify network parameters for policy functions #
        #########################################################
        model_params = {}
        model_params["fcnet_hiddens"] = [args.sizes] * args.n_layers
        model_params["fcnet_activation"] = args.fcnet_activation

        config["model"] = model_params
        # obs_dim = np.sum([args.energy_in_state, args.price_in_state])
        obs_dim = math.prod(obs_space.shape)
            
        out_path = os.path.join(args.log_path, "bulk_data.h5")
        if args.gym_env in ["socialgame_multi", "microgrid_multi"]:
            callbacks = MultiAgentCallbacks(log_path=out_path, save_interval=args.bulk_log_interval, 
                                            obs_dim=obs_dim, num_agents=dummy_env.n_nodes)
        else:
            callbacks = CustomCallbacks(log_path=out_path, save_interval=args.bulk_log_interval, obs_dim=obs_dim)

        config["callbacks"] = lambda: callbacks
        logger_creator = utils.custom_logger_creator(args.log_path)

        callbacks.save()
        print("Saved first callback")

        if args.wandb:
            wandb.save(out_path)
        if args.algo in ["ppo", "single_ppo", "afl_ppo"]:
            
            ret = ray_ppo.PPOTrainer(
                config = config, 
                env = environments[args.gym_env],
                logger_creator = logger_creator
            )
            
        else:
            print("RUNNING CCPPO")
            ModelCatalog.register_custom_model(
                "cc_model", TorchCentralizedCriticModel
                if config["framework"] == "torch" else CentralizedCriticModel)
            config["model"] = {"custom_model": "cc_model",}
            ret =  CCTrainer(
                config = config, 
                env = environments[args.gym_env],
                logger_creator = logger_creator
            )
        if args.use_hnet:
            shapes = {k:  v.shape for k, v in list(ret.get_weights().values())[0].items()}
            input_size = len(agg_data_dict['0']) if args.use_agg_data else 1
            hnet = PFL_Hypernet(n_nodes = dummy_env.n_nodes,
                        embedding_dim = args.hnet_embedding_dim, 
                        num_layers = args.hnet_num_layers,
                        num_hidden = args.hnet_num_hidden,
                        out_params_shapes = shapes,
                        lr = args.hnet_lr,
                        device = device,
                        input_size=input_size,
                        use_layernorm=args.hnet_use_layernorm,
                        dropout = args.hnet_dropout)
            if args.hnet_load_ckpt is not None:
                print("Loading HNET weights from {}".format(args.hnet_load_ckpt))
                hnet.load_state_dict(torch.load(args.hnet_load_ckpt))
            hnet_optimizer = optim.Adam(hnet.parameters(), lr=args.hnet_lr, weight_decay = args.hnet_l2_reg)
            wandb.watch(hnet)
        return ret

    # Add more algorithms here. 

def normalize(x):
    return x / torch.norm(x)

def pfl_hnet_update(agent, result, args, old_weights):
    # Get list of weights of each worker, including remote replicas
    #curr_weights = trainer.workers.foreach_worker(lambda ev: ev.get_policy().get_weights())
    curr_weights = agent.get_weights()
    # set gradients to 0 to start w
    hnet_optimizer.zero_grad()
    # update hnet weights
    num_agents = len(curr_weights)

    total_dot = 0
    for agent_id in curr_weights.keys():
        # calculating delta theta
        
        delta_theta = OrderedDict({k: old_weights[agent_id][k].to(device) - torch.tensor(curr_weights[agent_id][k], device=device) for k in curr_weights[agent_id].keys()})

        # calculating phi gradient
        hnet_grads = torch.autograd.grad(
            [v.to(device) for v in old_weights[agent_id].values()], hnet.parameters(), grad_outputs=list(delta_theta.values())
        )
        # update hnet weights
        for p, g in zip(hnet.parameters(), hnet_grads):
            # normalize so we end up with a grad that is the mean of all the updates from each agent
            p.grad = g / num_agents 
    torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
    hnet_optimizer.step()
    if args.hnet_save_ckpt is not None:
        torch.save(hnet.state_dict(), args.hnet_save_ckpt)
    new_weight_dict = {}
    weight_dots = {}

    for agent_id in curr_weights.keys():
        if args.use_agg_data:
            new_weights = hnet(agg_data_dict[agent_id])
        else:
            new_weights = hnet(int(agent_id))
        new_weight_dict[agent_id] = new_weights
        # Calculate the dot product of weights and average weight
        for k in curr_weights[agent_id].keys():
            if ".weight" in k: # exclude biases and batch norm params
                curr_w = torch.flatten(torch.from_numpy(curr_weights[agent_id][k]).to(device))
                old_w = torch.flatten(old_weights[agent_id][k].to(device))
                delta_h = torch.flatten(new_weights[k]) - old_w
                delta_l = curr_w - old_w
                layer_dot = torch.dot(
                    normalize(delta_h), 
                    normalize(delta_l))
                key = "dots/" + k
                weight_dots[key] = (layer_dot + weight_dots.get(key, 0)) / num_agents
    avg_dot = sum(weight_dots.values()) / len(weight_dots)
    weight_dots["dots/avg"] = avg_dot
    weight_dots["dots/max"] = max(weight_dots.values())
    #wandb.log({"weight_dot": total_dot}, commit=False)
    # Set weights of each worker, including remote replicas
    for w in agent.workers.remote_workers():
        w.set_weights.remote(new_weight_dict)
    #agent.set_weights(detach_weights(new_weight_dict))
    return result, new_weight_dict, weight_dots

def afl_update(agent, result, args, initial_update=False):
    curr_weights = agent.get_weights()
    avg_weights = OrderedDict()
    normalizer = 1 if initial_update else len(curr_weights) 

    with torch.no_grad():
        total_dot = 0
        for agent_id in curr_weights.keys():
            for k in curr_weights[agent_id].keys():
                curr_weight_tensor = torch.tensor(curr_weights[agent_id][k], device=device)
                if k not in avg_weights:
                    avg_weights[k] = curr_weight_tensor / normalizer
                elif not initial_update:
                    avg_weights[k] += curr_weight_tensor / normalizer
            # Calculate the dot product of weights and average weight
            for k in curr_weights[agent_id].keys():
                total_dot += torch.dot(
                    torch.flatten(torch.from_numpy(curr_weights[agent_id][k])), 
                    torch.flatten(avg_weights[k]))
        print("AFL total dot: {}".format(total_dot))
        wandb.log({"weight_dot": total_dot}, commit=False)
    new_weight_dict = {}
    for agent_id in curr_weights.keys():
        new_weight_dict[agent_id] = {k: v.clone().detach() for k, v in avg_weights.items()}#deepcopy(avg_weights)
    # Set weights of each worker, including remote replicas
    for w in agent.workers.remote_workers():
        w.set_weights.remote(new_weight_dict)
    #agent.set_weights(detach_weights(new_weight_dict))
    return result, new_weight_dict, {}

def detach_weights(weights):
    new_weights = {}
    for agent_id, agent_params in weights.items():
        agent_param_dict = {}
        for k, v in agent_params.items():
            agent_param_dict[k] = v.detach()
        new_weights[agent_id] = agent_param_dict
    return new_weights

def train(agent, args):
    """
    Purpose: Train agent in environment. 
    """
    library = args.library
    algo = args.algo
    env = args.gym_env
    num_steps = args.num_steps

    if library == "rllib":
        print("Initializing Ray")
        # ray.init()
        print("Ray Initialized")

    ## Beginning Training ##
    print("Beginning training.")
    to_log = []
    training_steps = 0
    if args.gym_env in ["socialgame_multi", "microgrid_multi"]:
        if args.use_hnet:
            weights = {}
            for agent_id in range(hnet.n_nodes):
                if args.use_agg_data:
                    weights[str(agent_id)] = hnet(agg_data_dict[str(agent_id)])
                else:
                    weights[str(agent_id)] = hnet(agent_id)
            # Set weights of each worker, including remote replicas
            #agent.workers.foreach_worker(lambda ev: ev.get_policy().set_weights(detach_weights(weights)))
            agent.set_weights(detach_weights(weights))
        if algo == "afl_ppo":
            result, weights = afl_update(agent, {}, args, initial_update=True)
    #breakpoint()
    num_train_steps = 0
    while training_steps < num_steps:
        result = agent.train()
        training_steps = result["timesteps_total"]
        num_train_steps += 1
        dots = None
        if args.gym_env in ["socialgame_multi", "microgrid_multi"]:
            if args.use_hnet and num_train_steps % args.hnet_num_local_steps == 0:
                print("****\nHNET UPDATE AT TIMESTEP {}\n****".format(training_steps))
                result, weights, dots = pfl_hnet_update(agent, result, args, old_weights=weights)
            if algo == "afl_ppo" and num_train_steps % args.afl_num_local_steps == 0:
                print("****\nAFL UPDATE AT TIMESTEP {}\n****".format(training_steps))
                result, weights, dots = afl_update(agent, result, args, initial_update=False)
        
        log = result['custom_metrics']
        if dots:
            log.update(dots)
        log["ray_num_steps_sampled"] = result["info"]["num_steps_sampled"]
        wandb.log(log)
        

    #list of different local agents
    #for each local agent, we call .train()
    #at the end of .train, that local agent should
    #have second order gradient updates ready for the hypernetwork
    #We should access those from out here and use them to update the hypernetwork
    #And then set the new local agent model weights

    #TODO: create a custom RLLib Policy and Trainer class that will do the optimization
    # on the local agent and compute second order gradient updates at the end of .train

    # callbacks.save() TODO: Implement callbacks
######################################
#### Arguments and Configurations ####
######################################

# Add environments here to be included when configuring an agent
environments = {
    "socialgame": SocialGameEnvRLLib,
    "microgrid": MicrogridEnvRLLib,
    "microgrid_multi": MultiAgentMicrogridEnv,
    "socialgame_multi": MultiAgentSocialGameEnv
}

parser = argparse.ArgumentParser()
# Custom Agents Arguments
parser.add_argument(
    "--custom_config",
    # type=argparse.FileType("r"),
    type=str,
    default = None,
    help="Path to custom config file"
)
# Library Argument
parser.add_argument(
    "--library",
    help = "What RL Library backend is in use",
    type = str,
    default = "rllib",
    choices = ["rllib", "tune"]
)
# Algorithm Arguments
parser.add_argument(
    "--algo",
    help="RL Algorithm",
    type=str,
    default="ppo",
    choices=["sac", "ppo", "maml", "uc_bandit", "ccppo", "single_ppo", "afl_ppo"]
)
parser.add_argument(
    "--num_workers",
    help = "Number of cores you want to \
            sacrifice to RLLib to run envs in parallel",
    type=int,
    default=1
)
parser.add_argument(
    "--num_mg_workers", 
    help = "Number of cores each environment should dedicate to optimizing \
            building responses in microgrid",
    type=int,
    default=2
)
parser.add_argument(
    "--action_space_string",
    help="action space for algorithm",
    default="continuous",
)
parser.add_argument(
    "--policy_type",
    help="Type of Policy (e.g. MLP, LSTM) for algo",
    default="mlp",
    choices=["mlp", "lstm"],
)
parser.add_argument(
    "--reward_function",
    help="reward function to test",
    type=str,
    default="profit_maximizing",
    choices=["scaled_cost_distance", "log_cost_regularized", "log_cost", "scd", "lcr", "lc", "market_solving", "profit_maximizing"],
)

# Environment Arguments
parser.add_argument(
    "--gym_env", 
    help="Which Gym Environment you wish to use",
    type=str,
    choices=["socialgame", "microgrid", "microgrid_multi", "socialgame_multi"],
    default="microgrid_multi"
)
parser.add_argument(
    "--num_mg_optim_steps", 
    help = "Number of optim steps for mg to optimize building response",
    type=int,
    default=10
)
parser.add_argument(
    "--max_ep_steps", 
    help = "Number of steps in each episode",
    type=int,
    default=7 # week-long trajectory
)
parser.add_argument("--starting_day",
    help = "Which day for each environment to start on on resets",
    type=int,
    default=15)
parser.add_argument(
    "--env_id",
    help="Environment ID for Gym Environment",
    type=str,
    choices=["v0", "monthly"],
    default="v0",
)
parser.add_argument(
    "--response_type_string",
    help="Player response function (l = linear, t = threshold_exponential, s = sinusoidal",
    type=str,
    default="l",
    choices=["l", "t", "s"],
)
parser.add_argument(
    "--pv_batt_config_path",
    help="Path to json containing pv_batt_scenarios",
    type=str,
    default="",
)
# Experiment Arguments
parser.add_argument(
    "--exp_name",
    help="experiment_name",
    type=str,
    default="experiment"
)
parser.add_argument(
    "--num_steps",
    help="Number of timesteps to train algo",
    type=int,
    default=50000,
)
parser.add_argument(
    "--energy_in_state",
    help="Whether to include energy in state (default = F)",
    action="store_true"
)
parser.add_argument(
    "--price_in_state",
    help="Whether to include price in state (default = F)",
    action="store_false"
)
parser.add_argument(
    "--points_multiplier",
    help="points multiplier for SocialGame simulated agents (default = 10)",
    default=10,
    type=float
)
parser.add_argument(
    "--batch_size",
    help="Batch Size for sampling from replay buffer",
    type=int,
    default=256,
    #choices=[i for i in range(1, 30)],
)
parser.add_argument(
    "--ppo_num_sgd_iter",
    help="num_sgd_iter parameter for PPO",
    type=int,
    default=30)

parser.add_argument(
    "--ppo_clip_param",
    help="clip_param parameter for PPO",
    type=float,
    default=30)
parser.add_argument(
    "--one_day",
    help="Specific Day of the year to Train on (default = 15, train on day 15)",
    type=int,
    default=15,
    choices=[i for i in range(365)],
)
parser.add_argument(
    "--manual_tou_magnitude",
    help="Magnitude of the TOU during hours 5,6,7. Sets price in normal hours to 0.103.",
    type=float,
    default=.4
)
parser.add_argument(
    "--pricing_type",
    help="time of use or real time pricing",
    type=str,
    choices=["TOU", "RTP"],
    default="TOU",
)
parser.add_argument(
    "--number_of_participants",
    help="Number of players ([1, 20]) in social game",
    type=int,
    default=10,
    choices=[i for i in range(1, 21)],
)
parser.add_argument(
    "--learning_rate",
    help="learning rate of the the agent",
    type=float,
    default=3e-4,
)
# Logging Arguments
parser.add_argument(
    "-w",
    "--wandb",
    help="Whether to upload results to wandb. must have wandb key.",
    action="store_true"
)
parser.add_argument(
    "--tag",
    help="What tag to put on wandb run",
    type=str,
    default=None
)
parser.add_argument(
    # "--base_log_dir",
    "--log_path",
    help="Base directory for tensorboard logs",
    type=str,
    default="./logs/"
)
parser.add_argument(
    "--bulk_log_interval",
    help="Interval at which to save bulk log information",
    type=int,
    default=50000
)
parser.add_argument(
    "--bin_observation_space",
    help= "Whether to bin the observations.",
    action="store_true"
)
parser.add_argument(
    "--smirl_weight",
    help="Whether to run with SMiRL. When using SMiRL you must specify a weight.",
    type = float,
    default=None,
)
parser.add_argument(
    "--circ_buffer_size",
    help="Size of circular smirl buffer to use. Will use an unlimited size buffer in None",
    type = float,
    default=None,
)
# Machine Arguments
parser.add_argument(
    "-l",
    "--local_mode",
    help="Init Ray in local mode for easier debugging.",
    action="store_true"
)
parser.add_argument(
    "--num_inner_steps",
    help="Number of local model training steps",
    type= int,
    default=1
)
parser.add_argument(
    "--num_gpus",
    help="Number of gpus",
    type= int,
    default=0
)
parser.add_argument(
    "--scenarios",
    nargs='+',
    help="List of complex_batt_pv_scenarios to have separate agents for",
    default=[]
)
# Hypernetwork parameters
parser.add_argument(
    "--use_hnet",
    help="Whether or not to use the hypernetwork for MARL optimization",
    action="store_true",
    )
parser.add_argument(
    "--hnet_embedding_dim",
    help="Size of the embedding",
    type= int,
    default=1
)
parser.add_argument(
    "--hnet_num_layers",
    help="Number of linear layers",
    type= int,
    default=1
)
parser.add_argument(
    "--hnet_num_hidden",
    help="Size of the hidden layers",
    type= int,
    default=1
)
parser.add_argument(
    "--hnet_out_params_path",
    help="Path to a file containing a dict specifying the size of the output weights",
    type= str,
    default="./socialgame_param_dict.txt"
)
parser.add_argument(
    "--hnet_lr",
    help="Hnet learning rate",
    type= float,
    default=3e-4
)
parser.add_argument(
    "--hnet_num_local_steps",
    help = "Number of local steps between hnet updates",
    type=int,
    default=1
)
parser.add_argument(
    "--afl_num_local_steps",
    help = "Number of local steps between afl updates",
    type=int,
    default=1
)
parser.add_argument(
    "--hnet_use_layernorm",
    help = "should the hypernetwork use batchnorm",
    action="store_true"
)
parser.add_argument(
    "--hnet_l2_reg",
    type=float,
    default=0
)
parser.add_argument(
    "--hnet_dropout",
    type=float,
    default=0
)
parser.add_argument(
    "--hnet_load_ckpt",
    type=str,
    default=None,
    help="Checkpoint path to load hnet weights from"
)
parser.add_argument(
    "--hnet_save_ckpt",
    type=str,
    default=None,
    help="Checkpoint path to save hnet weights to"
)
parser.add_argument(
    "--use_agg_data",
    action='store_true',
    help="include environment data",
)


# Local network parameters
parser.add_argument(
    "--sizes",
    help="number of nodes in policy networks",
    default=256,
    type=int
)
parser.add_argument(
    "--n_layers",
    help="number of neural layers",
    type=int,
    default=2
)
parser.add_argument(
    "--fcnet_activation",
    type=str,
    help="activation function of policy neurons",
    choices = ["relu", "tanh"],
    default="tanh"
)

#
# Call get_agent and train to recieve the agent and then train it. 
#
if __name__ == "__main__":
    args = parser.parse_args()
    print(f"Running with following options: {args}")

    # Uploading logs to wandb
    if args.wandb:
        tags = []
        if args.tag is not None:
            tags.append(args.tag)
        wandb.init(project="energy-demand-response-game", entity="social-game-rl", tags=tags)
        wandb.tensorboard.patch(root_logdir=args.log_path) # patching the logdir directly seems to work
        wandb.config.update(args)

    # Get Agent
    agent = get_agent(args)
    print("Agent initialized.")

    # Training
    print(f'Beginning Testing! Logs are being saved somewhere')
    train(agent, args)


