import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import TrainConfig, ENVS_WITH_GOAL
from utils import modify_reward, compute_mean_std, normalize_states, wrap_env, set_seed, soft_update, eval_actor, modify_reward_online, is_goal_reached, KD_tree
from buffer import ReplayBuffer, SequenceReplayBuffer
from network import Actor, Critic
from algo import TD3_BC
from utils import qlearning_dataset_with_timeouts
import time

@pyrallis.wrap()
def train_td3_bc(config: TrainConfig):
    if config.algo != 'td3-bc':
        return
    env = gym.make(config.env)
    #eval_env = gym.make(config.env)
    
    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)
    max_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = d4rl.qlearning_dataset(env)
    
    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(
            dataset,
            config.env,
            reward_scale=config.reward_scale,
            reward_bias=config.reward_bias,
        )

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_d4rl_dataset(dataset)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)

    actor = Actor(state_dim, action_dim, max_action).to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

    critic_1 = Critic(state_dim, action_dim).to(config.device)
    critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)
    critic_2 = Critic(state_dim, action_dim).to(config.device)
    critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "critic_1": critic_1,
        "critic_1_optimizer": critic_1_optimizer,
        "critic_2": critic_2,
        "critic_2_optimizer": critic_2_optimizer,
        "discount": config.discount,
        "tau": config.tau,
        "device": config.device,
        # TD3
        "policy_noise": config.policy_noise * max_action,
        "noise_clip": config.noise_clip * max_action,
        "policy_freq": config.policy_freq,
        # TD3 + BC
        "alpha": config.alpha,
    }

    print("---------------------------------------")
    print(f"Training TD3 + BC, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = TD3_BC(**kwargs)

    if config.load_model == True:
        weight_path = config.checkpoints_path + "/checkpoint.pt"
        trainer.load_state_dict(torch.load(weight_path, map_location=config.device))
        actor = trainer.actor
        t = config.offline_iterations
    else:
        t = 0
        print("Offline pretraining")

    evaluations = []
    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    goal_achieved = False

    eval_successes = []
    train_successes = []
    
    last_time = 0
    cur_time = 0

    while(t < config.offline_iterations + config.online_iterations):
        if t == config.offline_iterations:
            print("Online tuning")
        online_log = {}
        if t >= config.offline_iterations:
            episode_step += 1
            action = actor(
                torch.tensor(
                    state.reshape(1, -1), device=config.device, dtype=torch.float32
                )
            )
            action = action.cpu().data.numpy().flatten()
            next_state, reward, done, env_infos = env.step(action)
            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)

            episode_return += reward
            real_done = False  # Episode can timeout which is different from done
            if done and episode_step < max_steps:
                real_done = True

            if config.normalize_reward:
                reward = modify_reward_online(
                    reward,
                    config.env,
                    reward_scale=config.reward_scale,
                    reward_bias=config.reward_bias,
                    **reward_mod_dict,
                )
            replay_buffer.add_transition(state, action, reward, next_state, real_done)
            state = next_state

            if done:
                state, done = env.reset(), False
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if is_env_with_goal:
                    train_successes.append(goal_achieved)
                    online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                    online_log["train/is_success"] = float(goal_achieved)
                online_log["train/episode_return"] = episode_return
                normalized_return = env.get_normalized_score(episode_return)
                online_log["train/d4rl_normalized_episode_return"] = (
                    normalized_return * 100.0
                )
                online_log["train/episode_length"] = episode_step
                episode_return = 0
                episode_step = 0
                goal_achieved = False

        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        if config.heterogeneous:
            log_dict = trainer.heter_train(batch)
        else:
            log_dict = trainer.train(batch)
        log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
            t if t < config.offline_iterations else t - config.offline_iterations
        )
        log_dict.update(online_log)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = env.get_normalized_score(np.mean(eval_scores))
            if t >= config.offline_iterations:
                iteration = t - config.offline_iterations
                normalized_eval_score = normalized * 100.0
                directory = os.path.dirname(config.results_path)
                import time
                cur_time = time.time()
                if not os.path.exists(directory):
                    os.makedirs(directory)
                with open(config.results_path, 'a') as file:
                    file.write(f"{iteration} {normalized_eval_score}\n")
                print(f"time_consume:{cur_time-last_time}")
                last_time = cur_time
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if t >= config.offline_iterations and is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate

            normalized_eval_score = normalized * 100.0
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            evaluations.append(normalized_eval_score)
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")

        if t < config.offline_iterations and (t + 1) % config.save_checkpoints_freq == 0:
            if config.checkpoints_path:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint.pt"),
                )
                
        t += 1

@pyrallis.wrap()
def train_pes_td3_bc(config: TrainConfig):
    if config.algo != 'pes-td3-bc':
        return
    env = gym.make(config.env)
    max_episode_steps = env._max_episode_steps

    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)

    #max_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    dataset = qlearning_dataset_with_timeouts(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(
            dataset,
            config.env,
            reward_scale=config.reward_scale,
            reward_bias=config.reward_bias,
        )

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)

    offline_replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    online_replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    sequence_replay_buffer = SequenceReplayBuffer(
        state_dim,
        action_dim,
        config.max_sequence_length,
        config.num_sequence,
        config.env
    )

    offline_replay_buffer.load_d4rl_dataset(dataset)
    sequence_replay_buffer.load_dataset(dataset)
    online_replay_buffer.load_d4rl_dataset(sequence_replay_buffer.get_buffer_data_dict())
    offline_kd_tree = KD_tree(data = sequence_replay_buffer.get_obs_action_concat(), k = 1)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)

    actor = Actor(state_dim, action_dim, max_action).to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)

    critic_1 = Critic(state_dim, action_dim).to(config.device)
    critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)
    critic_2 = Critic(state_dim, action_dim).to(config.device)
    critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "critic_1": critic_1,
        "critic_1_optimizer": critic_1_optimizer,
        "critic_2": critic_2,
        "critic_2_optimizer": critic_2_optimizer,
        "discount": config.discount,
        "tau": config.tau,
        "device": config.device,
        # TD3
        "policy_noise": config.policy_noise * max_action,
        "noise_clip": config.noise_clip * max_action,
        "policy_freq": config.policy_freq,
        # TD3 + BC
        "alpha": config.alpha,
    }

    print("---------------------------------------")
    print(f"Training TD3_BC, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = TD3_BC(**kwargs)

    if config.load_model == True:
        weight_path = config.checkpoints_path + "/checkpoint.pt"
        trainer.load_state_dict(torch.load(weight_path, map_location=config.device))
        actor = trainer.actor


    evaluations = []
    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    episode_data = {
        "observations": [],
        "actions": [],
        "rewards": [],
        "next_observations": [],
        "terminals": []
    }
    goal_achieved = False

    eval_successes = []
    train_successes = []

    if config.load_model == True:
        t = config.offline_iterations
    else:
        t = 0
        print("Offline pretraining")

    while(t < config.offline_iterations + config.online_iterations):
        if t == config.offline_iterations:
            print("Online tuning")
        online_log = {}
        if t >= config.offline_iterations:
            episode_step += 1
            action = actor(
                torch.tensor(
                    state.reshape(1, -1), device=config.device, dtype=torch.float32
                )
            )
            action = action.cpu().data.numpy().flatten()
            next_state, reward, done, env_infos = env.step(action)
            
            real_done = False
            if done and episode_step < max_episode_steps:
                real_done = True

            episode_data["observations"].append(state)
            episode_data["actions"].append(action)
            episode_data["next_observations"].append(next_state)
            episode_data["rewards"].append(reward)
            episode_data["terminals"].append(real_done)

            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)

            episode_return += reward

            if config.normalize_reward:
                reward = modify_reward_online(
                    reward,
                    config.env,
                    reward_scale=config.reward_scale,
                    reward_bias=config.reward_bias,
                    **reward_mod_dict,
                )
            if online_replay_buffer._size < 999:
                online_replay_buffer.add_transition(state, action, reward, next_state, real_done)
                t += 1
                continue
            distance = offline_kd_tree.query(np.concatenate((state, action), axis = -1)).item()
            if distance < (config.threshold_distance + config.threshold_coefficient * t / config.offline_iterations):
                online_replay_buffer.add_transition(state, action, reward, next_state, done)
            state = next_state
            if done or episode_step >= config.max_sequence_length or t + 1 == config.offline_iterations + config.online_iterations:
                episode_data["observations"] = np.array(episode_data["observations"])
                episode_data["actions"] = np.array(episode_data["actions"])
                episode_data["next_observations"] = np.array(episode_data["next_observations"])
                episode_data["rewards"] = np.array(episode_data["rewards"])
                episode_data["terminals"] = np.array(episode_data["terminals"])

                if sequence_replay_buffer.update_top_episodes(episode_data):
                    offline_kd_tree.update(sequence_replay_buffer.get_obs_action_concat())

                state, done = env.reset(), False
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if is_env_with_goal:
                    train_successes.append(goal_achieved)
                    online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                    online_log["train/is_success"] = float(goal_achieved)
                online_log["train/episode_return"] = episode_return
                normalized_return = env.get_normalized_score(episode_return)
                online_log["train/d4rl_normalized_episode_return"] = (
                    normalized_return * 100.0
                )
                online_log["train/episode_length"] = episode_step
                episode_return = 0
                episode_step = 0
                episode_data = {
                "observations": [],
                "actions": [],
                "rewards": [],
                "next_observations": [],
                "terminals": []
                }
                goal_achieved = False

        online_batch_size = int(config.batch_size * config.online_ratio)
        offline_batch_size = config.batch_size - online_batch_size

        assert online_replay_buffer._size >= online_batch_size
        assert offline_replay_buffer._size >= offline_batch_size

        online_batch = online_replay_buffer.sample(online_batch_size)
        offline_batch = offline_replay_buffer.sample(offline_batch_size)

        batch = [torch.cat((online_batch_key, offline_batch_key), dim=0) for online_batch_key, offline_batch_key in zip(online_batch, offline_batch)]
        #batch = offline_replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        if config.heterogeneous:
            log_dict = trainer.heter_train(batch)
        else:
            log_dict = trainer.train(batch)
        log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
            t if t < config.offline_iterations else t - config.offline_iterations
        )
        log_dict.update(online_log)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = env.get_normalized_score(np.mean(eval_scores))
            if t >= config.offline_iterations:
                iteration = t - config.offline_iterations
                normalized_eval_score = normalized * 100.0
                with open(config.results_path, 'a') as file:
                    if is_env_with_goal:
                        file.write(f"{iteration} {success_rate}\n")
                    else:
                        file.write(f"{iteration} {normalized_eval_score}\n")

            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if t >= config.offline_iterations and is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate

            normalized_eval_score = normalized * 100.0
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            evaluations.append(normalized_eval_score)
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")
        
        if t < config.offline_iterations and (t + 1) % config.save_checkpoints_freq == 0:
            if config.checkpoints_path:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint.pt"),
                )
        t += 1

if __name__ == "__main__":
    train_td3_bc()
    train_pes_td3_bc()