from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from time import localtime, strftime, time

from dotmap import DotMap
from scipy.io import savemat
from tqdm import trange
import numpy as np
import pandas as pd
import torch

import utils
from Agent import Agent
from DotmapUtils import get_required_argument


class MBExperiment:
    def __init__(self, params):
        """Initializes class instance.

        Argument:
            params (DotMap): A DotMap containing the following:
                .sim_cfg:
                    .env (gym.env): Environment for this experiment
                    .task_hor (int): Task horizon
                    .stochastic (bool): (optional) If True, agent adds noise to its actions.
                        Must provide noise_std (see below). Defaults to False.
                    .noise_std (float): for stochastic agents, noise of the form N(0, noise_std^2I)
                        will be added.

                .exp_cfg:
                    .ntrain_iters (int): Number of training iterations to be performed.
                    .nrollouts_per_iter (int): (optional) Number of rollouts done between training
                        iterations. Defaults to 1.
                    .ninit_rollouts (int): (optional) Number of initial rollouts. Defaults to 1.
                    .policy (controller): Policy that will be trained.

                .log_cfg:
                    .logdir (str): Parent of directory path where experiment data will be saved.
                        Experiment will be saved in logdir/<date+time of experiment start>
                    .nrecord (int): (optional) Number of rollouts to record for every iteration.
                        Defaults to 0.
                    .neval (int): (optional) Number of rollouts for performance evaluation.
                        Defaults to 1.
        """

        # Assert True arguments that we currently do not support
        assert params.sim_cfg.get("stochastic", False) == False

        self.env = get_required_argument(params.sim_cfg, "env", "Must provide environment.")
        self.task_hor = get_required_argument(params.sim_cfg, "task_hor", "Must provide task horizon.")
        self.agent = Agent(DotMap(env=self.env, noisy_actions=False))

        self.ntrain_iters = get_required_argument(
            params.exp_cfg, "ntrain_iters", "Must provide number of training iterations."
        )
        self.nrollouts_per_iter = params.exp_cfg.get("nrollouts_per_iter", 1)
        self.ninit_rollouts = params.exp_cfg.get("ninit_rollouts", 1)
        self.policy = get_required_argument(params.exp_cfg, "policy", "Must provide a policy.")

        self.logdir = params.log_cfg["logdir"]
        self.nrecord = params.log_cfg.get("nrecord", 0)
        self.neval = params.log_cfg.get("neval", 1)

    def run_experiment(self):
        """Perform experiment.
        """
        os.makedirs(self.logdir, exist_ok=True)

        traj_obs, traj_acs, traj_rets, traj_rews = [], [], [], []
        ep_stats = []

        # Perform initial rollouts
        samples = []
        for i in range(self.ninit_rollouts):
            samples.append(
                self.agent.sample(
                    self.task_hor, self.policy
                )
            )
            traj_obs.append(samples[-1]["obs"])
            traj_acs.append(samples[-1]["ac"])
            traj_rews.append(samples[-1]["rewards"])

        if self.ninit_rollouts > 0:
            self.policy.train(
                [sample["obs"] for sample in samples],
                [sample["ac"] for sample in samples],
                [sample["rewards"] for sample in samples]
            )

        # Training loop
        for i in trange(self.ntrain_iters):
            print("####################################################################")
            print("Starting training iteration %d." % (i + 1))

            start_time = time()

            samples = []
            for j in range(max(self.neval, self.nrollouts_per_iter)):
                samples.append(
                    self.agent.sample(
                        self.task_hor, self.policy
                    )
                )
            collection_time = time() - start_time
            ep_reward = [sample["reward_sum"] for sample in samples[:self.neval]][0]
            total_steps = self.policy.train_in.shape[0]

            traj_obs.extend([sample["obs"] for sample in samples[:self.nrollouts_per_iter]])
            traj_acs.extend([sample["ac"] for sample in samples[:self.nrollouts_per_iter]])
            traj_rets.extend([sample["reward_sum"] for sample in samples[:self.neval]])
            traj_rews.extend([sample["rewards"] for sample in samples[:self.nrollouts_per_iter]])
            samples = samples[:self.nrollouts_per_iter]

            self.policy.dump_logs(self.logdir, iter_dir)
            savemat(
                os.path.join(self.logdir, "logs.mat"),
                {
                    "observations": traj_obs,
                    "actions": traj_acs,
                    "returns": traj_rets,
                    "rewards": traj_rews
                }
            )
            # Delete iteration directory if not used
            if len(os.listdir(iter_dir)) == 0:
                os.rmdir(iter_dir)

            # train
            losses = []
            if i < self.ntrain_iters:
                l = self.policy.train(
                    [sample["obs"] for sample in samples],
                    [sample["ac"] for sample in samples],
                    [sample["rewards"] for sample in samples]
                )
                losses.append(l.copy())
            losses_mean, losses_std = np.array(losses).mean(), np.array(losses).std()

            # stats
            ep_stat = {}
            ep_stat['ep_reward'] = ep_reward
            ep_stat['total_steps'] = total_steps
            ep_stat['losses_mean'] = losses_mean
            ep_stat['losses_std'] = losses_std
            ep_stats.append(ep_stat.copy())

            print("total_steps %d" % total_steps, "Rewards obtained:", ep_reward, "loss: %.7f+-%.5f" % (losses_mean, losses_std),
                "Collection time: %.2fs" % collection_time, "")

        # save train results
        pd.DataFrame(ep_stats).to_csv(self.logdir + os.sep + 'training.csv', index=None)

    def save(self, train_iteration=None, path=None):
        '''
        Save model and Replybuffer
        '''
        if path == None: path = self.logdir

        # Save replay buffer
        rb = {"train_in": self.policy.train_in,
              "train_targs": self.policy.train_targs,
              "raw_observations": self.policy.raw_observations,
              "raw_actions": self.policy.raw_actions,
              "maneuvers" : self.policy.maneuvers }
        
        if train_iteration:
            filepath = file=path + os.sep + 'replay_buffer.pkl' + str(train_iteration)
        else:
            filepath = file=path + os.sep + 'replay_buffer.pkl'
        utils.to_pickle(obj=rb, file=filepath, verbose=True)

        # save model weights
        model_file = path + os.sep + 'model.ptorch'
        torch.save(self.policy.model.state_dict(), model_file)
        print('saved model in:', model_file)

    def initialize_model(self, maneuver_index):
        traj_obs, traj_acs, traj_rets, traj_rews, traj_obs_ = [], [], [], [], []

        # Perform initial rollouts
        samples = []
        for i in range(self.ninit_rollouts):
            samples.append(
                self.agent.sample(
                    self.task_hor, self.policy
                )
            )
            traj_obs.append(samples[-1]["obs"])
            traj_acs.append(samples[-1]["ac"])
            traj_rews.append(samples[-1]["rewards"])
            traj_obs_.append(samples[-1]["obs_"])

        if self.ninit_rollouts > 0:
            self.policy.train(
                [sample["obs"] for sample in samples],
                [sample["ac"] for sample in samples],
                [sample["rewards"] for sample in samples],
                [sample["obs_"] for sample in samples],
                [np.full_like(sample["rewards"], maneuver_index) for sample in samples]
            )

    def load_model(self, model_path, buffer_number=''):
        model_path += os.sep
        # load replay buffer
        d = utils.read_pickle(model_path + 'replay_buffer.pkl' + buffer_number)
        self.policy.train_in, self.policy.train_targs, self.policy.raw_observations, self.policy.raw_actions, self.policy.maneuvers = d["train_in"], d["train_targs"], d["raw_observations"], d["raw_actions"], d["maneuvers"]
        # load weights
        self.policy.model.load_state_dict(torch.load(model_path + 'model.ptorch'))
        print('loaded rb with %d samples' % len( self.policy.train_in) )
