from train.behavioral_cloning.datasets.minerl_dataset import MineRLDataset, MineRLBaseDataset
from train.reinforcment_learning.rudder.rudder import RR_LSTM_ARCH, RR_LSTM
from train.reinforcment_learning.rudder.buffer import LessonBuffer
from train.common.config import Config
from train.reinforcment_learning.utils.utils import safemean, get_device, import_module
from train.reinforcment_learning.save_video import save_video
from train.envs.minerl_env import make_minerl
from train.reinforcment_learning.runner import Runner
from torch.utils.tensorboard import SummaryWriter

import tqdm
import sys
import time
import torch
import numpy as np


def pad_sequence(episode_length, data):
    """
    Data is assumed to be T first. Returns padded data. Also, cuts longer sequences to smaller chunks. Paadding is done
    after the sequence
    :param episode_length: Length of predefined episode
    :param data: Timesteps, data_shape
    :return: episode_length, data_shape (numpy array)
    """
    stack_size = episode_length - data.shape[0]
    if stack_size < 0:
        # cut extra length
        data = data[-episode_length:]
        return data
    elif stack_size == 0:
        return data
    else:
        pad = np.zeros((stack_size,) + data.shape[1:])
        return np.concatenate((data, pad), axis=0)


def break_sequence(episode_length, data, return_lengths=False):
    """ Returns list. """

    if return_lengths:
        if data.shape[0] < episode_length:
            return [data.shape[0]], [pad_sequence(episode_length, data)]
        elif data.shape[0] == episode_length:
            return [episode_length], [data]
        else:
            # get index in the multiple of episode length
            # breaks = (data.shape[0] // episode_length)
            # total_length = breaks * episode_length
            d = []
            l = []
            for start in range(0, data.shape[0], episode_length):
                end = start + episode_length
                d_ = data[start:end]
                if d_.shape[0] < episode_length:
                    l.append(d_.shape[0])
                    d.append(pad_sequence(episode_length, d_))
                else:
                    l.append(episode_length)
                    d.append(d_)

            return l, d
    else:
        if data.shape[0] < episode_length:
            return [pad_sequence(episode_length, data)]
        elif data.shape[0] == episode_length:
            return [data]
        else:
            # get index in the multiple of episode length
            # breaks = (data.shape[0] // episode_length)
            # total_length = breaks * episode_length
            d = []
            for start in range(0, data.shape[0], episode_length):
                end = start + episode_length
                d_ = data[start:end]
                if d_.shape[0] < episode_length:
                    d.append(pad_sequence(episode_length, d_))
                else:
                    d.append(d_)
            return d


def get_episode_from_traj(b_rewards, b_dones, pov, bin_actions, camera_actions, args):
    neg_rew = args.negative_reward_const
    episode_length = args.rudder_episode_len
    b_rewards = np.stack(b_rewards).swapaxes(0, 1)
    b_dones = np.stack(b_dones).swapaxes(0, 1)
    pov = pov.swapaxes(0, 1)
    bin_actions = bin_actions.swapaxes(0, 1)
    camera_actions = camera_actions.swapaxes(0, 1)
    print(pov.shape, bin_actions.shape, camera_actions.shape, b_rewards.shape, b_dones.shape)

    agent_pov, agent_bin_actions, agent_camera_actions, agent_b_rewards, agent_lengths = [], [], [], [], []
    for i in range(pov.shape[0]):
        epi_ends = np.where(b_dones[i] == True)[0]
        print(epi_ends)
        if epi_ends.size == 0:
            print("complete episode, breaking sequence")
            split_pov = break_sequence(episode_length, pov[i])
            split_bin_actions = break_sequence(episode_length, bin_actions[i])
            split_camera_actions = break_sequence(episode_length, camera_actions[i])
            split_length, split_rewards = break_sequence(episode_length, b_rewards[i], True)
            for p, b, c, r, l in zip(split_pov, split_bin_actions, split_camera_actions, split_rewards, split_length):
                p_p = pad_sequence(episode_length, p)
                b_p = pad_sequence(episode_length, b)
                c_p = pad_sequence(episode_length, c)
                if args.change_reward_structure:
                    first_reward = np.where(r == 1)[0]
                    if first_reward.size == 0:
                        r[:] = neg_rew
                    else:
                        r[:first_reward[0]] = neg_rew
                r_p = pad_sequence(episode_length, r)
                agent_pov.append(p_p)
                agent_bin_actions.append(b_p)
                agent_camera_actions.append(c_p)
                agent_b_rewards.append(r_p)
                agent_lengths.append(l)
        else:
            # split episode
            print("splitting episodes acc to dones")
            epi_ends = np.sort(np.append(epi_ends, episode_length))
            split_pov_ = np.split(pov[i], epi_ends)
            split_bin_actions_ = np.split(bin_actions[i], epi_ends)
            split_camera_actions_ = np.split(camera_actions[i], epi_ends)
            split_rewards_ = np.split(b_rewards[i], epi_ends)
            for p_, b_, c_, r_ in zip(split_pov_, split_bin_actions_, split_camera_actions_, split_rewards_):
                split_pov = break_sequence(episode_length, p_)
                split_bin_actions = break_sequence(episode_length, b_)
                split_camera_actions = break_sequence(episode_length, c_)
                split_length, split_rewards = break_sequence(episode_length, r_, True)
                for p, b, c, r, l in zip(split_pov, split_bin_actions, split_camera_actions, split_rewards,
                                         split_length):
                    if args.change_reward_structure:
                        first_reward = np.where(r == 1)[0]
                        if first_reward.size == 0:
                            r[:] = neg_rew
                        else:
                            r[:first_reward[0]] = neg_rew
                    r_p = pad_sequence(episode_length, r)
                    agent_pov.append(pad_sequence(episode_length, p))
                    agent_bin_actions.append(pad_sequence(episode_length, b))
                    agent_camera_actions.append(pad_sequence(episode_length, c))
                    agent_b_rewards.append(r_p)
                    agent_lengths.append(l)

    agent_pov = np.stack(agent_pov)[:, :, -1]
    agent_bin_actions = np.stack(agent_bin_actions)[:, :, -1]
    agent_camera_actions = np.stack(agent_camera_actions)[:, :, -1]
    agent_b_rewards = np.stack(agent_b_rewards)
    agent_lengths = np.stack(agent_lengths)

    return agent_pov, agent_bin_actions, agent_camera_actions, agent_b_rewards, agent_lengths


if __name__ == "__main__":
    # parse config
    args = Config()

    # import model and dataset
    Network = import_module(args.model).Network
    dataset = import_module(args.dataset)
    # get behavior policy
    print("Initializing network!")
    model = Network()
    # load pre-trained model parameters
    if args.load_model:
        print("Loading behaviour cloned model!")
        model.load_state_dict(torch.load(args.pre_trained_params))

    device = get_device(args.gpu)
    model = model.to(device)
    neg_rew = args.negative_reward_const
    episode_length = args.rudder_episode_len

    # TODO: store everything as float32
    if not args.load_rudder_data:
        # load dataset
        root_chop = "/publicdata/minerl/official_v1"
        root_log = "/publicdata/minerl/subtasks/log"
        data_log = MineRLDataset(root=root_log, sequence_len=dataset.SEQ_LENGTH, train=True, prepare=False,
                                 experiment="MineRLObtainDiamond-v0",
                                 input_space=dataset.INPUT_SPACE,
                                 action_space=dataset.ACTION_SPACE,
                                 data_transform=dataset.DATA_TRANSFORM)

        print("Number of log sequences", data_log.num_sequences())
        num_seq_log = data_log.num_sequences()
        inds = np.arange(num_seq_log)
        # preprocessing
        pov, binary_actions, camera_actions, rewards, length = [], [], [], [], []
        for i in range(num_seq_log):
            # get a sequence
            _pov, _binary_actions, _camera_actions, _rewards = data_log._prepare_sequence(i, step=None)
            pov.append(pad_sequence(episode_length, _pov))
            binary_actions.append(pad_sequence(episode_length, _binary_actions))
            camera_actions.append(pad_sequence(episode_length, _camera_actions))
            if args.change_reward_structure:
                _rewards[:] = neg_rew
                _rewards[-1] = 1
            else:
                _rewards[:] = 0
                _rewards[-1] = 1
            rewards.append(pad_sequence(episode_length, _rewards))
            l = episode_length if _rewards.shape[0] > episode_length else _rewards.shape[0]
            length.append(l)
        log_pov = np.stack(pov)
        log_binary_actions = np.stack(binary_actions)
        log_camera_actions = np.stack(camera_actions)
        log_rewards = np.stack(rewards)
        log_lengths = np.stack(length)

        del data_log, pov, binary_actions, camera_actions, rewards, length

        print("Loading tree chop sequences")
        data_chop = MineRLDataset(root=root_chop, sequence_len=dataset.SEQ_LENGTH, train=True, prepare=False,
                                  experiment="MineRLTreechop-v0",
                                  input_space=dataset.INPUT_SPACE,
                                  action_space=dataset.ACTION_SPACE,
                                  data_transform=dataset.DATA_TRANSFORM)

        num_seq_chop = data_chop.num_sequences()
        print("number of tree chop sequences", num_seq_chop)

        pov, binary_actions, camera_actions, rewards, length = [], [], [], [], []
        for i in range(num_seq_chop):
            # get a sequence
            _pov, _binary_actions, _camera_actions, _rewards = data_chop._prepare_sequence(i, step=None)
            pov.append(break_sequence(episode_length, _pov))
            binary_actions.append(break_sequence(episode_length, _binary_actions))
            camera_actions.append(break_sequence(episode_length, _camera_actions))
            l, b_r = break_sequence(episode_length, _rewards, True)
            rewards.append(b_r)
            length.append(l)

        # TODO: test if the transformation aligns all the arrays in correct manner
        flat_pov, flat_binary_actions, flat_camera_actions, flat_rewards, chop_lengths = [], [], [], [], []
        for p_, b_, c_, r_, l_ in zip(pov, binary_actions, camera_actions, rewards, length):
            for p_1, b_1, c_1, r_1, l_1 in zip(p_, b_, c_, r_, l_):
                flat_pov.append(p_1)
                flat_binary_actions.append(b_1)
                flat_camera_actions.append(c_1)
                if args.change_reward_structure:
                    first_reward = np.where(r_1 == 1)[0]
                    if first_reward.size == 0:
                        r_1[:] = neg_rew
                    else:
                        r_1[:first_reward[0]] = neg_rew
                flat_rewards.append(r_1)
                chop_lengths.append(l_1)

        chop_pov = np.stack(flat_pov)
        chop_binary_actions = np.stack(flat_binary_actions)
        chop_camera_actions = np.stack(flat_camera_actions)
        chop_rewards = np.stack(flat_rewards)
        chop_lengths = np.stack(chop_lengths)

        del pov, binary_actions, camera_actions, rewards, data_chop, flat_pov, flat_binary_actions, flat_camera_actions
        del flat_rewards, length

        # appending everything and creating a new data set.
        all_pov = np.concatenate((chop_pov, log_pov), axis=0)
        all_binary_actions = np.concatenate((chop_binary_actions, log_binary_actions), axis=0)
        all_camera_actions = np.concatenate((chop_camera_actions, log_camera_actions), axis=0)
        all_rewards = np.concatenate((chop_rewards, log_rewards), axis=0)
        all_lengths = np.concatenate((chop_lengths, log_lengths), axis=0)
        print(all_rewards.shape, all_pov.shape, all_binary_actions.shape, all_camera_actions.shape, all_lengths.shape)

        if args.store_rudder_samples:
            print("saving data")
            np.save("rudder_data/pov", all_pov)
            np.save("rudder_data/bin_actions", all_binary_actions)
            np.save("rudder_data/camera_actions", all_camera_actions)
            np.save("rudder_data/rewards", all_rewards)
            np.save("rudder_data/lengths", all_lengths)

    # instantiate environment
    env = make_minerl(args.env, n_cpu=args.num_env, seq_len=dataset.SEQ_LENGTH,
                        transforms=dataset.DATA_TRANSFORM, input_space=dataset.INPUT_SPACE,
                        env_server=True)

    runner = Runner(model=model, dataset=dataset, nenv=args.num_env, env=env, nsteps=args.nsteps,
                    seq_len=dataset.SEQ_LENGTH, gamma=args.gamma, lam=args.lam, episode_length=args.episode_length,
                    reward_scale=args.reward_scale, episodic_interactions=True, device=device)

    # TODO: Get conv layer output from the policy networks for the entire dataset
    rudder_rollouts = args.rudder_get_samples * args.rudder_episode_len // (args.episode_length * args.num_env)

    with torch.no_grad():
        b_states, b_rewards, b_dones, _, _, _, _, _ = runner.rollout()
        pov, bin_actions, camera_actions = runner.transform_state_dict(b_states)

    agent_pov, agent_bin_actions, agent_camera_actions, agent_b_rewards, agent_lengths = get_episode_from_traj(
        b_rewards,
        b_dones, pov,
        bin_actions,
        camera_actions,
        args)

    for i in range(rudder_rollouts):
        with torch.no_grad():
            b_states, b_rewards, b_dones, _, _, _, _, _ = runner.rollout()
            pov, bin_actions, camera_actions = runner.transform_state_dict(b_states)

        agent_pov_, agent_bin_actions_, agent_camera_actions_, agent_b_rewards_, agent_lengths_ = get_episode_from_traj(
                                                                                                    b_rewards,
                                                                                                    b_dones, pov,
                                                                                                    bin_actions,
                                                                                                    camera_actions,
                                                                                                    args)

        agent_pov = np.concatenate((agent_pov, agent_pov_), axis=0)
        agent_bin_actions = np.concatenate((agent_bin_actions, agent_bin_actions_), axis=0)
        agent_camera_actions = np.concatenate((agent_camera_actions, agent_camera_actions_), axis=0)
        agent_b_rewards = np.concatenate((agent_b_rewards, agent_b_rewards_), axis=0)
        agent_lengths = np.concatenate((agent_lengths, agent_lengths_), axis=0)

        print(agent_pov.shape, agent_bin_actions.shape, agent_camera_actions.shape, agent_b_rewards.shape)

    if args.store_rudder_samples:
        print("saving data")
        np.save("rudder_data/agent_pov", agent_pov)
        np.save("rudder_data/agent_bin_actions", agent_bin_actions)
        np.save("rudder_data/agent_camera_actions", agent_camera_actions)
        np.save("rudder_data/agent_rewards", agent_b_rewards)
        np.save("rudder_data/agent_lengths", agent_lengths)

    if args.load_rudder_data:
        all_pov = np.load('rudder_data/pov.npy')
        all_binary_actions = np.load("rudder_data/bin_actions.npy")
        all_camera_actions = np.load("rudder_data/camera_actions.npy")
        all_rewards = np.load("rudder_data/rewards.npy")
        all_lengths = np.load("rudder_data/lengths.npy")

    all_pov = np.concatenate((all_pov, agent_pov), axis=0)
    all_binary_actions = np.concatenate((all_binary_actions, agent_bin_actions), axis=0)
    all_camera_actions = np.concatenate((all_camera_actions, agent_camera_actions), axis=0)
    all_rewards = np.concatenate((all_rewards, agent_b_rewards), axis=0)
    all_lengths = np.concatenate((all_lengths, agent_lengths), axis=0)

    print(all_rewards.shape, all_pov.shape, all_binary_actions.shape, all_camera_actions.shape, all_lengths.shape)

    all_actions = np.concatenate((all_binary_actions, all_camera_actions), axis=-1)

    num_total_seq = all_pov.shape[0]
    num_epochs = args.rudder_epochs
    mini_batch_size = num_total_seq // args.num_mini_batch
    inds = np.arange(num_total_seq)

    lessons_buffer = LessonBuffer(3000, 512, (3, 48, 48), 10)

    # adding everything to the lessons buffer
    for i in range(num_total_seq):
        lessons_buffer.add(states=all_pov[i], actions=all_actions[i], rewards=all_rewards[i], length=all_lengths[i])
    print(lessons_buffer.curr_size)
    mean_return = np.mean(np.sum(all_rewards, axis=1))

    rr_model = RR_LSTM_ARCH(input_lstm=288, lstm_size=128, num_actions=10, duplication=10,
                            bias_mean=mean_return, device=device)
    rr_model = rr_model.to(device)

    rudder = RR_LSTM(buffer=lessons_buffer, model=rr_model, num_actions=10)

    writer = SummaryWriter(comment=args.experiment_name)
    rudder.update(writer, rudder_pretraining=True, batch_size=mini_batch_size)
    torch.save(rr_model.state_dict(), writer.get_logdir() + '/rudder')
    print("Completed training the RUDDER!")

    # for _ in range(1):
    #     np.random.shuffle(inds)
    #     for start in range(0, num_total_seq, mini_batch_size):
    #         end = start + mini_batch_size
    #         mbinds = inds[start:end]
    #         slices = (arr[mbinds] for arr in (all_pov, all_actions))
    #         # TODO: give to rudder for optimization
    #         o = rr_model.forward(*slices)
    #         print(o.shape)
    # # TODO: generate data or use old data and redistribute reward
