import json
from os.path import join
import os
import numpy as np
import time
import latentplan.utils as utils
import latentplan.datasets as datasets
from latentplan.search import (
    enumerate_all,
    sample_with_prior,
    sample_with_prior_tree,
    beam_with_prior,
    beam_with_uniform,
    beam_mimic,
    make_prefix,
    extract_actions,
    update_context,
)
import torch
from mujoco_stochastic import hopper_high_noise
from mujoco_stochastic import HopperPiecewiseScaleEnv
from mujoco_stochastic import walker2d_high_noise
import mujoco_stochastic

print("mujoco position",mujoco_stochastic.__file__)

from mujoco_stochastic import hopper_mod_noise
from mujoco_stochastic import walker2d_mod_noise
from currency_exchange import currency_exchange
from hiv_treatment import hiv_treatment
import matplotlib.pyplot as plt

class Parser(utils.Parser):
    dataset: str = 'halfcheetah-medium-expert-v2'
    config: str = 'config.vqvae'

#######################
######## setup ########
#######################

args = Parser().parse_args('plan')
args.nb_samples = int(args.nb_samples)
args.n_expand = int(args.n_expand)
args.beam_width = int(args.beam_width)
args.n_actions = int(args.n_actions)
args.b_percent = float(args.b_percent)
args.action_percent = float(args.action_percent)
args.pw_alpha = float(args.pw_alpha)
args.mcts_itr = int(args.mcts_itr)
args.horizon = int(args.horizon)
args.rounds = int(args.rounds)
args.logbase = os.path.expanduser(args.logbase)
args.savepath = os.path.expanduser(args.savepath)
args.uniform = bool(args.uniform)


try:
    args.prob_weight = float(args.prob_weight)
except:
    args.prob_weight = 5e2

print("check save path", args.savepath)
#######################
####### models ########
#######################
#enforce the certain model
#args.exp_name = 'T-1-1'
#env = datasets.load_environment(args.dataset)
#env.seed(1*int(args.suffix)*100)

d4rl_env = datasets.load_environment(args.dataset)
#d4rl_env.max_episode_steps = int(args.max_episode_steps)
#env = hopper_high_noise.HopperHighNoise(1*int(args.suffix)*100)
if 'hopper' in args.dataset:
    env = hopper_high_noise.HopperHighNoise()
    #env = d4rl_env
    #env = hopper_mod_noise.HopperModNoise()
    # Create the environment
    # env = HopperPiecewiseScaleEnv.HopperPiecewiseScaleEnv(
    #     xml_file="hopper.xml",   # or a path to your MuJoCo Hopper model
    #     max_episode_steps=1000,
    #     n_switches=4,
    #     transition_length=200,
    #     scale_max=5.0,
    #     seed=42,
    # )
    #
    observation = env.reset()
    terminal = False
    # total_reward = 0.0
    print("high noise")
elif 'walker2d' in args.dataset:
    env = walker2d_high_noise.Walker2DHighNoise()
    observation = env.reset()

else:
    env = d4rl_env
    #env.max_episode_steps = 100
    observation = env.reset()

#env = currency_exchange.CurrencyExchange()
#env = hiv_treatment.HIVTreatment()
#env.seed(int(args.suffix)*100)
dataset = utils.load_from_config(args.logbase, args.dataset, args.exp_name,
        'data_config.pkl')


gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.exp_name,
        epoch=args.gpt_epoch, device=args.device)
#gpt.reset_model()
#gpt.to('cuda')

if args.test_planner in ["sample_prior", "sample_prior_tree", "beam_prior", "beam_mimic", "beam_uniform", "mcts_prior"]:
    prior, _ = utils.load_prior_model(args.logbase, args.dataset, args.exp_name,
                                      epoch=args.gpt_epoch, device=args.device)

#prior.reset_model()


gpt.set_padding_vector(dataset.normalize_joined_single(np.zeros(gpt.transition_dim-1)))
#######################
####### dataset #######
#######################

if args.task_type == "locomotion":
   renderer = utils.make_renderer(args)
timer = utils.timer.Timer()

discount = dataset.discount
observation_dim = dataset.observation_dim
action_dim = dataset.action_dim

preprocess_fn = datasets.get_preprocess_fn(env.name)
#######################
###### main loop ######
#######################
REWARD_DIM = VALUE_DIM = 1
transition_dim = observation_dim + 3*action_dim + VALUE_DIM
print("observation_dim", observation_dim)
print("action_dim", action_dim)
total_reward = 0
discount_return = 0

if "antmaze" in env.name:
    if dataset.disable_goal:
        observation = np.concatenate([observation, np.zeros([2], dtype=np.float32)])
        rollout = [np.concatenate([env.state_vector().copy(), np.zeros([2], dtype=np.float32)])]
    else:
        observation = np.concatenate([observation, env.target_goal])
        rollout = [np.concatenate([env.state_vector().copy(), env.target_goal])]
else:
   rollout = [np.concatenate([env.state_vector().copy()])]

## previous (tokenized) transitions for conditioning transformer
context = []
mses = []

T = env.max_episode_steps
print(T)
frames = []
gpt.eval()
total_high_loss_count = 0
context_window = []
context_interval = 3  # Process every 3 time steps
max_groups = 6  # Maximum number of groups to keep (t dimension)
rolling_window_size = 3*6
temp_states = []  # Temporary storage for states in current interval
temp_actions = []  # Temporary storage for actions in current interval
context_matrix = None
rolling_window = []
for t in range(T):
    observation = preprocess_fn(observation)
    state = env.state_vector()
    if dataset.normalized_raw:
        observation = dataset.normalize_states(observation)

    if "antmaze" in env.name:
        if dataset.disable_goal:
            state = np.concatenate([state, np.zeros([2], dtype=np.float32)])
        else:
            state = np.concatenate([state, env.target_goal])

    if t % args.plan_freq == 0:
        ## concatenate previous transitions and current observations to input to model
        prefix = make_prefix(observation, transition_dim, device=args.device)[-1, -1, None, None]
        #print(observation)
        #print(prefix)
        ## sample sequence from model beginning with `prefix`
        if args.test_planner == 'beam_prior':
            prior.eval()
            start_time = time.time()
            sequence,_ = beam_with_prior(prior, gpt, prefix, context_matrix, denormalize_macro=dataset.denormalize_macro,
                              denormalize_val=dataset.denormalize_values,
                              normalize_val=dataset.normalize_values,
                              steps=int(args.horizon),
                              beam_width=args.beam_width,
                              n_expand=args.n_expand,
                              n_action=args.n_actions,
                              b_percent= args.b_percent,
                              action_percent = args.action_percent,
                              pw_alpha = args.pw_alpha,
                              mcts_itr = args.mcts_itr,
                              likelihood_weight=args.prob_weight,
                              prob_threshold=float(args.prob_threshold),
                              discount=discount)
            #total_high_loss_count += high_loss_count
            # print("context,",contex)
            end_time = time.time()  # End timer
            # Calculate the elapsed time in seconds
            elapsed_time = end_time - start_time
            print("decision time:", elapsed_time)

        elif args.test_planner == 'mcts_prior':
            prior.eval()
            start_time = time.time()
            sequence = mcts_with_prior(prior, gpt, prefix, denormalize_rew=dataset.denormalize_rewards,
                              denormalize_val=dataset.denormalize_values,
                              steps=int(args.horizon),
                              beam_width=args.beam_width,
                              n_expand=args.n_expand,
                              likelihood_weight=args.prob_weight,
                              prob_threshold=float(args.prob_threshold),
                              discount=discount)
            # print("context,",contex)
            end_time = time.time()  # End timer
            # Calculate the elapsed time in seconds
            elapsed_time = end_time - start_time
            print("decision time:", elapsed_time)

        elif args.test_planner == 'beam_uniform':
            prior.eval()
            sequence = beam_with_uniform(prior, gpt, prefix, denormalize_rew=dataset.denormalize_rewards,
                                       denormalize_val=dataset.denormalize_values,
                                       steps=int(args.horizon),
                                       beam_width=args.beam_width,
                                       n_expand=args.n_expand,
                                       prob_threshold=float(args.prob_threshold),
                                       discount=discount)
        elif args.test_planner == 'beam_mimic':
            prior.eval()
            sequence = beam_mimic(prior, gpt, prefix, denormalize_rew=dataset.denormalize_rewards,
                                       denormalize_val=dataset.denormalize_values,
                                       steps=int(args.horizon),
                                       beam_width=args.beam_width,
                                       n_expand=args.n_expand,
                                       prob_threshold=float(args.prob_threshold),
                                       discount=discount)
        elif args.test_planner == "enumerate":
            sequence = enumerate_all(gpt, prefix, denormalize_rew=dataset.denormalize_rewards,
                                     denormalize_val=dataset.denormalize_values,
                                     discount=discount)
        elif args.test_planner == "sample_prior":
            prior.eval()
            sequence = sample_with_prior(prior, gpt, prefix, denormalize_rew=dataset.denormalize_rewards,
                              denormalize_val=dataset.denormalize_values,
                              steps=int(args.horizon),
                              nb_samples=args.nb_samples,
                              rounds=args.rounds,
                              prob_threshold=float(args.prob_threshold),
                              likelihood_weight=args.prob_weight,
                              uniform=args.uniform,
                              discount=discount)
        elif args.test_planner == "sample_prior_tree":
            prior.eval()
            sequence = sample_with_prior_tree(prior, gpt, prefix, denormalize_rew=dataset.denormalize_rewards,
                                         denormalize_val=dataset.denormalize_values,
                                         steps=int(args.horizon) - args.max_context_transitions - 1,
                                         discount=discount)
    else:
        sequence = sequence[1:]

    if t == 0:
        first_value = float(dataset.denormalize_values(sequence[0,-2]))
        first_search_value = float(dataset.denormalize_values(sequence[-1, -2]))
    #print(dataset.denormalize_values(sequence[0,-2]))

    ## [ horizon x transition_dim ] convert sampled tokens to continuous latentplan
    sequence_recon = sequence

    ## [ action_dim ] index into sampled latentplan to grab first action

    feature_dim = dataset.observation_dim
    action = extract_actions(sequence_recon, feature_dim, action_dim, t=0)
    #print(action)
    # context_state = prefix[:, 0, :prior.observation_dim]
    # context_action = action
    # context_action_tensor = torch.tensor(context_action).to(args.device).view(1, 3)
    # if t % context_interval == 0:
    #     # Concatenate along dimension 1
    #     if t != 0:
    #         zero = torch.zeros(1, 1, device=combined_tensor.device)
    #         if context_matrix is None:
    #             combined_tensor = torch.cat([combined_tensor, zero], dim=1)
    #             context_matrix = combined_tensor.unsqueeze(1)
    #         else:
    #             combined_tensor = torch.cat([combined_tensor, zero], dim=1)
    #             if context_matrix.size(1) >= max_groups:
    #                 # Remove the first entry along dim=1 (the oldest one)
    #                 context_matrix = context_matrix[:, 1:, :]
    #             context_matrix = torch.cat([context_matrix, combined_tensor.unsqueeze(1)], dim=1)
    #     combined_tensor = None
    #     combined_tensor = torch.cat([context_state, context_action_tensor], dim=1)
    # else:
    #     combined_tensor = torch.cat([combined_tensor, context_action_tensor], dim=1)

    #print(context_matrix)

    if dataset.normalized_raw:
        action = dataset.denormalize_actions(action)
        #print(action)
        sequence_recon = dataset.denormalize_joined(sequence_recon)

    ## execute action in environment
    next_observation, reward, terminal, _ = env.step(action)


    # 1. Retrieve state and action separately
    context_state = prefix[:, 0, :prior.observation_dim]  # shape: (1, obs_dim)
    context_action_tensor = torch.tensor(dataset.normalize_actions(action), device=args.device).view(1, action.shape[0])  # shape: (1, act_dim), e.g., act_dim=3

    # 2. Append the tuple (state, action) to the rolling window
    #print("rolling window", len(rolling_window))
    rolling_window.append((context_state, context_action_tensor, reward))
    if len(rolling_window) > rolling_window_size:
        rolling_window.pop(0)  # Always keep only the most recent 18 entries
    #print(rolling_window)

    # 3. When it's time to build your context (e.g., every context_interval time steps)
    if t != 0:
        # Make a copy of the rolling window for chunk extraction,
        # so that the main rolling window remains unchanged.
        chunk_window = rolling_window
        # Compute how many entries are extra to form full chunks of size context_interval.
        remainder = len(chunk_window) % context_interval
        if remainder != 0:
            # Remove the extra entries from the beginning of the chunk window
            chunk_window = chunk_window[remainder:]
        # Skip if nothing remains after removing leftovers.
        if len(chunk_window) != 0:
        # 4. For each chunk of size context_interval, build a context row.
        # In each chunk, keep the full (state + action) for the first element,
        # and only the action for the subsequent elements.
            chunk_list = []
            for start_idx in range(0, len(chunk_window), context_interval):
                chunk = chunk_window[start_idx : start_idx + context_interval]  # a list of tuples
                # For the first element, combine state and action:
                first_state, first_action, reward = chunk[0]
                first_full = torch.cat([first_state, first_action], dim=1)  # shape: (1, obs_dim + act_dim)

                # For the rest, keep only the action:
                subsequent_actions = [pair[1] for pair in chunk[1:]]  # each has shape: (1, act_dim)
                subsequent_rewards = [pair[-1] for pair in chunk[1:]]  # each has shape: (1, act_dim)
                # Concatenate along feature dimension:
                # If context_interval==1, there are no subsequent actions, so we use only first_full.
                if subsequent_actions:
                    chunk_tensor = torch.cat([first_full] + subsequent_actions, dim=1)
                else:
                    chunk_tensor = first_full
                macro_rewards = (
                        torch.as_tensor(reward, device=chunk_tensor.device, dtype=chunk_tensor.dtype).sum()
                        + torch.as_tensor(subsequent_rewards, device=chunk_tensor.device,
                                          dtype=chunk_tensor.dtype).sum()
                ).reshape(1, 1)
                zero = torch.zeros(1, 1, device=chunk_tensor.device)
                #print((macro_rewards-dataset.macro_mean)/dataset.macro_std)
                chunk_tensor = torch.cat([zero, chunk_tensor, (macro_rewards-dataset.macro_mean)/dataset.macro_std], dim=1)
                #print(chunk_tensor.shape)
                # Append the chunk tensor to our list
                chunk_list.append(chunk_tensor)
            # 5. Concatenate all chunks (if there are more than one) into a context_matrix.
            # Each row represents one chunk.

            context_matrix = torch.cat(chunk_list, dim=0).unsqueeze(0)


    # start_time = time.time()  # Start timer
    if "antmaze" in env.name:
        if dataset.disable_goal:
            next_observation = np.concatenate([next_observation, np.zeros([2], dtype=np.float32)])
        else:
            next_observation = np.concatenate([next_observation, env.target_goal])


    ## update return
    total_reward += reward
    discount_return += reward* discount**(t)
    #print(total_reward)
    score = d4rl_env.get_normalized_score(total_reward)
    #score = total_reward/100
    rollout.append(state.copy())
    context = update_context(observation, action, reward, device=args.device)

    print(
        f'[ plan ] t: {t} / {T} | r: {reward:.2f} | R: {total_reward:.2f} | score: {score:.4f} | '
        f'time: {timer():.4f} | {args.dataset} | {args.exp_name} | {args.suffix}\n'
    )
    ########render here
    #renderer.render_matplotlib(state.copy())
    #print(state)
    #print(sequence_recon)
    #frames.append(frame)
    if t % args.vis_freq == 0 or terminal or t == T-1:
        if not os.path.exists(args.savepath):
            os.makedirs(args.savepath)
        # ffmpeg will report a error in some setup
        # if "antmaze" in env.name or "medium" in env.name:
        #     _, mse = renderer.render_plan(join(args.savepath, f'{t}_plan.mp4'),
        #                      sequence_recon, state)
        # else:
        #     #print(args.savepath, t, sequence_recon, state)
        #     _, mse = renderer.render_real(join(args.savepath, f'{t}_plan.mp4'),
        #                                   sequence_recon, state)


        ## save rollout thus far
        #renderer.render_rollout(join(args.savepath, f'rollout.mp4'), rollout, fps=30)
        # if not terminal:
        #     mses.append(mse)

    if terminal: break

    observation = next_observation
## save result as a json file
json_path = join(args.savepath, 'rollout.json')
json_data = {'score': score, 'step': t, 'return': total_reward, 'term': terminal, 'gpt_epoch': gpt_epoch,
             'first_value': first_value, 'first_search_value': first_search_value, 'discount_return': discount_return,
             'prediction_error': np.mean(mses)}
json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)
