import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional
from tqdm import trange
from config import TrainConfig
from utils import set_seed, set_env_seed, modify_reward, compute_mean_std, normalize_states, wrap_env, modify_reward_online, eval_actor, is_goal_reached
from utils import qlearning_dataset_with_timeouts, KD_tree
from buffer import ReplayBuffer, SequenceReplayBuffer
from network import Actor, Critic
from algo import AdvantageWeightedActorCritic

ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")

@pyrallis.wrap()
def train_awac(config: TrainConfig):
    if not config.algo == "awac":
        return
    env = gym.make(config.env_name)
    eval_env = gym.make(config.env_name)

    is_env_with_goal = config.env_name.startswith(ENVS_WITH_GOAL)

    max_steps = env._max_episode_steps

    set_seed(config.seed, env)
    set_env_seed(eval_env, config.eval_seed)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    dataset = d4rl.qlearning_dataset(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(dataset, config.env_name)

    state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_d4rl_dataset(dataset)

    actor_critic_kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "hidden_dim": config.hidden_dim,
    }

    actor = Actor(**actor_critic_kwargs)
    actor.to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.learning_rate)
    critic_1 = Critic(**actor_critic_kwargs)
    critic_2 = Critic(**actor_critic_kwargs)
    critic_1.to(config.device)
    critic_2.to(config.device)
    critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=config.learning_rate)
    critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=config.learning_rate)

    awac = AdvantageWeightedActorCritic(
        actor=actor,
        actor_optimizer=actor_optimizer,
        critic_1=critic_1,
        critic_1_optimizer=critic_1_optimizer,
        critic_2=critic_2,
        critic_2_optimizer=critic_2_optimizer,
        gamma=config.gamma,
        tau=config.tau,
        awac_lambda=config.awac_lambda,
    )

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    full_eval_scores, full_normalized_eval_scores = [], []
    state, done = env.reset(), False
    episode_step = 0
    episode_return = 0
    goal_achieved = False

    eval_successes = []
    train_successes = []

    print("Offline pretraining")
    for t in trange(
        int(config.offline_iterations) + int(config.online_iterations), ncols=80
    ):
        if t == config.offline_iterations:
            print("Online tuning")
        online_log = {}
        if t >= config.offline_iterations:
            episode_step += 1
            action, _ = actor(
                torch.tensor(
                    state.reshape(1, -1), device=config.device, dtype=torch.float32
                )
            )
            action = action.cpu().data.numpy().flatten()
            next_state, reward, done, env_infos = env.step(action)

            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)
            episode_return += reward
            real_done = False  # Episode can timeout which is different from done
            if done and episode_step < max_steps:
                real_done = True

            if config.normalize_reward:
                reward = modify_reward_online(reward, config.env_name, **reward_mod_dict)

            replay_buffer.add_transition(state, action, reward, next_state, real_done)
            state = next_state
            if done:
                state, done = env.reset(), False
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if is_env_with_goal:
                    train_successes.append(goal_achieved)
                    online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                    online_log["train/is_success"] = float(goal_achieved)
                online_log["train/episode_return"] = episode_return
                normalized_return = eval_env.get_normalized_score(episode_return)
                online_log["train/d4rl_normalized_episode_return"] = (
                    normalized_return * 100.0
                )
                online_log["train/episode_length"] = episode_step
                episode_return = 0
                episode_step = 0
                goal_achieved = False

        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        update_result = awac.update(batch)
        update_result[
            "offline_iter" if t < config.offline_iterations else "online_iter"
        ] = (t if t < config.offline_iterations else t - config.offline_iterations)
        update_result.update(online_log)
        if (t + 1) % config.eval_frequency == 0:
            eval_scores, success_rate = eval_actor(
                eval_env, actor, config.device, config.n_test_episodes, config.test_seed
            )
            eval_log = {}

            full_eval_scores.append(eval_scores)
            if hasattr(eval_env, "get_normalized_score"):
                normalized = eval_env.get_normalized_score(np.mean(eval_scores))
                
                if t >= config.offline_iterations:
                    iteration = t - config.offline_iterations
                    normalized_eval_score = normalized * 100.0
                    normalized_success_rate = success_rate * 100.0
                    with open(config.results_path, 'a') as file:
                        if is_env_with_goal:
                            file.write(f"{iteration} {normalized_success_rate}\n")
                        else:
                            file.write(f"{iteration} {normalized_eval_score}\n")                

                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if t >= config.offline_iterations and is_env_with_goal:
                    eval_successes.append(success_rate)
                    eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                    eval_log["eval/success_rate"] = success_rate
                normalized_eval_scores = normalized * 100.0
                full_normalized_eval_scores.append(normalized_eval_scores)
                eval_log["eval/d4rl_normalized_score"] = normalized_eval_scores
            
            if t < int(config.offline_iterations) and config.checkpoints_path:
                torch.save(
                    awac.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint.pt"),
                )

@pyrallis.wrap()
def train_pes_awac(config: TrainConfig):
    if config.algo != 'pes-awac':
        return
    env = gym.make(config.env_name)

    is_env_with_goal = config.env_name.startswith(ENVS_WITH_GOAL)

    max_episode_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = qlearning_dataset_with_timeouts(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(dataset, config.env_name)

    state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    
    offline_replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    online_replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    sequence_replay_buffer = SequenceReplayBuffer(
        state_dim,
        action_dim,
        config.max_sequence_length,
        config.num_sequence,
        config.env_name
    )

    offline_replay_buffer.load_d4rl_dataset(dataset)
    sequence_replay_buffer.load_dataset(dataset)
    online_replay_buffer.load_d4rl_dataset(sequence_replay_buffer.get_buffer_data_dict())
    offline_kd_tree = KD_tree(data = sequence_replay_buffer.get_obs_action_concat(), k = 1)

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)

    actor_critic_kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "hidden_dim": config.hidden_dim,
    }

    actor = Actor(**actor_critic_kwargs)
    actor.to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.learning_rate)
    critic_1 = Critic(**actor_critic_kwargs)
    critic_2 = Critic(**actor_critic_kwargs)
    critic_1.to(config.device)
    critic_2.to(config.device)
    critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=config.learning_rate)
    critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=config.learning_rate)

    awac = AdvantageWeightedActorCritic(
        actor=actor,
        actor_optimizer=actor_optimizer,
        critic_1=critic_1,
        critic_1_optimizer=critic_1_optimizer,
        critic_2=critic_2,
        critic_2_optimizer=critic_2_optimizer,
        gamma=config.gamma,
        tau=config.tau,
        awac_lambda=config.awac_lambda,
    )

    print("---------------------------------------")
    print(f"Training AWAC, Env: {config.env_name}, Seed: {seed}")
    print("---------------------------------------")

    if config.load_model == True:
        weight_path = config.checkpoints_path + "/checkpoint.pt"
        awac.load_state_dict(torch.load(weight_path,map_location=config.device))
        actor = awac._actor

    evaluations = []
    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    episode_data = {
        "observations": [],
        "actions": [],
        "rewards": [],
        "next_observations": [],
        "terminals": []
    }
    goal_achieved = False

    eval_successes = []
    train_successes = []

    if config.load_model == True:
        t = config.offline_iterations
    else:
        t = 0
        print("Offline pretraining")

    while(t < config.offline_iterations + config.online_iterations):
        if t == config.offline_iterations:
            print("Online tuning")
        online_log = {}
        if t >= config.offline_iterations:
            episode_step += 1
            action, _ = actor(
                torch.tensor(
                    state.reshape(1, -1), device=config.device, dtype=torch.float32
                )
            )
            action = action.cpu().data.numpy().flatten()
            next_state, reward, done, env_infos = env.step(action)
            
            real_done = False
            if done and episode_step < max_episode_steps:
                real_done = True

            episode_data["observations"].append(state)
            episode_data["actions"].append(action)
            episode_data["next_observations"].append(next_state)
            episode_data["rewards"].append(reward)
            episode_data["terminals"].append(real_done)

            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)

            episode_return += reward

            if config.normalize_reward:
                reward = modify_reward_online(
                    reward,
                    config.env_name,
                    **reward_mod_dict,
                )
            
            if online_replay_buffer._size < 999:
                online_replay_buffer.add_transition(state, action, reward, next_state, real_done)
                t += 1
                continue

            distance = offline_kd_tree.query(np.concatenate((state, action), axis = -1)).item()
            if distance < (config.threshold_distance + config.threshold_coefficient * t / config.offline_iterations):
                online_replay_buffer.add_transition(state, action, reward, next_state, real_done)
            state = next_state
            if done or episode_step >= config.max_sequence_length or t + 1 == config.offline_iterations + config.online_iterations:
                episode_data["observations"] = np.array(episode_data["observations"])
                episode_data["actions"] = np.array(episode_data["actions"])
                episode_data["next_observations"] = np.array(episode_data["next_observations"])
                episode_data["rewards"] = np.array(episode_data["rewards"])
                episode_data["terminals"] = np.array(episode_data["terminals"])

                if sequence_replay_buffer.update_top_episodes(episode_data):
                    offline_kd_tree.update(sequence_replay_buffer.get_obs_action_concat())

                state, done = env.reset(), False
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if is_env_with_goal:
                    train_successes.append(goal_achieved)
                    online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                    online_log["train/is_success"] = float(goal_achieved)
                online_log["train/episode_return"] = episode_return
                normalized_return = env.get_normalized_score(episode_return)
                online_log["train/d4rl_normalized_episode_return"] = (
                    normalized_return * 100.0
                )
                online_log["train/episode_length"] = episode_step
                episode_return = 0
                episode_step = 0
                episode_data = {
                "observations": [],
                "actions": [],
                "rewards": [],
                "next_observations": [],
                "terminals": []
                }
                goal_achieved = False

        online_batch_size = int(config.batch_size * config.online_ratio)
        offline_batch_size = config.batch_size - online_batch_size

        assert online_replay_buffer._size >= online_batch_size
        assert offline_replay_buffer._size >= offline_batch_size

        online_batch = online_replay_buffer.sample(online_batch_size)
        offline_batch = offline_replay_buffer.sample(offline_batch_size)

        batch = [torch.cat((online_batch_key, offline_batch_key), dim=0) for online_batch_key, offline_batch_key in zip(online_batch, offline_batch)]
        #batch = offline_replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        log_dict = awac.update(batch)
        log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
            t if t < config.offline_iterations else t - config.offline_iterations
        )
        log_dict.update(online_log)
        #wandb.log(log_dict, step=trainer.total_it)
        # Evaluate episode
        if (t + 1) % config.eval_frequency == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                env,
                actor,
                device=config.device,
                n_episodes=config.n_test_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = env.get_normalized_score(np.mean(eval_scores))
            if t >= config.offline_iterations:
                iteration = t - config.offline_iterations
                normalized_eval_score = normalized * 100.0
                with open(config.results_path, 'a') as file:
                    file.write(f"{iteration} {normalized_eval_score}\n")
                    
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if t >= config.offline_iterations and is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate

            normalized_eval_score = normalized * 100.0
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            evaluations.append(normalized_eval_score)
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_test_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")
            #wandb.log(eval_log, step=trainer.total_it)

        if t < config.offline_iterations and (t + 1) % config.save_checkpoints_freq == 0:
            if config.checkpoints_path:
                torch.save(
                    awac.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint.pt"),
                )
        t += 1
                
if __name__ == "__main__":
    train_awac()
    train_pes_awac()
    
    
