import diffuser.utils as utils
from diffuser.utils.trajectory import Trajectory
from diffuser.utils.trajectory import ReplayBuffer
import torch
from ml_logger import logger, RUN
from config.locomotion_config import Config
from copy import deepcopy
import numpy as np
import pickle
import os
import gym
from diffuser.models.forward_dynamics import ForwardDynamics
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.models.helpers import ForwardLoss
import random


def save(model, bucket):
    data = {
        'model': model.state_dict(),
    }
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    # logger.save_torch(data, savepath)

    savepath = os.path.join(savepath, 'forward.pt')
    torch.save(data, savepath)


def set_seed(seed, deterministic_torch=False):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def load(bucket):
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    savepath = os.path.join(savepath, 'trajectories.dat')
    with open(savepath, "rb") as f:
        traj = pickle.load(f)
    return traj


# TODO: add model ensemble


def train_forward(**deps):
    RUN._update(deps)
    Config._update(deps)

    set_seed(Config.seed)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    load_check = True   # TODO: put it in the parameter input

    if load_check:
        trajectories = load(bucket=Config.bucket)
    else:
        return

    example_state, example_action, _ = trajectories[0].get_item(0)
    state_dim = example_state.shape[0]
    action_dim = example_action.shape[0]

    buffer_size = 2_000_000    # TODO: put it in the parameter input
    buffer = ReplayBuffer(state_dim=state_dim, action_dim=action_dim, buffer_size=buffer_size, device=Config.device, seed=Config.seed)

    full_data = {}
    for i in range(len(trajectories)):
        data = trajectories[i].create_dict()
        if i == 0:
            full_data["observations"] = data["observations"]
            full_data["actions"] = data["actions"]
            full_data["rewards"] = data["rewards"]
            full_data["next_observations"] = data["next_observations"]
            full_data["terminals"] = data["terminals"]
        else:
            full_data["observations"] = np.concatenate((full_data["observations"], data["observations"]), axis=0)
            full_data["actions"] = np.concatenate((full_data["actions"], data["actions"]), axis=0)
            full_data["rewards"] = np.concatenate((full_data["rewards"], data["rewards"]), axis=0)
            full_data["next_observations"] = np.concatenate((full_data["next_observations"], data["next_observations"]),
                                                            axis=0)
            full_data["terminals"] = np.concatenate((full_data["terminals"], data["terminals"]), axis=0)
    buffer.load_data(full_data)

    epoch_num = 650    # TODO: put it in the parameter input
    hidden_dim = 200
    model = ForwardDynamics(state_dim=state_dim, hidden_dim=hidden_dim).to(Config.device)
    batch_size = 256
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    loss_func = ForwardLoss().to(Config.device)

    for i in range(epoch_num):
        batch = buffer.sample(batch_size=batch_size)
        batch = [b.to(Config.device) for b in batch]
        state = batch[0]
        next_state = batch[3]
        comb = model(state)
        loss = loss_func(comb, state_dim, next_state, Config.device)
        logger.print(f"loss at epoch {i}", loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    save(model, Config.bucket)