import copy
import glob
import os
import time
from collections import deque

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from a2c_ppo_acktr import algo, utils
# from a2c_ppo_acktr.algo import gail
from a2c_ppo_acktr.arguments import get_args
from a2c_ppo_acktr.envs import make_vec_envs, make_vec_envs_eval, make_vec_envs_fseval, get_num_test
from a2c_ppo_acktr.model import Policy
from a2c_ppo_acktr.storage import RolloutStorage
from evaluation import evaluate, evaluate_lm, evaluate_fs_lm
from utils import setup_roberta
from torch import multiprocessing as mp

class Normalizer:
    _STATS_FNAME = "env_stats.pickle"

    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, in_size, num_process, device='cpu', dtype=torch.float):
        device='cpu'
        self.mean = torch.zeros((num_process, in_size), device=device, dtype=dtype)
        self.std = torch.ones((num_process, in_size), device=device, dtype=dtype)
        self.num_process = num_process
        self.eps = 1e-12 if dtype == torch.double else 1e-5
        self.device = device
        self.count = self.eps + torch.zeros((num_process, in_size), device=device, dtype=dtype)
        # self.count = 1 + torch.zeros((num_process, in_size), device=device, dtype=dtype)

    def update_stats(self, batch_data, batch_indices):
        if isinstance(batch_data, np.ndarray):
            batch_data = torch.from_numpy(batch_data).float().to(data.device)
        batch_data = batch_data.to('cpu')
        if isinstance(batch_indices, np.ndarray):
            batch_indices = torch.from_numpy(batch_indices).to('cpu')
        # print('update ', data.shape)
        for i in range(self.num_process):
            index = (batch_indices == i).nonzero()
            data = torch.gather(batch_data, dim=0, index=index)
            # print(data.shape)
            if data.shape[0] > 1:
                batch_mean = data.mean(0, keepdim=True)
                batch_var = data.var(0, keepdim=True)
                batch_count = data.shape[0]
                # print('bef ', data.shape, data, batch_mean, batch_var, self.mean, self.std)
                self.update_from_moments(batch_mean, batch_var, batch_count, i)
                # print('aft ', self.mean, self.std)

    def update_from_moments(self, batch_mean, batch_var, batch_count, index):
        delta = batch_mean - self.mean[[index]]
        tot_count = self.count[[index]] + batch_count

        new_mean = self.mean[[index]] + delta * batch_count / tot_count
        m_a = torch.square(self.std[[index]]) * (self.count[[index]])
        m_b = batch_var * (batch_count)
        M2 = m_a + m_b + torch.square(delta) * self.count[[index]] * batch_count / (self.count[[index]] + batch_count)
        new_var = M2 / (self.count[[index]] + batch_count)

        new_count = batch_count + self.count[[index]]

        self.mean[[index]] = new_mean
        self.std[[index]] = torch.sqrt(new_var)
        # self.std[self.std < self.eps] = self.eps
        self.count[[index]] = new_count

    def normalize(self, val, index):
        # print('val ', val.shape)
        if isinstance(val, np.ndarray):
            val = torch.from_numpy(val).to(self.device)
        std = torch.clamp(self.std, self.eps)
        mean = self.mean[index]
        std = std[index]
        return (val - mean.to(val.device)) / std.to(val.device)

    def denormalize(self, val):
        if isinstance(val, np.ndarray):
            val = torch.from_numpy(val).to(self.device)
        std = torch.clamp(self.std, self.eps)
        return std * val.to(val.device) + self.mean.to(val.device)

    '''
    def load(self, results_dir):
        with open(pathlib.Path(results_dir) / self._STATS_FNAME, "rb") as f:
            stats = pickle.load(f)
            self.mean = torch.from_numpy(stats["mean"]).to(self.device)
            self.std = torch.from_numpy(stats["std"]).to(self.device)

    def save(self, save_dir):
        save_dir = pathlib.Path(save_dir)
        with open(save_dir / self._STATS_FNAME, "wb") as f:
            pickle.dump(
                {"mean": self.mean.cpu().numpy(), "std": self.std.cpu().numpy()}, f
            )
    '''

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


def main():
    ctx = mp.get_context('spawn')
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")
    
    # Construct params
    params = {
        'conditioned_on_correct_classes': True,
        # 'subsample_test_set': subsample_test_set,
        'api_num_log_prob': args.api_num_log_prob,
        'approx': args.approx,
        # 'bs': args.num_processes
        'bs': 1
    }
    params['model'] = args.models
    params['dataset'] = args.datasets
    params['seed'] = args.seed
    params['num_shots'] = args.num_shots
    params['expr_name'] = ""
    params['subsample_test_set'] = args.subsample_test_set
    params['env_name'] = args.env_name
    params['verbalizer'] = args.verbalizer
    params['rew_type'] = args.rew_type
    params['example_pool_size'] = args.example_pool_size
    params['use_knn'] = args.use_knn
    params['sub_sample'] = args.sub_sample
    params['num_actors'] = args.num_actors
    params['entropy_coef'] = args.env_entropy_coef
    params['random_init'] = args.random_init
    params['k'] = args.num_k_shots
    if args.models == 'gpt2-xl':
        obs_size = 2048
    elif args.models == 'gpt2-large':
        obs_size = 1280
    elif args.models == 'gpt2-medium':
        obs_size = 1024
    elif args.models == 'roberta-large':
        obs_size = 1024
    elif args.models == 't5-large':
        obs_size = 1024
    elif args.models == 't5-11b':
        obs_size = 1024
    elif args.models == 't5-3b':
        obs_size = 1024
    else:
        assert False
    # setup_roberta(args.models)
    # rank = torch.distributed.get_rank()
    print('params ', params, flush=True)
    print('arguments ', args, flush=True)

    envs = make_vec_envs(params['seed'], params, args.max_steps, args.num_processes, args.gamma, obs_size, 0)
    envs_fseval = make_vec_envs_fseval(params['seed'], params, args.max_steps, args.num_processes, args.gamma, obs_size, 0)
    num_test_samples = get_num_test(params['seed'], params, args.max_steps, args.num_processes, args.gamma, obs_size, 0, 0%torch.cuda.device_count())
    eval_envs = []
    for i in range(params['num_actors']):
        # print('actor ', i)
        eval_env, num_test_samples = make_vec_envs_eval(params['seed'], params, args.max_steps, args.num_processes, args.gamma, obs_size, False, i, i%torch.cuda.device_count())
        eval_envs.append(eval_env)
        # eval_envs.append(None)

    num_blocks = int(envs.observation_space.shape[0]/obs_size)
    actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        args.use_attention,
        'cuda',
        num_blocks,
        base_kwargs={'recurrent': args.recurrent_policy,
            'hidden_size': 1024})
    print(actor_critic, flush=True)
    actor_critic.to(device)

    hidden_dim = 256
    if args.exploration:
        pred_net = nn.Sequential(nn.Linear(obs_size, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,  hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,  hidden_dim))

        target_net = nn.Sequential(nn.Linear(obs_size, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,  hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim,  hidden_dim))
        pred_net.to(device)
        target_net.to(device)
        pred_optimizer = optim.Adam(pred_net.parameters(), lr=args.lr, eps=args.eps)


    eval_actor_critic = Policy(
        envs.observation_space.shape,
        envs.action_space,
        args.use_attention,
        'cpu',
        num_blocks,
        base_kwargs={'recurrent': args.recurrent_policy,
            'hidden_size': 1024})
    eval_actor_critic.load_state_dict(actor_critic.state_dict())
    eval_actor_critic.share_memory()

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(
            actor_critic,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            alpha=args.alpha,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(
            actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True)

    if args.exploration:
        agent.pred_net = pred_net
        agent.target_net = target_net
        agent.pred_optimizer = pred_optimizer
        agent.obs_size = obs_size

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        assert False
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir, "trajs_{}.pt".format(
                args.env_name.split('-')[0].lower()))
        
        expert_dataset = gail.ExpertDataset(
            file_name, num_trajectories=4, subsample_frequency=20)
        drop_last = len(expert_dataset) > args.gail_batch_size
        gail_train_loader = torch.utils.data.DataLoader(
            dataset=expert_dataset,
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=drop_last)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                            envs.observation_space.shape, envs.action_space,
                            actor_critic.recurrent_hidden_state_size)

    # TODO: get the mean and variance of
    '''
    obs = torch.cat([envs.venv.envs[0].current_prompt_embedding_pool.reshape(-1, obs_size),
        envs.venv.envs[0].add_current_prompt_embedding_pool.reshape(-1, obs_size),
        envs.venv.envs[0].current_verbalizer_embedding_pool.reshape(-1, obs_size),
        envs.venv.envs[0].add_current_verbalizer_embedding_pool.reshape(-1, obs_size)], dim=0)
    obs_mean = torch.mean(obs, dim=0, keepdims=True)
    obs_std = torch.std(obs, dim=0, keepdims=True)
    block_size = int(actor_critic.base.normalizer.mean.shape[-1] / obs_size)
    obs_mean = torch.cat([obs_mean for _ in range(block_size)], dim=-1)
    obs_std = torch.cat([obs_std for _ in range(block_size)], dim=-1)
    actor_critic.base.normalizer.mean[:, :block_size * obs_size] = obs_mean
    actor_critic.base.normalizer.std[:, :block_size * obs_size] = obs_std
    '''

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    all_obs = []
    all_rews = []
    all_indexs = []
    rew_normalizer = Normalizer(1, params['k'] * len(params['label_dict'].keys()))
    Normalized = False
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            # time1 = time.time()
            subset_idxs = envs.venv.envs[0].subset_idxs
            all_indexs.append(copy.deepcopy(subset_idxs))
            obs, reward, done, infos = envs.step(action)
            # time2 = time.time()
            # print(time2 - time1, flush=True)
            if args.normalize_obs:
                actor_critic.base.normalizer.update_stats(obs)
            all_rews.append(copy.deepcopy(reward))
            all_obs.append(copy.deepcopy(obs))
            if args.exploration:
                norm_obs = actor_critic.base.normalizer.normalize(obs)
                pred_emb = agent.pred_net(norm_obs[:, :obs_size].to(device))
                target_emb = agent.target_net(norm_obs[:, :obs_size].to(device))
                # print('reward ', reward)
                reward = reward + args.exp_coef * torch.sum((target_emb - pred_emb)**2, dim=-1).reshape(reward.shape).detach().cpu().numpy()
        
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
            if done[0]:
                episode_rewards.append(info['episode_r'])
                # print(info['episode_r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, torch.Tensor(subset_idxs).unsqueeze(-1), value, reward, masks, bad_masks)

        # if args.normalize_obs:
            # actor_critic.base.normalizer.update_stats(torch.cat(all_obs, dim=0))
        all_obs = []
        if args.normalize_rew:
            rew_normalizer.update_stats(torch.cat(all_rews, dim=0), torch.from_numpy(np.concatenate(all_indexs, axis=0)))
        all_indexs = []
        all_rews = []
            # rollouts.update_rew(rew_normalizer)
            # reward = rew_normalizer.normalize(reward, envs.venv.envs[0].subset_idxs)
        # if args.normalize_rew:
        #     all_indexs = []
        #     all_rews = []

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                            utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.actions[step], args.gamma,
                    rollouts.masks[step])

        # if j / args.eval_interval > 0 and j % args.eval_interval == 0:
        # if j / args.eval_interval == 1:
            # if args.normalize_rew:
            #     rew_normalizer.update_stats(torch.cat(all_rews, dim=0), torch.from_numpy(np.concatenate(all_indexs, axis=0)))
            #     all_indexs = []
            #     all_rews = []
            # Normalized = True

        if True:
            rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                    args.gae_lambda, rew_normalizer, args.use_proper_time_limits)

            value_loss, action_loss, dist_entropy = agent.update(rollouts)

            rollouts.after_update()
        else:
            value_loss = 0.0
            action_loss = 0.0
            dist_entropy = 0.0

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss), flush=True)
      
        if (args.eval_interval is not None and j % args.eval_interval == 0):
            # obs_rms = utils.get_vec_normalize(envs).obs_rms
            actor_critic.to('cpu')
            evaluate_fs_lm(actor_critic, None, envs_fseval, args.seed,
                    args.num_processes, 64, params, args, obs_size)
            actor_critic.to('cuda:0')

            eval_actor_critic.load_state_dict(actor_critic.state_dict())
            eval_actor_critic.base.normalizer.mean = copy.deepcopy(actor_critic.base.normalizer.mean.to('cpu'))
            eval_actor_critic.base.normalizer.std = copy.deepcopy(actor_critic.base.normalizer.std.to('cpu'))
            results = ctx.Queue()
            orig_results = ctx.Queue()
            env_queue = ctx.Queue()
            evaluate_processes = []
            for i in range(params['num_actors']):
            #     # actor = ctx.Process(
            #     # test_device = 'cuda:'+str(i)
                eval = ctx.Process(
                    target=evaluate_lm,
                    args=(i, eval_actor_critic, None, eval_envs[i], args.seed, 
                    args.num_processes, num_test_samples, orig_results, results, env_queue, params, args, obs_size))
                eval.start()
                evaluate_processes.append(eval)
            for eval in evaluate_processes:
                eval.join()
            results_list = []
            orig_results_list = []
            for i in range(results.qsize()):
                results_list.append(results.get())
                orig_results_list.append(orig_results.get())
            # if env_queue.qsize()>0:
            #     for i in range(env_queue.qsize()):
            #         eval_envs[i] = env_queue.get()
            print('Evaluation mean reward {:.5f}, original mean reward {:.5f}'.format(sum(results_list)/len(results_list), sum(orig_results_list)/len(orig_results_list)), flush=True)

            if not args.load_ckpt and args.env_name != 'lmall': 
                file_path = 'checkpoints/'+str(args.models)+'_'+str(args.datasets)+'_'+str(args.seed)+'/'
                isExist = os.path.exists(file_path)
                if not isExist:
                    os.makedirs(file_path)
                current_prompt_embedding_pool = []
                add_current_prompt_embedding_pool = []
                current_verbalizer_embedding_pool = []
                add_current_verbalizer_embedding_pool = []
                for eval_env in eval_envs:
                    current_prompt_embedding_pool.append(eval_env.envs[0].current_prompt_embedding_pool)
                    add_current_prompt_embedding_pool.append(eval_env.envs[0].add_current_prompt_embedding_pool)
                    current_verbalizer_embedding_pool.append(eval_env.envs[0].current_verbalizer_embedding_pool)
                    add_current_verbalizer_embedding_pool.append(eval_env.envs[0].add_current_verbalizer_embedding_pool)
                current_prompt_embedding_pool = torch.cat(current_prompt_embedding_pool, dim=0)
                add_current_prompt_embedding_pool = torch.cat(add_current_prompt_embedding_pool, dim=0)
                current_verbalizer_embedding_pool = torch.cat(current_verbalizer_embedding_pool, dim=0)
                add_current_verbalizer_embedding_pool = torch.cat(add_current_verbalizer_embedding_pool, dim=0)
                torch.save(current_prompt_embedding_pool, file_path+'current_prompt_embedding_pool.pth')
                torch.save(add_current_prompt_embedding_pool, file_path+'add_current_prompt_embedding_pool.pth')
                torch.save(current_verbalizer_embedding_pool, file_path+'current_verbalizer_embedding_pool.pth')
                torch.save(add_current_verbalizer_embedding_pool, file_path+'add_current_verbalizer_embedding_pool.pth')

            # evaluate_lm(actor_critic, None, eval_envs, args.seed,
            #         args.num_processes, eval_log_dir, device, num_test_samples)


if __name__ == "__main__":
    torch.multiprocessing.set_start_method('spawn')# good solution !!!!
    main()
