import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import copy
import time
from collections import deque

import numpy as np
import torch
import torch.nn as nn

from a2c_ppo_acktr import algo, utils
from a2c_ppo_acktr.arguments import get_args
from a2c_ppo_acktr.envs import make_vec_envs, make_ven_envs_rew
from a2c_ppo_acktr.model import Policy
from a2c_ppo_acktr.storage import RolloutStorage
from evaluation import evaluate_attack
from torch import multiprocessing as mp

class Normalizer:
    _STATS_FNAME = "env_stats.pickle"

    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, in_size, num_process, device='cpu', dtype=torch.float):
        device='cpu'
        self.mean = torch.zeros((num_process, in_size), device=device, dtype=dtype)
        self.std = torch.ones((num_process, in_size), device=device, dtype=dtype)
        self.num_process = num_process
        self.eps = 1e-12 if dtype == torch.double else 1e-5
        self.device = device
        self.count = self.eps + torch.zeros((num_process, in_size), device=device, dtype=dtype)

    def update_stats(self, batch_data, batch_indices):
        if isinstance(batch_data, np.ndarray):
            batch_data = torch.from_numpy(batch_data).float().to(data.device)
        batch_data = batch_data.to('cpu')
        if isinstance(batch_indices, np.ndarray):
            batch_indices = torch.from_numpy(batch_indices).to('cpu')
        for i in range(self.num_process):
            index = (batch_indices == i).nonzero()
            data = torch.gather(batch_data, dim=0, index=index)
            if data.shape[0] > 1:
                batch_mean = data.mean(0, keepdim=True)
                batch_var = data.var(0, keepdim=True)
                batch_count = data.shape[0]
                self.update_from_moments(batch_mean, batch_var, batch_count, i)

    def update_from_moments(self, batch_mean, batch_var, batch_count, index):
        delta = batch_mean - self.mean[[index]]
        tot_count = self.count[[index]] + batch_count

        new_mean = self.mean[[index]] + delta * batch_count / tot_count
        m_a = torch.square(self.std[[index]]) * (self.count[[index]])
        m_b = batch_var * (batch_count)
        M2 = m_a + m_b + torch.square(delta) * self.count[[index]] * batch_count / (self.count[[index]] + batch_count)
        new_var = M2 / (self.count[[index]] + batch_count)

        new_count = batch_count + self.count[[index]]

        self.mean[[index]] = new_mean
        self.std[[index]] = torch.sqrt(new_var)
        self.count[[index]] = new_count

    def normalize(self, val, index):
        if isinstance(val, np.ndarray):
            val = torch.from_numpy(val).to(self.device)
        std = torch.clamp(self.std, self.eps)
        mean = self.mean[index]
        std = std[index]
        return (val - mean.to(val.device)) / std.to(val.device)

    def denormalize(self, val):
        if isinstance(val, np.ndarray):
            val = torch.from_numpy(val).to(self.device)
        std = torch.clamp(self.std, self.eps)
        return std * val.to(val.device) + self.mean.to(val.device)
 
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


def main():
    ctx = mp.get_context('spawn')
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:{}".format(args.cuda_id) if args.cuda else "cpu")
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Construct params
    params = {"num_processes": args.num_processes, 
              "tar_model": args.models, 
              'gpu_id': args.cuda_id,
              "datasets": args.datasets}
  
    params['seed'] = args.seed
    params['use_knn'] = args.use_knn
    params['num_actors'] = args.num_actors
    params['entropy_coef'] = args.env_entropy_coef
    params['random_init'] = args.random_init
    params['cuda_id'] = args.cuda_id
    params['env_name'] = args.env_name
    params['openai_key'] = args.openai_key

    obs_size = 1024
    # TODO: based on the embedder, change the obs_size

    print('Experiment params ', params, flush=True)
    print('Experiment arguments ', args, flush=True)

    if args.evaluate:
        evaluate_attack(params, args, obs_size, args.ckpt_path, device)
        exit()
    
    if args.use_value:
        print('use value network in PPO.')
    else:
        print('Not use value network in PPO.')
    if args.use_rew_model:
        envs = make_ven_envs_rew(params['seed'], params, args.max_steps, args.num_processes, args.gamma, obs_size, args.cuda_id)
    else:
        envs = make_vec_envs(params['seed'], params, args.max_steps, args.num_processes, args.gamma, obs_size, args.cuda_id)
  
    num_blocks = int(envs.observation_space.shape[0]/obs_size)
    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        args.use_attention,
        device,
        num_blocks,
        base_kwargs={'recurrent': args.recurrent_policy,
            'hidden_size': 1024})
    actor_critic.to(device)

    if args.algo == 'ppo':
        agent = algo.PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm)
    else:
        assert False

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                            envs.observation_space.shape, envs.action_space,
                            actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    all_obs = []
    all_rews = []
    all_indexs = []
    if args.normalize_rew:
        rew_normalizer = Normalizer(1, 16 * len(params['label_dict'].keys()))
    else:
        rew_normalizer = None
    cur_best_ep_reward = 0.0
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Observation reward and next obs
           
            
            obs, reward, done, infos = envs.step(action)
     
            if args.normalize_obs:
                actor_critic.base.normalizer.update_stats(obs)
            all_rews.append(copy.deepcopy(reward))
            all_obs.append(copy.deepcopy(obs))
        
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
            if done[0]:
                episode_rewards.append(info['episode_r'])
            # if done[0]:
            #     total_rew = []
            #     for info in infos:
            #         total_rew.append(info['episode_r'])
            #     episode_rewards.append(np.mean(total_rew))

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        if args.normalize_obs:
            all_obs = []
        if args.normalize_rew:
            rew_normalizer.update_stats(torch.cat(all_rews, dim=0), torch.from_numpy(np.concatenate(all_indexs, axis=0)))
        if args.normalize_rew:
            rollouts.update_rew(rew_normalizer)
        if args.normalize_rew:
            all_indexs = []
            all_rews = []

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                args.gae_lambda, rew_normalizer, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts, args.use_value)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "" and j > 0 :
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
            ], os.path.join(save_path, args.env_name + "_iter%d.pt"%j))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(f'time elapsed now: {end-start}\n')
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.3f}/{:.3f}, min/max reward {:.3f}/{:.3f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss), flush=True)
            ep_reward = np.mean(episode_rewards)
            if ep_reward > cur_best_ep_reward:
                print("Updates {}, new max mean reward {}".format(j, ep_reward))
                save_path = os.path.join(args.save_dir, args.algo)
                try:
                    os.makedirs(save_path)
                except OSError:
                    pass

                torch.save([
                    actor_critic,
                    getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
                ], os.path.join(save_path, args.env_name + "best.pt"))
                
                cur_best_ep_reward = ep_reward

        # evaluation TBD...


if __name__ == "__main__":
    torch.multiprocessing.set_start_method('spawn')# good solution !!!!
    main()
