import copy

import numpy as np
from preprocessing.training_reward import train
import multiprocessing
import torch


def reward_learning(dataset, source_dataset, state_dim, action_dim, var_coeff, writer, ensemble=1):
    batch_size = 1024
    lr = 3e-4
    num_epochs = 50
    # var_coeff = 10
    num_workers = min(64, multiprocessing.cpu_count())
    learned_rewards = train(dataset, source_dataset, state_dim, action_dim, batch_size, num_epochs, lr, writer,
                            num_workers, device=torch.device("cuda:0"), var_coeff=var_coeff, ensemble=ensemble)
    return learned_rewards


def merge_dataset(env, dataset, source_dataset, strategy="none", var_coeff=3, writer=None):
    print(dataset.size, source_dataset.size)
    print(writer)
    if strategy == "none":
        # don't share data
        return
    if strategy == "all":
        # share oracle rewards
        dataset.rewards = np.concatenate([dataset.rewards, source_dataset.rewards])
    elif strategy == "oracle":
        source_dataset.rewards = np.array(
            [env.compute_reward_general(obs) for obs in source_dataset.observations ])
        dataset.rewards = np.concatenate([dataset.rewards.reshape(-1), source_dataset.rewards.reshape(-1)])
    elif strategy == "learn":
        state_dim = dataset.observations.shape[1]
        action_dim = dataset.actions.shape[1]
        learned_rewards = reward_learning(dataset, source_dataset, state_dim, action_dim, var_coeff, writer)
        dataset.rewards = np.concatenate([dataset.rewards.reshape(-1), learned_rewards.reshape(-1)])
    elif strategy == "pess":
        state_dim = dataset.observations.shape[1]
        action_dim = dataset.actions.shape[1]
        learned_rewards = reward_learning(dataset, source_dataset, state_dim, action_dim, var_coeff, writer,
                                          ensemble=10)
        dataset.rewards = np.concatenate([dataset.rewards.reshape(-1), learned_rewards.reshape(-1)])
    elif strategy == "learn_all":
        # learn rewards and apply to all data
        tmp_dataset = copy.deepcopy(dataset)
        tmp_dataset.observations = np.concatenate([dataset.observations, source_dataset.observations])
        tmp_dataset.next_observations = np.concatenate([dataset.next_observations, source_dataset.next_observations])
        tmp_dataset.actions = np.concatenate([dataset.actions, source_dataset.actions])
        tmp_dataset.dones_float = np.concatenate([dataset.dones_float, source_dataset.dones_float])
        tmp_dataset.masks = np.concatenate([dataset.masks, source_dataset.masks])
        tmp_dataset.rewards = np.concatenate([dataset.rewards, source_dataset.rewards])
        tmp_dataset.size = dataset.size + source_dataset.size
        state_dim = dataset.observations.shape[1]
        action_dim = dataset.actions.shape[1]
        learned_rewards = reward_learning(dataset, tmp_dataset, state_dim, action_dim, var_coeff, writer, ensemble=1)
        dataset.rewards = learned_rewards.reshape(-1)
    elif strategy == "pess_all":
        # learn rewards and apply to all data
        tmp_dataset = copy.deepcopy(dataset)
        tmp_dataset.observations = np.concatenate([dataset.observations, source_dataset.observations])
        tmp_dataset.next_observations = np.concatenate([dataset.next_observations, source_dataset.next_observations])
        tmp_dataset.actions = np.concatenate([dataset.actions, source_dataset.actions])
        tmp_dataset.dones_float = np.concatenate([dataset.dones_float, source_dataset.dones_float])
        tmp_dataset.masks = np.concatenate([dataset.masks, source_dataset.masks])
        tmp_dataset.rewards = np.concatenate([dataset.rewards, source_dataset.rewards])
        tmp_dataset.size = dataset.size + source_dataset.size
        state_dim = dataset.observations.shape[1]
        action_dim = dataset.actions.shape[1]
        learned_rewards = reward_learning(dataset, tmp_dataset, state_dim, action_dim, var_coeff, writer, ensemble=10)
        dataset.rewards = learned_rewards.reshape(-1)
    elif strategy == "zero":
        # uds
        zero_rewards = np.zeros_like(source_dataset.rewards)
        dataset.rewards = np.concatenate([dataset.rewards, zero_rewards])
    else:
        print(f"Strategy {strategy} not found")
        raise NotImplementedError

    dataset.observations = np.concatenate([dataset.observations, source_dataset.observations])
    dataset.next_observations = np.concatenate([dataset.next_observations, source_dataset.next_observations])
    dataset.actions = np.concatenate([dataset.actions, source_dataset.actions])
    dataset.dones_float = np.concatenate([dataset.dones_float, source_dataset.dones_float])
    dataset.masks = np.concatenate([dataset.masks, source_dataset.masks])
    dataset.size = dataset.size + source_dataset.size
