import argparse
import logging
import time
import gym
import torch
import numpy as np
from itertools import count
from sac.replay_memory import ReplayMemory
from sac.sac import SAC
from model import Ensemble_Model, PriorModel, NewPriorModel
from predict_env import PredictEnv
from sample_env import EnvSampler
from tf_models.constructor import construct_model, format_samples_for_training
import d4rl
from logger import Logger
import os
from video import VideoRecorder


class MOPO:
    def __init__(self, args, log_frequency=10000, log_save_tb=True):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')
        self.logger = Logger(self.work_dir,
                             save_tb=log_save_tb,
                             log_frequency=log_frequency,
                             agent="sac")
        # Initial environment
        self.env = gym.make(args.env_name)
        self.env_name = args.env_name
        self.dataset = d4rl.qlearning_dataset(self.env.env)
        self.best_reward = -100
        # Set random seed
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        self.env.seed(args.seed)

        if not args.data_path:
            self.dataset = d4rl.qlearning_dataset(self.env.env)
        else:
            data_path = "../../../"+ args.data_path
            self.load_data(data_path)

        # Intial agent
        self.agent = SAC(self.env.observation_space.shape[0], self.env.action_space, args, logger=self.logger)
        self.agent.policy.load_state_dict(torch.load(os.path.join(args.load_prior, 'prior.pt')))

        self.save_models = args.save_models
        # Initial ensemble model
        self.state_size = np.prod(self.env.observation_space.shape)
        self.action_size = np.prod(self.env.action_space.shape)

        # Prior Model
        self.prior_model = NewPriorModel(self.state_size, self.action_size, hidden_dim=256)
        self.agent.setup_prior(self.prior_model)
       
        if args.model_type == 'pytorch':
            print("hello false")
            self.env_model = Ensemble_Model(args.num_networks, args.num_elites, self.state_size, self.action_size, args.reward_size, args.pred_hidden_size, separate_mean_var=False)
        else:
            self.env_model = construct_model(obs_dim=self.state_size, act_dim=self.action_size, hidden_dim=args.pred_hidden_size, num_networks=args.num_networks, num_elites=args.num_elites)
        print("constructed model")
        # Predict environments
        self.predict_env = PredictEnv(self.env_model, args.env_name, args.model_type, logger=self.logger, penalty_coeff=args.coeff)
        self.env_sampler = EnvSampler(self.predict_env)
        print("constructed prediction environment")
        # Initial pool for env
        self.env_pool = ReplayMemory(args.replay_size)
        # Initial pool for model
        self.rollouts_per_epoch = args.rollout_batch_size * args.epoch_length / args.model_train_freq
        self.model_steps_per_epoch = int(1 * self.rollouts_per_epoch)
        self.new_pool_size = args.model_retain_epochs * self.model_steps_per_epoch
        self.model_pool = ReplayMemory(self.new_pool_size)
        self.epoch_length = args.epoch_length
        self.model_train_freq = args.model_train_freq
        self.video_recorder = VideoRecorder(
            self.work_dir if args.save_video else None)
        self.save_video = args.save_video

    def setup_data(self, args):
        assert args.env_name == 'halfcheetah-expert-v0'
        if args.env_name == 'halfcheetah-expert-v0':
            args.env_name = 'halfcheetah-medium-expert-v0'
        self.env = gym.make(args.env_name)
        self.env_name = args.env_name
        dataset = d4rl.qlearning_dataset(self.env)
        count = 0
        dones = []
        for i in range(len(dataset['observations']) - 1):
            if not np.all(dataset['next_observations'][i] == dataset['observations'][i+1]):
                dones.append(i)
        stop = dones[1000]
        print("THIS IS STOP", stop)
        for key in dataset:
            dataset[key] = dataset[key][stop:]
        self.dataset = dataset

    def train(self, args, train_model_epochs=25, load_prior=None, load=None):
        total_step = 0
        reward_sum = 0
        self.rollout_length = args.rollout
        self.add_to_pool(self.dataset)
        #self.mixup_data(self.dataset)
        if load_prior:
            self.prior_model.load(load_prior)
        else:
            obss = self.dataset['observations']
            acss = self.dataset['actions']
            if args.data_path:
                obss = obss.value 
                acss = acss.value
            # print("WHAT IS THIS", len(obss))
            self.prior_model.train(60, obss, acss)
            self.prior_model.optim = torch.optim.Adam(self.prior_model.parameters(), lr=1e-4)
            self.prior_model.train(70, obss, acss)
            save_dir = 'saved_models'
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            self.prior_model.save('saved_models')
        if load:
            self.predict_env.model.load(load)
            #self.agent.prior_model.load(load)
        else:
            import itertools
            epoch_iter = itertools.count()
            save_dir = 'saved_models'
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            for i in epoch_iter:#range(train_model_epochs):
                print("step", i)
                break_train = self.train_predict_model_offline(args, self.env_pool, self.predict_env, epoch=i, save=True)
                if break_train:
                    print("breaking")
                    break
            #self.predict_env.model.load(save_dir)
        #self.exploration_before_start(args, self.env_pool)
        self.evaluate(self.env, self.agent, total_step)
        for epoch_step in range(args.num_epoch):
            start_step = total_step
            train_policy_steps = 0
            for i in count():
                cur_step = total_step - start_step

                if cur_step >= self.epoch_length and len(self.env_pool) > args.min_pool_size:
                    break

                if total_step > 0 and cur_step % self.model_train_freq == 0 and args.real_ratio < 1.0:
                    #train_predict_model(args, env_pool, predict_env)

                    # new_rollout_length = self.set_rollout_length(args, epoch_step)
                    # if self.rollout_length != new_rollout_length:
                    #     self.rollout_length = new_rollout_length
                    #     self.model_pool = self.resize_model_pool(args, self.rollout_length, self.model_pool)

                    self.rollout_model(args, self.predict_env, self.agent, self.model_pool, self.env_pool, self.rollout_length, total_step=total_step)

                # cur_state, action, next_state, reward, done, info = env_sampler.sample(agent)
                # env_pool.push(cur_state, action, reward, next_state, done)

                if len(self.env_pool) >= args.min_pool_size:
                    train_policy_steps += self.train_policy_repeats(args, total_step, train_policy_steps, cur_step, self.env_pool, self.model_pool, self.agent)

                total_step += 1

            if total_step % 1000 == 0:
                self.logger.dump(total_step)
                self.evaluate(self.env, self.agent, total_step)
                '''
                avg_reward_len = min(len(env_sampler.path_rewards), 5)
                avg_reward = sum(env_sampler.path_rewards[-avg_reward_len:]) / avg_reward_len
                logging.info("Step Reward: " + str(total_step) + " " + str(env_sampler.path_rewards[-1]) + " " + str(avg_reward))
                print(total_step, env_sampler.path_rewards[-1], avg_reward)
                '''
                # env_sampler.current_state = None
                # sum_reward = 0
                # done = False
                # while not done:
                #     cur_state, action, next_state, reward, done, info = env_sampler.sample(agent, eval_t=True)
                #     sum_reward += reward
                # logging.info("Step Reward: " + str(total_step) + " " + str(sum_reward))
                # print(total_step, sum_reward)

    def evaluate(self, env, agent, total_step):
        average_episode_reward = 0
        num_eval_episodes = 10
        average_pred_episode_reward = 0
        average_unpenalized_pred_reward = 0
        for episode in range(num_eval_episodes):
            obs = env.reset()
            #agent.reset()
            if self.save_video:
                self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            while not done:
                with torch.no_grad():
                    action = self.agent.select_action(obs, eval=True)
                obs, reward, done, _ = self.env.step(action)
                if self.save_video: self.video_recorder.record(self.env)
                episode_reward += reward
            average_episode_reward += episode_reward
            if self.save_video:
                self.video_recorder.save(f'{total_step}.mp4')
        # for episode in range(num_eval_episodes):
        #     obs = env.reset()
        #     done, episode_reward, unpenalized_reward = False, 0, 0
        #     step = 0
        #     while not done:
        #         if self.env_name.startswith('maze2d') and step > 300:
        #             break
        #         with torch.no_grad():
        #             action = self.agent.select_action(obs)
        #         obs, reward, unpen_reward, done, _ = self.predict_env.step(obs, action, None, return_unpenalized=True)
        #         episode_reward += reward
        #         unpenalized_reward += unpen_reward
        #         step += 1
        #     average_pred_episode_reward += episode_reward
        #     average_unpenalized_pred_reward += unpenalized_reward
        average_episode_reward /= num_eval_episodes
        average_pred_episode_reward = (average_pred_episode_reward / num_eval_episodes)#[0]
        average_unpenalized_pred_reward = (average_unpenalized_pred_reward / num_eval_episodes)#[0]
        logging.info("Step Reward: " + str(total_step) + " " + str(average_episode_reward))
        #print(total_step, average_episode_reward)
        self.logger.log('eval/episode_reward', average_episode_reward,
                        total_step)
        self.logger.log('eval/pred_episode_reward', average_pred_episode_reward, total_step)
        self.logger.log('eval/unpenalized_pred_reward', average_unpenalized_pred_reward, total_step)
        self.logger.dump(total_step)
        if self.save_models and average_episode_reward > self.best_reward:
            self.best_reward = average_episode_reward
            self.agent.save_model(self.env_name)

    def add_to_pool(self, data):
        zipped_data = zip(data['observations'], data['actions'], data['next_observations'], 
                    data['rewards'], data['terminals'])
        i = 0
        for old_obs, ac, obs, rew, done in zipped_data:
            self.env_pool.push(old_obs, ac, rew, obs, done)
            i += 1
            # if i > 1000:
            #     print("BREAKING EARLY")
            #     break
        print("done adding to replay buffer")

    def exploration_before_start(self, args, predict_env):
        for i in range(args.init_exploration_steps):
            cur_state, action, next_state, reward, done, info = env_sampler.sample(agent)
            self.env_pool.push(cur_state, action, reward, next_state, done)

    def set_rollout_length(self, args, epoch_step):
        rollout_length = (min(max(args.rollout_min_length + (epoch_step - args.rollout_min_epoch)
            / (args.rollout_max_epoch - args.rollout_min_epoch) * (args.rollout_max_length - args.rollout_min_length),
            args.rollout_min_length), args.rollout_max_length))
        return int(rollout_length)

    def mixup_data(self, data):
        old_obs, ac, obs, rew, done = data['observations'], data['actions'], data['next_observations'], data['rewards'], data['terminals']
        betas = np.expand_dims(np.random.beta(0.4, 0.4, size=len(obs) - 1), 1)
        mixup_old_obs = betas * old_obs[:-1] + (1 - betas) * old_obs[1:]
        mixup_ac = betas * ac[:-1] + (1 - betas) * ac[1:]
        mixup_obs = betas * obs[:-1] + (1 - betas) * obs[1:]
        rew = np.expand_dims(rew, 1)
        mixup_rew = betas * rew[:-1] + (1 - betas) * rew[1:]
        done= np.expand_dims(done, 1)
        mixup_done = betas * done[:-1] + (1 - betas) * done[1:]
        zipped_data = zip(mixup_old_obs, mixup_ac, mixup_obs, mixup_rew, mixup_done)
        print("done calculating this thing")
        for old_obs, ac, obs, rew, done in zipped_data:
            #print("SHAPES", old_obs.shape, ac.shape, obs.shape, rew.shape, done.shape)
            self.env_pool.push(old_obs, ac, rew[0], obs, False)
        print("done adding mixup")

    def train_predict_model_offline(self, args, env_pool, predict_env, epoch, save=False):
        # Get all samples from environment
        batch = env_pool.buffer
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        #state, action, reward, next_state, done = env_pool.sample(len(env_pool))
        delta_state = next_state - state
        inputs = np.concatenate((state, action), axis=-1)
        labels = np.concatenate((np.reshape(reward, (reward.shape[0], -1)), delta_state), axis=-1)

        if args.model_type == 'pytorch':
            num_train = int(len(inputs) * 0.8)
            train_inputs, val_inputs = inputs[:num_train], inputs[num_train:]
            inds = np.arange(len(train_inputs))
            np.random.shuffle(inds)
            train_inputs, val_inputs = train_inputs[inds], val_inputs# shuffle together
            train_labels, val_labels = labels[:num_train][inds], labels[num_train:]
            return predict_env.model.train(train_inputs, train_labels, val_inputs, val_labels, batch_size=256, epoch=epoch)
        else:
            return predict_env.model.train(inputs, labels, batch_size=256, holdout_ratio=0.2)
        # if save:
        #     save_dir = 'saved_models'
        #     if not os.path.exists(save_dir):
        #         os.makedirs(save_dir)
            #predict_env.model.save(save_dir)


    def resize_model_pool(self, args, rollout_length, model_pool):
        rollouts_per_epoch = args.rollout_batch_size * args.epoch_length / args.model_train_freq
        model_steps_per_epoch = int(rollout_length * rollouts_per_epoch)
        new_pool_size = args.model_retain_epochs * model_steps_per_epoch

        sample_all = model_pool.return_all()
        new_model_pool = ReplayMemory(new_pool_size)
        new_model_pool.push_batch(sample_all)

        return new_model_pool

    def rollout_model(self, args, predict_env, agent, model_pool, env_pool, rollout_length, total_step=None):
        state, action, reward, next_state, done = env_pool.sample_all_batch(args.rollout_batch_size)
        for i in range(rollout_length):
            # TODO: Get a batch of actions
            action = agent.select_action(state)
            # with torch.no_grad():
            #     state_tc, action_tc = torch.from_numpy(state).float().cuda(), torch.from_numpy(action).float().cuda()
            #     q1, q2 = agent.critic(state_tc, action_tc)
            #     self.logger.log('train/q1', q1[0], total_step)
            #     self.logger.log('train/q2', q2[0], total_step)
            #     q_diff = np.abs(((q1 - q2) ** 2).detach().cpu().numpy() / q1.detach().cpu().numpy())
            #     #print(q_diff[0])
                #print("QDIFF SHAPE", q_diff.shape, action.shape, state.shape)
            q_diff = None
            next_states, rewards, terminals, info = predict_env.step(state, action, q_diff, total_step=total_step)
            # TODO: Push a batch of samples
            model_pool.push_batch([(state[j], action[j], rewards[j], next_states[j], terminals[j]) for j in range(state.shape[0])])
            nonterm_mask = ~terminals.squeeze(-1)
            if nonterm_mask.sum() == 0:
                break
            state = next_states[nonterm_mask]

    def train_policy_repeats(self, args, total_step, train_step, cur_step, env_pool, model_pool, agent):
        if total_step % args.train_every_n_steps > 0:
            return 0

        if train_step > args.max_train_repeat_per_step * total_step:
            return 0

        for i in range(args.num_train_repeat):
            env_batch_size = int(args.policy_train_batch_size * args.real_ratio)
            model_batch_size = args.policy_train_batch_size - env_batch_size

            env_state, env_action, env_reward, env_next_state, env_done = env_pool.sample(int(env_batch_size))

            if model_batch_size > 0 and len(model_pool) > 0:
                model_state, model_action, model_reward, model_next_state, model_done = model_pool.sample_all_batch(int(model_batch_size))
                batch_state, batch_action, batch_reward, batch_next_state, batch_done = np.concatenate((env_state, model_state), axis=0), \
                    np.concatenate((env_action, model_action), axis=0), np.concatenate((np.reshape(env_reward, (env_reward.shape[0], -1)), model_reward), axis=0), \
                    np.concatenate((env_next_state, model_next_state), axis=0), np.concatenate((np.reshape(env_done, (env_done.shape[0], -1)), model_done), axis=0)
            else:
                batch_state, batch_action, batch_reward, batch_next_state, batch_done = env_state, env_action, env_reward, env_next_state, env_done

            batch_reward, batch_done = np.squeeze(batch_reward), np.squeeze(batch_done)
            batch_done = (~batch_done).astype(int)
            with torch.autograd.detect_anomaly():
                agent.update_parameters((batch_state, batch_action, batch_reward, batch_next_state, batch_done), args.policy_train_batch_size, i, step=cur_step)

        return args.num_train_repeat

    def load_data(self, dataset_path):
        import h5py
        file = h5py.File(dataset_path, 'r')
        data = {}
        for key in ['observations', 'actions', 'next_observations', 'rewards', 'terminals']:
            data[key] = file.get(key)
        self.dataset = data


def main():
    #logging.basicConfig(filename=time.strftime("%Y%m%d-%H%M%S") + '_train.log', level=logging.INFO)
    import shutil
    args = readParser()

    filedir = os.path.join('exp', time.strftime("%Y.%m.%d"), time.strftime("%H_%M") + "_" + args.experiment)
    if not os.path.exists(os.path.join('exp', time.strftime("%Y.%m.%d"))):
        os.mkdir(os.path.join('exp', time.strftime("%Y.%m.%d")))
    shutil.rmtree(filedir, ignore_errors=True)
    os.mkdir(filedir)
    os.chdir(filedir)
    import sys
    import json
    with open('args.txt', 'w') as outfile:
        json.dump(args.__dict__, outfile)
    mopo = MOPO(args)

    
    print("starting training...")
    #load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.01.04/23_18_pointmaze_mopo/saved_models'
    load = None
    #load_prior= '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.07/16_07_medium-random-prior-kl5/saved_models'
    #load_prior = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.09/17_21_hopper-medium-actionprior2/saved_models'
    #load_prior='/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.11/21_23_hopper-actionprior-kl2-mixed/saved_models'
    load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.11/19_05_hopper_medium_mixed/saved_models'
    load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2020.12.21/11_32_hopper_mixed_mopo_3e4/saved_models'

    load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.14/22_52_hopper-random-goodmodel-mopo/saved_models'

    load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.14/22_52_hopper-random-goodmodel-mopo/saved_models'
    load_prior='/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.18/01_25_hopper-mixed-kl2/saved_models'
    #load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.14/21_55_hopper-mediu-goodmodel/saved_models'
    #load='/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.02/00_30_umaze-random-prior-kl5/saved_models'
    load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.17/22_08_hopper-mixed-model/saved_models'
    #load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.17/22_42_hopper-med-expert-model/saved_models'
    #load = '/home/catc/experiments/git-repos/replication-mbpo/exp/2021.02.17/22_02_cheetah-random-model/saved_models'
    #load=None
    mopo.train(args, load_prior=args.load_prior, load=args.load, train_model_epochs=40)


def readParser():
    parser = argparse.ArgumentParser(description='MBPO')
    parser.add_argument('--env-name', default="hopper-medium-expert-v0",
        help='Mujoco Gym environment (default: hopper-medium-expert-v0)')
    parser.add_argument('--seed', type=int, default=123456, metavar='N',
        help='random seed (default: 123456)')
    parser.add_argument('--experiment', default="mopo-hopper")
    parser.add_argument('--load', default=None)
    parser.add_argument('--load_prior', default=None)
    parser.add_argument('--coeff', type=float, default=1.0)
    parser.add_argument('--rollout', type=int, default=5)
    parser.add_argument('--save_video', default=False, action='store_true')
    parser.add_argument('--save_models', default=True, action='store_true')
    parser.add_argument('--data_path', default=None, type=str)
    parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
        help='discount factor for reward (default: 0.99)')
    parser.add_argument('--tau', type=float, default=0.005, metavar='G',
        help='target smoothing coefficient(τ) (default: 0.005)')
    parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
    parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
    parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G',
                    help='Automaically adjust α (default: False)')
    parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
    parser.add_argument('--lr', type=float, default=3e-4, metavar='G',
                    help='learning rate (default: 0.0003)')
    parser.add_argument('--kl', type=float, default=1, help='target kl divergence')
    parser.add_argument('--num_networks', type=int, default=7, metavar='E',
                    help='ensemble size (default: 7)')
    parser.add_argument('--num_elites', type=int, default=5, metavar='E',
                    help='elite size (default: 5)')
    parser.add_argument('--pred_hidden_size', type=int, default=200, metavar='E',
                    help='hidden size for predictive model')
    parser.add_argument('--reward_size', type=int, default=1, metavar='E',
                    help='environment reward size')

    parser.add_argument('--replay_size', type=int, default=1000000, metavar='N',
                    help='size of replay buffer (default: 10000000)')

    parser.add_argument('--model_retain_epochs', type=int, default=5, metavar='A', # old: 1
                    help='retain epochs')
    parser.add_argument('--model_train_freq', type=int, default=1000, metavar='A', #old: 250
                    help='frequency of training')
    parser.add_argument('--rollout_batch_size', type=int, default=50000, metavar='A', # old: 100000
                    help='rollout number M')
    parser.add_argument('--epoch_length', type=int, default=1000, metavar='A',
                    help='steps per epoch')
    parser.add_argument('--rollout_min_epoch', type=int, default=20, metavar='A',
                    help='rollout min epoch')
    parser.add_argument('--rollout_max_epoch', type=int, default=150, metavar='A',
                    help='rollout max epoch')
    parser.add_argument('--rollout_min_length', type=int, default=1, metavar='A',
                    help='rollout min length')
    parser.add_argument('--rollout_max_length', type=int, default=15, metavar='A',
                    help='rollout max length')
    parser.add_argument('--num_epoch', type=int, default=1000, metavar='A',
                    help='total number of epochs')
    parser.add_argument('--min_pool_size', type=int, default=1000, metavar='A',
                    help='minimum pool size')
    parser.add_argument('--real_ratio', type=float, default=0.05, metavar='A',
                    help='ratio of env samples / model samples')
    parser.add_argument('--train_every_n_steps', type=int, default=1, metavar='A',
                    help='frequency of training policy')
    parser.add_argument('--num_train_repeat', type=int, default=1, metavar='A', # old: 20
                    help='times to training policy per step')
    parser.add_argument('--max_train_repeat_per_step', type=int, default=5, metavar='A',
                    help='max training times per step')
    parser.add_argument('--policy_train_batch_size', type=int, default=256, metavar='A',
                    help='batch size for training policy')
    parser.add_argument('--init_exploration_steps', type=int, default=5000, metavar='A',
                    help='exploration steps initially')

    parser.add_argument('--model_type', default='tensorflow', metavar='A',
                    help='predict model -- pytorch or tensorflow')

    parser.add_argument('--cuda', default=True, action="store_true",
                    help='run on CUDA (default: True)')
    return parser.parse_args()

if __name__ == '__main__':
    main()
