import diffuser.utils as utils
from diffuser.utils.trajectory import Trajectory
from diffuser.utils.trajectory import ReplayBuffer
import torch
from ml_logger import logger, RUN
from config.locomotion_config import Config
from copy import deepcopy
import numpy as np
import pickle
import os
import gym
from diffuser.models.forward_dynamics import ForwardDynamics
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.models.helpers import ForwardLoss
import random


def save(models, bucket):
    data = {}
    for i in range(len(models)):
        data.update({f"model{i}": models[i].state_dict()})
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    # logger.save_torch(data, savepath)

    savepath = os.path.join(savepath, 'forward.pt')
    torch.save(data, savepath)


def set_seed(seed, deterministic_torch=False):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def load(bucket):
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    savepath = os.path.join(savepath, 'trajectories.dat')
    with open(savepath, "rb") as f:
        traj = pickle.load(f)
    return traj


# TODO: add model ensemble


def train_forward(**deps):
    RUN._update(deps)
    Config._update(deps)

    set_seed(Config.seed)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    dataset_config = utils.Config(
        Config.loader,
        savepath='dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        returns_scale=Config.returns_scale,
    )

    render_config = utils.Config(
        Config.renderer,
        savepath='render_config.pkl',
        env=Config.dataset,
    )

    dataset = dataset_config()
    renderer = render_config()

    load_check = True   # TODO: put it in the parameter input

    # if load_check:
    #     trajectories = load(bucket=Config.bucket)
    # else:
    #     return

    num_eval = 100
    device = Config.device

    env_list = [gym.make(Config.dataset) for _ in range(num_eval)]

    traj_length = 2000

    path_num = dataset.fields.normed_observations.shape[0]
    state_dim = env_list[0].observation_space.shape[0]
    action_dim = env_list[0].action_space.shape[0]
    length = [0 for i in range(num_eval)]
    traj = []

    # loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    # os.makedirs(loadpath, exist_ok=True)
    # loadpath = os.path.join(loadpath, 'trajectories.dat')
    # with open(loadpath, "rb") as f:
    #     traj = pickle.load(f)

    paths = []

    for path_ind in range(path_num):
        # path_ind, start, end = dataset.indices[i]
        observations = dataset.fields.normed_observations[path_ind, :dataset.fields.path_lengths[path_ind]]
        if path_ind == 0:
            obs_numpy = observations
        else:
            obs_numpy = np.concatenate((obs_numpy, observations), axis=0)
        actions = dataset.fields.normed_actions[path_ind, :dataset.fields.path_lengths[path_ind]]
        rewards = dataset.fields.rewards[path_ind, :dataset.fields.path_lengths[path_ind]]
        next_item = np.concatenate((dataset.fields.normed_observations[path_ind, 1:dataset.fields.path_lengths[path_ind]],
                                    np.array([dataset.fields.normed_observations[path_ind, dataset.fields.path_lengths[path_ind] - 1]])), axis=0)
        paths.append([observations, actions, rewards, next_item])

    # for path_ind in range(path_num):
    #     length = paths[path_ind][0].shape[0]
    #     for i in range(length):
    #         traj[path_ind].add_item(input_state=paths[path_ind][0][i], input_action=paths[path_ind][1][i],
    #                                 input_reward=paths[path_ind][2][i], idx=i)

    for path_ind in range(path_num):
        length = paths[path_ind][0].shape[0]
        obs = paths[path_ind][0]
        # actions = paths[path_ind][1]
        # rewards = paths[path_ind][2]
        # next_obs = paths[path_ind][3]
        if len(obs) != 0:
            temp_traj = Trajectory(length=traj_length, state_dim=state_dim, action_dim=action_dim)
            # new_paths.append([obs, actions, rewards, next_obs])
            for i in range(length):
                temp_traj.add_item(input_state=paths[path_ind][0][i], input_action=paths[path_ind][1][i],
                                            input_reward=paths[path_ind][2][i], idx=i)
            traj.append(temp_traj)

    # example_state, example_action, _ = traj[0].get_item(0)
    state_dim = paths[0][0].shape[1]
    action_dim = paths[0][1].shape[1]

    buffer_size = 20_000_000    # TODO: put it in the parameter input
    buffer = ReplayBuffer(state_dim=state_dim, action_dim=action_dim, buffer_size=buffer_size, device=Config.device, seed=Config.seed)

    full_data = {}
    for i in range(len(traj)):
        data = traj[i].create_dict()
        if i == 0:
            full_data["observations"] = data["observations"]
            full_data["actions"] = data["actions"]
            full_data["rewards"] = data["rewards"]
            full_data["next_observations"] = data["next_observations"]
            full_data["terminals"] = data["terminals"]
        else:
            full_data["observations"] = np.concatenate((full_data["observations"], data["observations"]), axis=0)
            full_data["actions"] = np.concatenate((full_data["actions"], data["actions"]), axis=0)
            full_data["rewards"] = np.concatenate((full_data["rewards"], data["rewards"]), axis=0)
            full_data["next_observations"] = np.concatenate((full_data["next_observations"], data["next_observations"]),
                                                            axis=0)
            full_data["terminals"] = np.concatenate((full_data["terminals"], data["terminals"]), axis=0)
    buffer.load_data(full_data)

    current_seed = Config.seed
    models = []
    for i in range(7):
        epoch_num = 2500    # TODO: put it in the parameter input
        hidden_dim = 200
        if i == 0:
            set_seed(current_seed + 2)
        elif i == 4:
            set_seed(current_seed + 2)
        # elif i == 2:
        #     set_seed(current_seed + 15)
        else:
            set_seed(current_seed + i)
        model = ForwardDynamics(state_dim=state_dim, hidden_dim=hidden_dim).to(Config.device)
        batch_size = 256
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)   # lr=3e-4
        loss_func = ForwardLoss().to(Config.device)

        for i in range(epoch_num):
            batch = buffer.sample(batch_size=batch_size)
            batch = [b.to(Config.device) for b in batch]
            state = batch[0]
            next_state = batch[3]
            comb = model(state)
            loss = loss_func(comb, state_dim, next_state, Config.device)
            logger.print(f"loss at epoch {i}", loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        current_seed += 1
        models.append(model)

    save(models, Config.bucket)
