import diffuser.utils as utils
from diffuser.utils.trajectory import Trajectory
from diffuser.utils.trajectory import ReplayBuffer
import torch
from ml_logger import logger, RUN
from config.locomotion_config import Config
from copy import deepcopy
import numpy as np
import pickle
import os
import gym
from diffuser.models.value_func_model import ValueMLP
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.models.helpers import BellmanLoss
import random


def save_traj(traj, bucket):
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    savepath = os.path.join(savepath, 'trajectories.dat')
    with open(savepath, "wb") as f:
        pickle.dump(traj, f)


def save(model1, model2, bucket):
    data = {
        'model1': model1.state_dict(),
        'model2': model2.state_dict()
    }
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    # logger.save_torch(data, savepath)

    savepath = os.path.join(savepath, 'V_model.pt')
    torch.save(data, savepath)


def set_seed(seed, deterministic_torch=False):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def train_val(trajectories, state_dim, action_dim, device, seed, bucket):
    batch_size = 256
    set_seed(seed)
    model1 = ValueMLP(hidden_dim=256, input_dim=state_dim, output_dim=1).to(device)
    model2 = ValueMLP(hidden_dim=256, input_dim=state_dim, output_dim=1).to(device)
    loss = BellmanLoss()
    optimizer1 = torch.optim.Adam(model1.parameters(), lr=3e-4)
    optimizer2 = torch.optim.Adam(model2.parameters(), lr=3e-4)
    epoch_num = 1000

    buffer_size = 2_000_000
    buffer = ReplayBuffer(state_dim=state_dim, action_dim=action_dim, buffer_size=buffer_size, device=device, seed=seed)

    full_data = {}
    for i in range(len(trajectories)):
        data = trajectories[i].create_dict()
        if i == 0:
            full_data["observations"] = data["observations"]
            full_data["actions"] = data["actions"]
            full_data["rewards"] = data["rewards"]
            full_data["next_observations"] = data["next_observations"]
            full_data["terminals"] = data["terminals"]
        else:
            full_data["observations"] = np.concatenate((full_data["observations"], data["observations"]), axis=0)
            full_data["actions"] = np.concatenate((full_data["actions"], data["actions"]), axis=0)
            full_data["rewards"] = np.concatenate((full_data["rewards"], data["rewards"]), axis=0)
            full_data["next_observations"] = np.concatenate((full_data["next_observations"], data["next_observations"]), axis=0)
            full_data["terminals"] = np.concatenate((full_data["terminals"], data["terminals"]), axis=0)
    buffer.load_data(full_data)

    for i in range(epoch_num):
        batch = buffer.sample(batch_size=batch_size)
        batch = [b.to(device) for b in batch]
        value_state_1 = model1(batch[0])
        value_next_1 = model1(batch[3])
        value_state_2 = model2(batch[0])
        value_next_2 = model2(batch[3])
        loss_1 = loss(value_state_1, value_next_1, batch[2])
        loss_2 = loss(value_state_2, value_next_2, batch[2])
        total_loss = loss_1 + loss_2

        logger.print(f"total loss at epoch {i}", total_loss)

        optimizer1.zero_grad()
        optimizer2.zero_grad()
        total_loss.backward()
        optimizer1.step()
        optimizer2.step()

    save(model1, model2, bucket)


def create(**deps):

    RUN._update(deps)
    Config._update(deps)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    Config.device = 'cuda'

    if Config.predict_epsilon:
        prefix = f'predict_epsilon_{Config.n_diffusion_steps}_1000000.0'
    else:
        prefix = f'predict_x0_{Config.n_diffusion_steps}_1000000.0'

    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')

    if Config.save_checkpoints:
        loadpath = os.path.join(loadpath, f'state_{self.step}.pt')
    else:
        loadpath = os.path.join(loadpath, 'state.pt')

    state_dict = torch.load(loadpath, map_location=Config.device)

    # Load configs
    torch.backends.cudnn.benchmark = True
    utils.set_seed(Config.seed)

    dataset_config = utils.Config(
        Config.loader,
        savepath='dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        returns_scale=Config.returns_scale,
    )

    render_config = utils.Config(
        Config.renderer,
        savepath='render_config.pkl',
        env=Config.dataset,
    )

    dataset = dataset_config()
    renderer = render_config()

    observation_dim = dataset.observation_dim
    action_dim = dataset.action_dim

    if Config.diffusion == 'models.GaussianInvDynDiffusion':
        transition_dim = observation_dim
    else:
        transition_dim = observation_dim + action_dim

    model_config = utils.Config(
        Config.model,
        savepath='model_config.pkl',
        horizon=Config.horizon,
        transition_dim=transition_dim,
        cond_dim=observation_dim,
        dim_mults=Config.dim_mults,
        dim=Config.dim,
        returns_condition=Config.returns_condition,
        device=Config.device,
    )

    diffusion_config = utils.Config(
        Config.diffusion,
        savepath='diffusion_config.pkl',
        horizon=Config.horizon,
        observation_dim=observation_dim,
        action_dim=action_dim,
        n_timesteps=Config.n_diffusion_steps,
        loss_type=Config.loss_type,
        clip_denoised=Config.clip_denoised,
        predict_epsilon=Config.predict_epsilon,
        hidden_dim=Config.hidden_dim,
        ## loss weighting
        action_weight=Config.action_weight,
        loss_weights=Config.loss_weights,
        loss_discount=Config.loss_discount,
        returns_condition=Config.returns_condition,
        device=Config.device,
        condition_guidance_w=Config.condition_guidance_w,
    )

    trainer_config = utils.Config(
        utils.Trainer,
        savepath='trainer_config.pkl',
        train_batch_size=Config.batch_size,
        train_lr=Config.learning_rate,
        gradient_accumulate_every=Config.gradient_accumulate_every,
        ema_decay=Config.ema_decay,
        sample_freq=Config.sample_freq,
        save_freq=Config.save_freq,
        log_freq=Config.log_freq,
        label_freq=int(Config.n_train_steps // Config.n_saves),
        save_parallel=Config.save_parallel,
        bucket=Config.bucket,
        n_reference=Config.n_reference,
        train_device=Config.device,
    )

    model = model_config()
    diffusion = diffusion_config(model)
    trainer = trainer_config(diffusion, dataset, renderer)
    logger.print(utils.report_parameters(model), color='green')
    trainer.step = state_dict['step']
    trainer.model.load_state_dict(state_dict['model'])
    trainer.ema_model.load_state_dict(state_dict['ema'])

    num_eval = 2000
    device = Config.device

    env_list = [gym.make(Config.dataset) for _ in range(num_eval)]
    dones = [0 for _ in range(num_eval)]
    episode_rewards = [0 for _ in range(num_eval)]

    assert trainer.ema_model.condition_guidance_w == Config.condition_guidance_w
    returns = to_device(Config.test_ret * torch.ones(num_eval, 1), device)

    t = 0
    obs_list = [env.reset()[None] for env in env_list]
    obs = np.concatenate(obs_list, axis=0)
    recorded_obs = [deepcopy(obs[:, None])]

    reward_list = [[] for i in range(num_eval)]
    true_list = [[] for i in range(num_eval)]

    traj_length = 1000
    state_dim = env_list[0].observation_space.shape[0]
    action_dim = env_list[0].action_space.shape[0]
    length = [0 for i in range(num_eval)]
    traj = [Trajectory(length=traj_length, state_dim=state_dim, action_dim=action_dim) for i in range(num_eval)]

    while sum(dones) < num_eval:
        obs = dataset.normalizer.normalize(obs, 'observations')
        conditions = {0: to_torch(obs, device=device)}
        samples = trainer.ema_model.conditional_sample(conditions, returns=returns)
        samples, sample_reward = samples[:, :, :-1], samples[:, :, -1]
        obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)  # TODO:increase window length
        obs_comb = obs_comb.reshape(-1, 2 * observation_dim)
        action = trainer.ema_model.inv_model(obs_comb)

        samples = to_np(samples)
        action = to_np(action)
        sample_t = to_np(samples[:, 0, :])

        reward_t = to_np(sample_reward[:, 0])
        reward_t_1 = to_np(sample_reward[:, 1])

        action = dataset.normalizer.unnormalize(action, 'actions')

        if t == 0:
            normed_observations = samples[:, :, :]
            observations = dataset.normalizer.unnormalize(normed_observations, 'observations')
            savepath = os.path.join('images', 'sample-planned.png')
            renderer.composite(savepath, observations)

        obs_list = []

        reward_t = np.array(reward_t)
        reward_t = reward_t * 4

        for i in range(num_eval):
            traj[i].add_item(input_state=sample_t[i, :], input_action=action[i], input_reward=reward_t[i], idx=length[i])
            length[i] += 1
            # this_obs, this_reward, this_done, _ = env_list[i].step(action[i])
            # true_list[i].append(this_reward)
            # reward_list[i].append(reward_t)
            # logger.print(f"R_t: {R_t[i] * 400 * 4}")    #for horizon 100 (100 * 3)
            # logger.print(f"R_t_1: {R_t_1[i] * 400 * 4}")
            # logger.print(f"this reward {this_reward}")

            # obs_list.append(this_obs[None])
            if length[i] == traj_length:
                dones[i] = 1

        # obs = np.concatenate(obs_list, axis=0)
        # recorded_obs.append(deepcopy(obs[:, None]))
        t += 1

    save_traj(traj, Config.bucket)

    # for i in range(num_eval):
    #     print(f"true {i}", true_list[i])
    #     print(f"reward {i}", reward_list[i] * 4)
    #
    # recorded_obs = np.concatenate(recorded_obs, axis=1)
    # savepath = os.path.join('images', f'sample-executed.png')
    # renderer.composite(savepath, recorded_obs)
    # episode_rewards = np.array(episode_rewards)
    #
    # logger.print(f"average_ep_reward: {np.mean(episode_rewards)}, std_ep_reward: {np.std(episode_rewards)}",
    #              color='green')
    # logger.log_metrics_summary(
    #     {'average_ep_reward': np.mean(episode_rewards), 'std_ep_reward': np.std(episode_rewards)})
    train_val(traj, state_dim, action_dim, device, Config.seed, Config.bucket)


if __name__ == "__main__":
    create()