import diffuser.utils as utils
from diffuser.utils.trajectory import Trajectory
from diffuser.utils.trajectory import ReplayBuffer
import torch
from ml_logger import logger, RUN
from config.locomotion_config import Config
from copy import deepcopy
import numpy as np
import pickle
import os
import gym
from diffuser.models.bisimulation_metric_model import BisimNet
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.models.helpers import ForwardLoss
import random


def build_bisim_target(batch_size, reward, horizon, target_next_dist, gamma, logging=False):
    reward_tile = reward.tile((1, batch_size))
    reward_tile_t = torch.transpose(reward_tile, 0, 1)
    reward_diff = torch.abs(reward_tile - reward_tile_t)
    reward_diff = torch.reshape(reward_diff, (batch_size ** 2, 1))
    if logging:
        mean_reward_diff = torch.mean(reward_diff)
        logger.print("Average Reward Diff", mean_reward_diff)
    next_state_dist = horizon * target_next_dist
    return reward_diff + gamma * next_state_dist


def save(online, target, bucket):
    data = {
        'online': online.state_dict(),
        'target': target.state_dict()
    }
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    # logger.save_torch(data, savepath)

    savepath = os.path.join(savepath, 'bisim.pt')
    torch.save(data, savepath)


def concat_state(states, batch_size, transpose=False):
    state_expand = torch.unsqueeze(states, dim=0)
    state_tile = state_expand.tile((batch_size, 1, 1))
    state_tile_t = torch.transpose(state_tile, dim0=0, dim1=1)
    if transpose:
        concat_states = torch.cat((state_tile_t, state_tile), dim=2)
    else:
        concat_states = torch.cat((state_tile, state_tile_t), dim=2)
    representation_dim = states.shape[1]
    return torch.reshape(concat_states, (batch_size ** 2, representation_dim * 2))


def set_seed(seed, deterministic_torch=False):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def load(bucket):
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    savepath = os.path.join(savepath, 'trajectories.dat')
    with open(savepath, "rb") as f:
        traj = pickle.load(f)
    return traj


def train_bisim(**deps):
    RUN._update(deps)
    Config._update(deps)

    set_seed(Config.seed)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    load_check = True  # TODO: put it in the parameter input

    if load_check:
        trajectories = load(bucket=Config.bucket)
    else:
        return

    example_state, example_action, _ = trajectories[0].get_item(0)
    state_dim = example_state.shape[0]
    action_dim = example_action.shape[0]

    buffer_size = 2_000_000  # TODO: put it in the parameter input
    buffer = ReplayBuffer(state_dim=state_dim, action_dim=action_dim, buffer_size=buffer_size, device=Config.device,
                          seed=Config.seed)

    full_data = {}
    for i in range(len(trajectories)):
        data = trajectories[i].create_dict()
        if i == 0:
            full_data["observations"] = data["observations"]
            full_data["actions"] = data["actions"]
            full_data["rewards"] = data["rewards"]
            full_data["next_observations"] = data["next_observations"]
            full_data["terminals"] = data["terminals"]
        else:
            full_data["observations"] = np.concatenate((full_data["observations"], data["observations"]), axis=0)
            full_data["actions"] = np.concatenate((full_data["actions"], data["actions"]), axis=0)
            full_data["rewards"] = np.concatenate((full_data["rewards"], data["rewards"]), axis=0)
            full_data["next_observations"] = np.concatenate((full_data["next_observations"], data["next_observations"]),
                                                            axis=0)
            full_data["terminals"] = np.concatenate((full_data["terminals"], data["terminals"]), axis=0)
    buffer.load_data(full_data)
    num_layer = 1


    target_model = BisimNet(state_dim=state_dim, num_layers=num_layer).to(Config.device)
    online_model = BisimNet(state_dim=state_dim, num_layers=num_layer).to(Config.device)
    loss_func = torch.nn.MSELoss()
    target_optimizer = torch.optim.Adam(target_model.parameters(), lr=0.000075, eps=0.00015)
    online_optimizer = torch.optim.Adam(online_model.parameters(), lr=0.000075, eps=0.00015)
    epoch_num = 2000
    batch_size = 128
    gamma = 0.99
    horizon_discount = 0.95
    bisim_discount_value = 1.0
    horizon = 0
    C = 5

    for i in range(epoch_num):
        batch = buffer.sample(batch_size=batch_size)
        batch = [b.to(Config.device) for b in batch]
        state_online_dist = online_model(concat_state(batch[0], batch_size=batch_size))
        next_online_dist = target_model(concat_state(batch[3], batch_size=batch_size))
        with torch.no_grad():
            bisim_target = build_bisim_target(batch_size=batch_size, reward=batch[2], gamma=gamma, horizon=horizon, target_next_dist=next_online_dist)
        bisim_target.to(Config.device)

        zeros = torch.zeros(batch_size, batch_size).to(Config.device)
        diag_one = zeros.fill_diagonal_(1)
        ones = torch.ones(batch_size, batch_size).to(Config.device)
        diag_mask = ones - diag_one
        diag_mask = torch.reshape(diag_mask, (batch_size ** 2, 1))

        bisim_target *= diag_mask
        bisim_estimate = state_online_dist

        loss = loss_func(bisim_target, bisim_estimate)
        logger.print(f"epoch {i} loss", loss)

        online_optimizer.zero_grad()
        target_optimizer.zero_grad()
        loss.backward()

        if i % C == 0:
            target_optimizer.step()
            horizon = 1.0 - bisim_discount_value
            bisim_discount_value *= horizon_discount

        online_optimizer.step()

    save(online_model, target_model, Config.bucket)