
import diffuser.utils as utils
from ml_logger import logger, RUN
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from copy import deepcopy
import numpy as np
import os
import gym
from diffuser.utils.timer import Timer
from config.locomotion_config import Config
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.datasets.d4rl import suppress_output
from diffuser.models.value_func_model import ValueMLP
from diffuser.models.forward_dynamics import ForwardDynamics
from diffuser.models.bisimulation_metric_model import BisimNet
from diffuser.datasets.sequence import CustomSequenceDataset
from collections import namedtuple
from diffuser.utils.trajectory import Trajectory
from scripts.create_trajectory import save_traj
from scripts.create_trajectory import train_val
import pickle


RewardBatch = namedtuple('Batch', 'trajectories conditions returns rewards')


def save_data(dataset, bucket, part_num):
    savepath = os.path.join(bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    filename = "dataset" + str(part_num) + ".dat"
    savepath = os.path.join(savepath, filename)
    with open(savepath, "wb") as f:
        pickle.dump(dataset, f)


def cycle(dl):
    while True:
        for data in dl:
            yield data


def concat_state(states, batch_size, transpose=False):
    state_expand = torch.unsqueeze(states, dim=0)
    state_tile = state_expand.tile((batch_size, 1, 1))
    state_tile_t = torch.transpose(state_tile, dim0=0, dim1=1)
    if transpose:
        concat_states = torch.cat((state_tile_t, state_tile), dim=2)
    else:
        concat_states = torch.cat((state_tile, state_tile_t), dim=2)
    representation_dim = states.shape[1]
    return torch.reshape(concat_states, (batch_size ** 2, representation_dim * 2))


def retrain_val(**deps):
    RUN._update(deps)
    Config._update(deps)

    logger.remove('*.pkl')
    logger.remove("traceback.err")
    logger.log_params(Config=vars(Config), RUN=vars(RUN))

    Config.device = 'cuda'

    if Config.predict_epsilon:
        prefix = f'predict_epsilon_{Config.n_diffusion_steps}_1000000.0'
    else:
        prefix = f'predict_x0_{Config.n_diffusion_steps}_1000000.0'

    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    next_loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')

    if Config.save_checkpoints:
        loadpath = os.path.join(loadpath, f'state_{self.step}.pt')
    else:
        loadpath = os.path.join(loadpath, 'state.pt')

    if Config.save_checkpoints:
        next_loadpath = os.path.join(next_loadpath, f'state_{self.step}.pt')
    else:
        next_loadpath = os.path.join(next_loadpath, 'state_include_next.pt')

    state_dict = torch.load(loadpath, map_location=Config.device)
    next_state_dict = torch.load(next_loadpath, map_location=Config.device)

    # Load configs
    torch.backends.cudnn.benchmark = True
    utils.set_seed(Config.seed)

    dataset_config = utils.Config(
        Config.loader,
        savepath='dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        returns_scale=Config.returns_scale,
    )

    render_config = utils.Config(
        Config.renderer,
        savepath='render_config.pkl',
        env=Config.dataset,
    )

    dataset = dataset_config()
    renderer = render_config()

    observation_dim = dataset.observation_dim
    action_dim = dataset.action_dim

    if Config.diffusion == 'models.GaussianInvDynDiffusion' and Config.next_diffusion == 'models.NextGaussianInvDynDiffusion':
        transition_dim = observation_dim
    else:
        transition_dim = observation_dim + action_dim

    model_config = utils.Config(
        Config.model,
        savepath='model_config.pkl',
        horizon=Config.horizon,
        transition_dim=transition_dim,
        cond_dim=observation_dim,
        dim_mults=Config.dim_mults,
        dim=Config.dim,
        returns_condition=Config.returns_condition,
        device=Config.device,
    )

    next_model_config = utils.Config(
        Config.next_model,
        savepath='model_config.pkl',
        horizon=Config.horizon,
        transition_dim=transition_dim,
        cond_dim=observation_dim,
        dim_mults=Config.dim_mults,
        dim=Config.dim,
        returns_condition=Config.returns_condition,
        device=Config.device,
    )

    diffusion_config = utils.Config(
        Config.diffusion,
        savepath='diffusion_config.pkl',
        horizon=Config.horizon,
        observation_dim=observation_dim,
        action_dim=action_dim,
        n_timesteps=Config.n_diffusion_steps,
        loss_type=Config.loss_type,
        clip_denoised=Config.clip_denoised,
        predict_epsilon=Config.predict_epsilon,
        hidden_dim=Config.hidden_dim,
        ## loss weighting
        action_weight=Config.action_weight,
        loss_weights=Config.loss_weights,
        loss_discount=Config.loss_discount,
        returns_condition=Config.returns_condition,
        device=Config.device,
        condition_guidance_w=Config.condition_guidance_w,
    )

    next_diffusion_config = utils.Config(
        Config.next_diffusion,
        savepath='diffusion_config.pkl',
        horizon=Config.horizon,
        observation_dim=observation_dim,
        action_dim=action_dim,
        n_timesteps=Config.n_diffusion_steps,
        loss_type=Config.loss_type,
        clip_denoised=Config.clip_denoised,
        predict_epsilon=Config.predict_epsilon,
        hidden_dim=Config.hidden_dim,
        ## loss weighting
        action_weight=Config.action_weight,
        loss_weights=Config.loss_weights,
        loss_discount=Config.loss_discount,
        returns_condition=Config.returns_condition,
        device=Config.device,
        condition_guidance_w=Config.condition_guidance_w,
    )

    trainer_config = utils.Config(
        utils.Trainer,
        savepath='trainer_config.pkl',
        train_batch_size=Config.batch_size,
        train_lr=Config.learning_rate,
        gradient_accumulate_every=Config.gradient_accumulate_every,
        ema_decay=Config.ema_decay,
        sample_freq=Config.sample_freq,
        save_freq=Config.save_freq,
        log_freq=Config.log_freq,
        label_freq=int(Config.n_train_steps // Config.n_saves),
        save_parallel=Config.save_parallel,
        bucket=Config.bucket,
        n_reference=Config.n_reference,
        train_device=Config.device,
    )

    model = model_config()
    next_model = next_model_config()
    diffusion = diffusion_config(model)
    next_diffusion = next_diffusion_config(next_model)
    trainer = trainer_config(diffusion, dataset, renderer)
    next_trainer = trainer_config(next_diffusion, dataset, renderer)
    logger.print(utils.report_parameters(model), color='green')

    trainer.step = state_dict['step']
    trainer.model.load_state_dict(state_dict['model'])
    trainer.ema_model.load_state_dict(state_dict['ema'])

    next_trainer.step = next_state_dict['step']
    next_trainer.model.load_state_dict(next_state_dict['model'])
    next_trainer.ema_model.load_state_dict(next_state_dict['ema'])

    state_dim = observation_dim
    V_model_1 = ValueMLP(hidden_dim=256, input_dim=state_dim, output_dim=1).to(Config.device)
    V_model_2 = ValueMLP(hidden_dim=256, input_dim=state_dim, output_dim=1).to(Config.device)
    value_loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    value_loadpath = os.path.join(value_loadpath, 'V_model.pt')
    value_state_dict = torch.load(value_loadpath, map_location=Config.device)
    V_model_1.load_state_dict(value_state_dict['model1'])
    V_model_2.load_state_dict(value_state_dict['model2'])

    forward_models = []
    forward_loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    forward_loadpath = os.path.join(forward_loadpath, 'forward.pt')
    forward_state_dict = torch.load(forward_loadpath, map_location=Config.device)
    for i in range(7):
        hidden_dim = 200
        model = ForwardDynamics(state_dim=state_dim, hidden_dim=hidden_dim).to(Config.device)
        model_idx = "model" + str(i)
        model.load_state_dict(forward_state_dict[model_idx])
        forward_models.append(model)

    # path_num = len(dataset.indices)
    path_num = dataset.fields.normed_observations.shape[0]
    paths = []
    first_epoch = False # After the first epoch: also false
    obs_numpy = np.array([])
    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    loadpath = os.path.join(loadpath, "new_dataset.dat")
    if first_epoch:
        for path_ind in range(path_num):
            # path_ind, start, end = dataset.indices[i]
            observations = dataset.fields.normed_observations[path_ind, :dataset.fields.path_lengths[path_ind]]
            if path_ind == 0:
                obs_numpy = observations
            else:
                obs_numpy = np.concatenate((obs_numpy, observations), axis=0)
            actions = dataset.fields.normed_actions[path_ind, :dataset.fields.path_lengths[path_ind]]
            rewards = dataset.fields.rewards[path_ind, :dataset.fields.path_lengths[path_ind]] / 4
            next_item = np.concatenate(
                (dataset.fields.normed_observations[path_ind, 1:dataset.fields.path_lengths[path_ind]],
                 np.array([dataset.fields.normed_observations[path_ind, dataset.fields.path_lengths[path_ind] - 1]])),
                axis=0)
            paths.append([observations, actions, rewards])
    else:
        with open(loadpath, "rb") as f:
            paths = pickle.load(f)
        path_num = len(paths)
        for path_ind in range(path_num):
            if path_ind == 0:
                obs_numpy = paths[path_ind][0]
            else:
                obs_numpy = np.concatenate((obs_numpy, paths[path_ind][0]), axis=0)

    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    os.makedirs(loadpath, exist_ok=True)
    part_num = 5
    part_length = path_num // 6
    loadpath = os.path.join(loadpath, "part5.dat")
    with open(loadpath, "rb") as f:
        paths_part = pickle.load(f)

    # logger.print(paths[-1][0].shape, paths[-1][1].shape, paths[-1][2].shape, paths[-1][3].shape)

    num_layer = 1
    # bisim_model = BisimNet(state_dim=state_dim, num_layers=num_layer).to(Config.device)
    # bisim_loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    # bisim_loadpath = os.path.join(bisim_loadpath, 'bisim.pt')
    # bisim_state_dict = torch.load(bisim_loadpath, map_location=Config.device)
    # bisim_model.load_state_dict(bisim_state_dict['online'])

    discount = 0.99
    returns_scale = 400
    max_path_length = 1000
    discounts = discount ** np.arange(max_path_length)[:, None]
    epoch_num = 1
    train_diffuser_epoch = 10
    horizon = 10
    threshold = 0.1
    returns = to_device(Config.test_ret * torch.ones(1, 1), Config.device)
    # dataloader = cycle(torch.utils.data.DataLoader(
    #     dataset, batch_size=batch_size, num_workers=0, shuffle=True, pin_memory=True
    # ))
    timer = Timer()

    new_paths = []
    stitching_num = 0
    count = 0
    obs_numpy_new = np.array([])
    # loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    # loadpath = os.path.join(loadpath, "new_dataset.dat")
    # with open(loadpath, "rb") as f:
    #     paths = pickle.load(f)
    new_dataset = CustomSequenceDataset(paths, dataset.indices)
    new_dataloader = cycle(torch.utils.data.DataLoader(
        new_dataset, batch_size=Config.batch_size, num_workers=0, shuffle=True, pin_memory=True
    ))
    loadpath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    os.makedirs(loadpath, exist_ok=True)
    loadpath = os.path.join(loadpath, 'trajectories.dat')
    with open(loadpath, 'rb') as f:
        traj = pickle.load(f)
    for path_num in range(len(traj)):
        length = len(traj[path_num].state)
        observations = np.array([])
        actions = np.array([])
        rewards = np.array([])
        next_state = np.array([])
        for idx in range(length):
            if idx == 0:
                observations, actions, rewards, next_state = traj[path_num].get_item(idx)
                observations = np.array([observations])
                actions = np.array([actions])
                rewards = np.array([[rewards]])
                next_state = np.array([next_state])
            else:
                temp0, temp1, temp2, temp3 = traj[path_num].get_item(idx)
                observations = np.concatenate((observations, np.array([temp0])), axis=0)
                actions = np.concatenate((actions, np.array([temp1])), axis=0)
                rewards = np.concatenate((rewards, np.array([[temp2]])), axis=0)
                next_state = np.concatenate((next_state, np.array([temp3])), axis=0)

        paths.append([observations, actions, rewards, next_state])

    new_paths = paths
    savepath = os.path.join(Config.bucket, logger.prefix, 'checkpoint')
    os.makedirs(savepath, exist_ok=True)
    filename = "temp_new_dataset.dat"
    savepath = os.path.join(savepath, filename)
    with open(savepath, "wb") as f:
        pickle.dump(new_paths, f)

    path_num = len(new_paths)
    traj_length = 1000
    new_traj = [Trajectory(length=traj_length, state_dim=state_dim, action_dim=action_dim) for i in range(path_num)]
    for path_ind in range(path_num):
        length = paths[path_ind][0].shape[0]
        for i in range(length):
            new_traj[path_ind].add_item(input_state=paths[path_ind][0][i], input_action=paths[path_ind][1][i], input_reward=paths[path_ind][2][i], idx=i)

    # for i in range(num_eval):
    #     print(f"true {i}", true_list[i])
    #     print(f"reward {i}", reward_list[i] * 4)
    #
    # recorded_obs = np.concatenate(recorded_obs, axis=1)
    # savepath = os.path.join('images', f'sample-executed.png')
    # renderer.composite(savepath, recorded_obs)
    # episode_rewards = np.array(episode_rewards)
    #
    # logger.print(f"average_ep_reward: {np.mean(episode_rewards)}, std_ep_reward: {np.std(episode_rewards)}",
    #              color='green')
    # logger.log_metrics_summary(
    #     {'average_ep_reward': np.mean(episode_rewards), 'std_ep_reward': np.std(episode_rewards)})
    train_val(new_traj, state_dim, action_dim, Config.device, Config.seed, Config.bucket)
