import os
import sys
from tqdm import tqdm
import torch
import pickle
from torch.nn import functional as F
from collections import OrderedDict
from utils.plotting import *
from statistics import mean
from utils.logger import logkv_mean, logkv
from mpi4py import MPI
from utils.utils import get_space_dim, GradUpdateType
from utils.pytorch_modules import DiscountedSumForward
from utils.mpi_adam import MpiAdam
from sampling.sampling import generate_samples
from sampling.sample_processing import calculate_discounted_returns, calculate_gae_advantages


class LVC(object):
    def __init__(self, env, policy, baseline, optimizer, opts):
        self.env = env
        self.policy = policy
        self.baseline = baseline
        self.grad_update_type = opts.grad_update_type

        self.opts = opts
        self.policy_optimizer = optimizer

    def epoch(self):
        # Set policy gradients to zero
        self.policy_optimizer.zero_grad()

        # For each environment sampled from env distribution
        if self.grad_update_type == GradUpdateType.MULTI:
            n_envs_to_sample = len(self.env.envs_to_sample)
            range_start = self.opts.rank
            range_step = 1 if MPI is None else MPI.COMM_WORLD.Get_size()
            for index in range(range_start, n_envs_to_sample, range_step):
                if self.opts.maml:
                    self.env_grads_maml(task=self.env.envs_to_sample[index], env_count_normalizer=n_envs_to_sample)
                else:
                    self.env_grads(task=self.env.envs_to_sample[index], env_count_normalizer=n_envs_to_sample)
        else:
            for env_sample in tqdm(range(self.opts.envs_per_process), desc=f"Process {self.opts.rank}",
                                   position=self.opts.rank, file=sys.stdout, disable=True):
                task = self.opts.fixed_env if self.opts.fixed_env >= 0 else self.env.sample_task()
                n_envs = (1 if MPI is None else MPI.COMM_WORLD.Get_size()) * self.opts.envs_per_process
                if self.opts.maml:
                    self.env_grads_maml(task=task, env_count_normalizer=n_envs)
                else:
                    self.env_grads(task=task, env_count_normalizer=n_envs)

        # Update the initial policy params
        if isinstance(self.policy_optimizer, MpiAdam):
            self.policy_optimizer.step(comm=None if MPI is None else MPI.COMM_WORLD)
        else:
            self.policy_optimizer.step()

    def env_grads(self, task, env_count_normalizer):
        self.env.set_task(task)
        current_params = OrderedDict(self.policy.named_parameters())
        current_inner_params = self.policy.inner_params
        for lookahead in range(self.opts.lookaheads):
            loss, traj_data = self.get_update_data(params=current_params)

            # Get updated parameters
            lr = self.policy.lr_params if self.opts.learn_lr_inner else self.opts.lr_inner
            new_inner_params = self.policy.sgd_update(loss=loss, lr=lr, params=current_inner_params)

            current_params = self.policy.outer_params.copy()
            current_params.update(new_inner_params)
            current_inner_params = new_inner_params
            self._log_traj_data(traj_data=traj_data, task=task, prefix=f"Step{lookahead:02}")

        loss, traj_data = self.get_update_data(params=current_params)
        loss /= env_count_normalizer

        if self.opts.save_trajs:
            with open(self.opts.base_dir + "/logs/trajs.pickle", 'wb') as handle:
                pickle.dump(traj_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

        # Accumulate grads
        loss.backward()
        self._log_traj_data(traj_data=traj_data,task=task, prefix=f"Step{self.opts.lookaheads:02}")

    def env_grads_maml(self, task, env_count_normalizer):
        self.env.set_task(task)
        current_params = OrderedDict(self.policy.named_parameters())
        current_outer_params = self.policy.outer_params
        for lookahead in range(self.opts.lookaheads):
            loss, traj_data = self.get_update_data(params=current_params)

            # Get updated parameters
            lr = self.policy.lr_params if self.opts.learn_lr_inner else self.opts.lr_inner
            new_outer_params = self.policy.sgd_update(loss=loss, lr=lr, params=current_outer_params)

            current_params = self.policy.inner_params.copy()
            current_params.update(new_outer_params)
            current_outer_params = new_outer_params
            self._log_traj_data(traj_data=traj_data, task=task, prefix=f"Step{lookahead:02}")

        loss, traj_data = self.get_update_data(params=current_params)
        loss /= env_count_normalizer

        if self.opts.save_trajs:
            with open(self.opts.base_dir + "/logs/trajs.pickle", 'wb') as handle:
                pickle.dump(traj_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

        # Accumulate grads
        loss.backward()
        self._log_traj_data(traj_data=traj_data, task=task, prefix=f"Step{self.opts.lookaheads:02}")

    def loaded_dice(self, log_probs, advantages, lam, entropies, beta):
        first_term = DiscountedSumForward.apply(log_probs, lam)
        second_term = first_term - log_probs
        deps_first = torch.exp(first_term - first_term.detach())
        deps_second = torch.exp(second_term - second_term.detach())
        # total = deps_first * advantages - deps_second * advantages
        total = deps_first * advantages + deps_second * (beta * entropies - advantages)
        return -torch.sum(total, dim=0, keepdim=True)

    def get_update_data(self, params=None):
        # Generate trajectory data
        traj_data = generate_samples(env=self.env, policy=self.policy, episodes=self.opts.episodes, params=params)
        traj_data["discounted_returns"] = calculate_discounted_returns(traj_data=traj_data,
                                                                       return_discount=self.opts.return_discount)

        self.policy.get_update_data(traj_data, params=params)

        self.baseline.update(traj_data)
        self.baseline.predict(traj_data)
        # Calculate advantages
        calculate_gae_advantages(traj_data=traj_data, return_discount=self.opts.return_discount,
                                 gae_discount=self.opts.gae_discount)

        traj_data["loss"] = list(map(lambda p, q, r: self.loaded_dice(p, q, lam=self.opts.dice_discount,
                                                                      entropies=r, beta=self.opts.entropy_reg),
                                     traj_data["action_log_probs"],
                                     traj_data["advantages"],
                                     traj_data["entropies"]))

        loss = torch.mean(torch.cat(traj_data["loss"]))
        return loss, traj_data

    def plot_epoch(self, name_prefix="", plot_updates=True):
        if "ant" in self.opts.env:
            return None
        if self.opts.rank == 0:
            os.makedirs(name_prefix, exist_ok=True)

        if MPI is not None:
            MPI.COMM_WORLD.Barrier()
        if self.env.one_hot:
            inputs = torch.eye(int(get_space_dim(self.env.observation_space)))
        else:
            raise NotImplementedError("TODO implement the coordinate input gen")
        if self.opts.rank == 0:
            with torch.no_grad():
                option_probs, termination_probs, policy_params = self.policy(inputs)
                policy_probs = F.softmax(policy_params / self.policy.temperature_actions, dim=-1)
            torch.save({"option_probs": option_probs, "termination_probs": termination_probs,
                        "policy_probs": policy_probs}, name_prefix + "pre_update_data")
            self.env.plot_params(option_probs=option_probs, termination_probs=termination_probs,
                                 policy_probs=policy_probs, name_prefix=name_prefix + "pre_update", no_blocks=True)

        if plot_updates:
            processes = 1 if MPI is None else MPI.COMM_WORLD.Get_size()
            for current_task in range(self.opts.rank, self.env.n_envs, processes):
                self.env.set_task(current_task)

                current_params = OrderedDict(self.policy.named_parameters())
                current_inner_params = self.policy.inner_params
                for lookahead in range(self.opts.lookaheads):
                    loss, traj_data = self.get_update_data(params=current_params)

                    # Get updated parameters
                    lr = self.policy.lr_params if self.opts.learn_lr_inner else self.opts.lr_inner
                    new_inner_params = self.policy.sgd_update(loss=loss, lr=lr, params=current_inner_params, create_graph=False)
                    current_params = self.policy.outer_params.copy()
                    current_params.update(new_inner_params)
                    current_inner_params = new_inner_params

                with torch.no_grad():
                    option_probs, termination_probs, policy_probs = self.policy(inputs, params=current_params)
                torch.save({"option_probs": option_probs, "termination_probs": termination_probs,
                            "policy_probs": policy_probs},
                           name_prefix + "env{}_data".format(current_task))
                self.env.plot_params(option_probs=option_probs, termination_probs=termination_probs,
                                     policy_probs=policy_probs, name_prefix=name_prefix + "env{}".format(current_task),
                                     no_blocks=False)
        plt.close("all")

    def _log_traj_data(self, traj_data, task, prefix=""):
            if self.opts.log_level >= 1:
                avg_episode_length = mean(map(len, traj_data["observations"]))
                avg_discounted_return = mean(map(lambda p: p[0].item(), traj_data["discounted_returns"]))
                avg_return = mean(map(lambda p: p.sum().item(), traj_data["rewards"]))
                avg_terminations = mean(map(lambda p: p.sum().item() / len(p), traj_data["terminations"]))
                keys = ["EpisodeLength", "DiscountedReturn", "Return", "Terminations"]
                values = [avg_episode_length, avg_discounted_return, avg_return, avg_terminations]
                for option in range(self.opts.options):
                    avg_option_usage = mean(map(lambda p: (p == option).sum().float().item() / len(p), traj_data["options"]))
                    keys.append(f"Option{option}Usage")
                    values.append(avg_option_usage)
                log_dict = dict(zip(keys, values))

                for key, val in log_dict.items():
                    logkv_mean(f"{prefix}{key}/Avg", val)
                    logkv_mean(f"{prefix}{key}/Env{task}", val)
