import diffuser.utils as utils
from diffuser.utils.trajectory import Trajectory
from diffuser.utils.trajectory import ReplayBuffer
import torch
from ml_logger import logger, RUN
from config.locomotion_config import Config
import matplotlib.pyplot as plt
from diffuser.utils.arrays import batch_to_device, to_np, to_device, apply_dict
from copy import deepcopy
import numpy as np
import pickle
from diffuser.datasets.sequence import CustomSequenceDataset
import os
import gym
from diffuser.models.bisimulation_metric_model import BisimNet
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.models.helpers import ForwardLoss
import random


def build_bisim_target(batch_size, reward, horizon, target_next_dist, gamma, logging=False):
    reward_tile = reward.tile((1, batch_size))
    reward_tile_t = torch.transpose(reward_tile, 0, 1)
    reward_diff = torch.abs(reward_tile - reward_tile_t)
    reward_diff = torch.reshape(reward_diff, (batch_size ** 2, 1))
    if logging:
        mean_reward_diff = torch.mean(reward_diff)
        logger.print("Average Reward Diff", mean_reward_diff)
    next_state_dist = horizon * target_next_dist
    return reward_diff + gamma * next_state_dist


def save(online, target, bucket):
    data = {
        'online': online.state_dict(),
        'target': target.state_dict()
    }
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    # logger.save_torch(data, savepath)

    savepath = os.path.join(savepath, 'bisim.pt')
    torch.save(data, savepath)



def concat_state(states, batch_size, transpose=False):
    state_expand = torch.unsqueeze(states, dim=0)
    state_tile = state_expand.tile((batch_size, 1, 1))
    state_tile_t = torch.transpose(state_tile, dim0=0, dim1=1)
    if transpose:
        concat_states = torch.cat((state_tile_t, state_tile), dim=2)
    else:
        concat_states = torch.cat((state_tile, state_tile_t), dim=2)
    representation_dim = states.shape[1]
    return torch.reshape(concat_states, (batch_size ** 2, representation_dim * 2))


def set_seed(seed, deterministic_torch=False):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def load(bucket):
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    savepath = os.path.join(savepath, 'trajectories.dat')
    with open(savepath, "rb") as f:
        traj = pickle.load(f)
    return traj


def cycle(dl):
    while True:
        for data in dl:
            yield data


def train_bisim(**deps):
    RUN._update(deps)
    Config._update(deps)

    set_seed(Config.seed)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    # load_check = True  # TODO: put it in the parameter input
    #
    # if load_check:
    #     trajectories = load(bucket=Config.bucket)
    # else:
    #     return
    #
    # example_state, example_action, _ = trajectories[0].get_item(0)
    # state_dim = example_state.shape[0]
    # action_dim = example_action.shape[0]



    # added original dataset
    dataset_config = utils.Config(
        Config.loader,
        savepath='dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        returns_scale=Config.returns_scale,
    )

    dataset = dataset_config()
    batch_size = 128

    path_num = dataset.fields.normed_observations.shape[0]
    paths = []
    for path_ind in range(path_num):
        # path_ind, start, end = dataset.indices[i]
        observations = dataset.fields.normed_observations[path_ind, :dataset.fields.path_lengths[path_ind]]
        actions = dataset.fields.normed_actions[path_ind, :dataset.fields.path_lengths[path_ind]]
        rewards = dataset.fields.rewards[path_ind, :dataset.fields.path_lengths[path_ind]] / 4
        # TODO: replace the rewards with normed ones
        next_item = np.concatenate((dataset.fields.normed_observations[path_ind, 1:dataset.fields.path_lengths[path_ind]],
                                    np.array([dataset.fields.normed_observations[path_ind, dataset.fields.path_lengths[path_ind] - 1]])), axis=0)
        paths.append([observations, actions, rewards, next_item])
    state_dim = paths[0][0].shape[1]
    action_dim = paths[0][1].shape[1]

    buffer_size = 4_000_000  # TODO: put it in the parameter input
    buffer = ReplayBuffer(state_dim=state_dim, action_dim=action_dim, buffer_size=buffer_size, device=Config.device,
                          seed=Config.seed)

    new_dataset = CustomSequenceDataset(paths, dataset.indices)
    new_dataloader = cycle(torch.utils.data.DataLoader(
        new_dataset, batch_size=batch_size, num_workers=0, shuffle=True, pin_memory=True
    ))

    full_data = {}
    for i in range(len(paths)):
        # data = trajectories[i].create_dict()
        if i == 0:
            full_data["observations"] = paths[i][0]
            full_data["actions"] = paths[i][1]
            full_data["rewards"] = paths[i][2]
            full_data["next_observations"] = paths[i][3]
            full_data["terminals"] = np.zeros(paths[i][2].shape[0])
        else:
            full_data["observations"] = np.concatenate((full_data["observations"], paths[i][0]), axis=0)
            full_data["actions"] = np.concatenate((full_data["actions"], paths[i][1]), axis=0)
            full_data["rewards"] = np.concatenate((full_data["rewards"], paths[i][2]), axis=0)
            full_data["next_observations"] = np.concatenate((full_data["next_observations"], paths[i][3]),
                                                            axis=0)
            full_data["terminals"] = np.concatenate((full_data["terminals"], np.zeros(paths[i][2].shape[0])), axis=0)
    full_data["terminals"] = np.expand_dims(full_data["terminals"], axis=1)
    buffer.load_data(full_data)


    # full_data = {}
    # for i in range(len(trajectories)):
    #     data = trajectories[i].create_dict()
    #     if i == 0:
    #         full_data["observations"] = data["observations"]
    #         full_data["actions"] = data["actions"]
    #         full_data["rewards"] = data["rewards"]
    #         full_data["next_observations"] = data["next_observations"]
    #         full_data["terminals"] = data["terminals"]
    #     else:
    #         full_data["observations"] = np.concatenate((full_data["observations"], data["observations"]), axis=0)
    #         full_data["actions"] = np.concatenate((full_data["actions"], data["actions"]), axis=0)
    #         full_data["rewards"] = np.concatenate((full_data["rewards"], data["rewards"]), axis=0)
    #         full_data["next_observations"] = np.concatenate((full_data["next_observations"], data["next_observations"]),
    #                                                         axis=0)
    #         full_data["terminals"] = np.concatenate((full_data["terminals"], data["terminals"]), axis=0)
    # buffer.load_data(full_data)
    num_layer = 1


    target_model = BisimNet(state_dim=state_dim, num_layers=num_layer).to(Config.device)

    for param in target_model.parameters():
        param.requires_grad = False

    online_model = BisimNet(state_dim=state_dim, num_layers=num_layer).to(Config.device)
    loss_func = torch.nn.MSELoss()
    target_optimizer = torch.optim.Adam(target_model.parameters(), lr=0.000075, eps=0.00015)
    online_optimizer = torch.optim.Adam(online_model.parameters(), lr=0.000075, eps=0.00015)  # lr=0.000075, eps=0.00015
    epoch_num = 10000
    gamma = 0.99
    horizon_discount = 0.99
    bisim_discount_value = 1.0
    horizon = 0
    C = 500
    final_dist = None
    for i in range(epoch_num):
        batch = buffer.sample(batch_size=batch_size)
        batch = [b.to(Config.device) for b in batch]
        state_online_dist = online_model(concat_state(batch[0], batch_size=batch_size))
        final_dist = state_online_dist
        with torch.no_grad():
            next_online_dist = target_model(concat_state(batch[3], batch_size=batch_size))
            bisim_target = build_bisim_target(batch_size=batch_size, reward=batch[2], gamma=gamma, horizon=horizon, target_next_dist=next_online_dist)
            bisim_target.to(Config.device)

            zeros = torch.zeros(batch_size, batch_size).to(Config.device)
            diag_one = zeros.fill_diagonal_(1)
            ones = torch.ones(batch_size, batch_size).to(Config.device)
            diag_mask = ones - diag_one
            diag_mask = torch.reshape(diag_mask, (batch_size ** 2, 1))

            bisim_target *= diag_mask
        bisim_estimate = state_online_dist

        loss = loss_func(bisim_target, bisim_estimate)
        logger.print(f"epoch {i} loss", loss)

        online_optimizer.zero_grad()
        loss.backward()

        if i % C == 0:
            target_model.load_state_dict(online_model.state_dict())
            horizon = 1.0 - bisim_discount_value
            bisim_discount_value *= horizon_discount

        online_optimizer.step()

    logger.print("final_dist")
    save(online_model, target_model, Config.bucket)

    metrics = []
    for state in paths[0][0]:
        state_torch = torch.from_numpy(state).to(Config.device)
        ini_torch = torch.from_numpy(paths[0][0][0]).to(Config.device)
        state_comb = torch.unsqueeze(torch.concatenate((ini_torch, state_torch), dim=-1), dim=0)
        metric = online_model(state_comb)
        print(metric)
        metrics.append(metric.detach().cpu().numpy())

    metrics = np.array(metrics)
    metrics = np.squeeze(metrics)
    plt.plot(metrics)
    plt.show()