import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim.lr_scheduler import CosineAnnealingLR
from config import TrainConfig
from utils import modify_reward, compute_mean_std, normalize_states, wrap_env, set_seed, set_env_seed,is_goal_reached, modify_reward_online, eval_actor
from buffer import ReplayBuffer, SequenceReplayBuffer
from network import TwinQ, ValueFunction, DeterministicPolicy, GaussianPolicy
from algo import ImplicitQLearning
from utils import KD_tree, qlearning_dataset_with_timeouts

TensorBatch = List[torch.Tensor]

EXP_ADV_MAX = 100.0
LOG_STD_MIN = -20.0
LOG_STD_MAX = 2.0
ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")

@pyrallis.wrap()
def train_iql(config: TrainConfig):
    if config.algo != 'iql':
        return
    env = gym.make(config.env)

    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)

    max_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = qlearning_dataset_with_timeouts(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(dataset, config.env)

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_dataset(dataset)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)

    q_network = TwinQ(state_dim, action_dim).to(config.device)
    v_network = ValueFunction(state_dim).to(config.device)
    actor = (
        DeterministicPolicy(
            state_dim, action_dim, max_action, dropout=config.actor_dropout
        )
        if config.iql_deterministic
        else GaussianPolicy(
            state_dim, action_dim, max_action, dropout=config.actor_dropout
        )
    ).to(config.device)
    v_optimizer = torch.optim.Adam(v_network.parameters(), lr=config.vf_lr)
    q_optimizer = torch.optim.Adam(q_network.parameters(), lr=config.qf_lr)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "q_network": q_network,
        "q_optimizer": q_optimizer,
        "v_network": v_network,
        "v_optimizer": v_optimizer,
        "discount": config.discount,
        "tau": config.tau,
        "device": config.device,
        # IQL
        "beta": config.beta,
        "iql_tau": config.iql_tau,
        "max_steps": config.offline_iterations,
    }

    print("---------------------------------------")
    print(f"Training IQL, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = ImplicitQLearning(**kwargs)

    t = 0

    if config.load_model == True:
        weight_path = config.checkpoints_path + "/checkpoint.pt"
        trainer.load_state_dict(torch.load(weight_path))
        actor = trainer.actor
        t = config.offline_iterations

    evaluations = []

    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    goal_achieved = False

    eval_successes = []
    train_successes = []
    
    last_time = 0
    cur_time = 0

    while(t < int(config.offline_iterations) + int(config.online_iterations)):
        if t == config.offline_iterations:
            print("Online tuning")
        online_log = {}
        if t >= config.offline_iterations:
            episode_step += 1
            action = actor(
                torch.tensor(
                    state.reshape(1, -1), device=config.device, dtype=torch.float32
                )
            )
            if not config.iql_deterministic:
                action = action.sample()
            else:
                noise = (torch.randn_like(action) * config.expl_noise).clamp(
                    -config.noise_clip, config.noise_clip
                )
                action += noise
            action = torch.clamp(max_action * action, -max_action, max_action)
            action = action.cpu().data.numpy().flatten()
            next_state, reward, done, env_infos = env.step(action)

            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)
            episode_return += reward

            real_done = False  # Episode can timeout which is different from done
            if done and episode_step < max_steps:
                real_done = True

            if config.normalize_reward:
                reward = modify_reward_online(reward, config.env, **reward_mod_dict)

            replay_buffer.add_transition(state, action, reward, next_state, real_done)
            state = next_state
            if done:
                state, done = env.reset(), False
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if is_env_with_goal:
                    train_successes.append(goal_achieved)
                    online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                    online_log["train/is_success"] = float(goal_achieved)
                online_log["train/episode_return"] = episode_return
                normalized_return = env.get_normalized_score(episode_return)
                online_log["train/d4rl_normalized_episode_return"] = (
                    normalized_return * 100.0
                )
                online_log["train/episode_length"] = episode_step
                episode_return = 0
                episode_step = 0
                goal_achieved = False

        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]

        log_dict = trainer.train(batch)
        log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
            t if t < config.offline_iterations else t - config.offline_iterations
        )
        log_dict.update(online_log)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = env.get_normalized_score(eval_score)
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            
            if t >= config.offline_iterations:
                iteration = t - config.offline_iterations
                normalized_eval_score = normalized * 100.0
                normalized_success_rate = success_rate * 100.0
                import time
                cur_time = time.time()
                with open(config.results_path, 'a') as file:
                    if is_env_with_goal:
                        file.write(f"{iteration} {normalized_success_rate}\n")
                    else:
                        file.write(f"{iteration} {normalized_eval_score}\n")
                print(f"time_consume:{cur_time-last_time}")
                last_time = cur_time
            if t >= config.offline_iterations and is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate
            normalized_eval_score = normalized * 100.0
            evaluations.append(normalized_eval_score)
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")

        if t < config.offline_iterations and (t + 1) % config.save_checkpoints_freq == 0:
            if config.checkpoints_path:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint.pt"),
                )
        t += 1
                
@pyrallis.wrap()
def train_pes_iql(config: TrainConfig):
    if config.algo != 'pes-iql':
        return
    env = gym.make(config.env)

    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)

    max_episode_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = qlearning_dataset_with_timeouts(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(dataset, config.env)

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    
    offline_replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    online_replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    sequence_replay_buffer = SequenceReplayBuffer(
        state_dim,
        action_dim,
        config.max_sequence_length,
        config.num_sequence,
        config.env
    )

    offline_replay_buffer.load_dataset(dataset)
    sequence_replay_buffer.load_dataset(dataset)
    online_replay_buffer.load_dataset(sequence_replay_buffer.get_buffer_data_dict())
    offline_kd_tree = KD_tree(data = sequence_replay_buffer.get_obs_action_concat(), k = 1)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)

    q_network = TwinQ(state_dim, action_dim).to(config.device)
    v_network = ValueFunction(state_dim).to(config.device)
    actor = (
        DeterministicPolicy(
            state_dim, action_dim, max_action, dropout=config.actor_dropout
        )
        if config.iql_deterministic
        else GaussianPolicy(
            state_dim, action_dim, max_action, dropout=config.actor_dropout
        )
    ).to(config.device)
    v_optimizer = torch.optim.Adam(v_network.parameters(), lr=config.vf_lr)
    q_optimizer = torch.optim.Adam(q_network.parameters(), lr=config.qf_lr)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "q_network": q_network,
        "q_optimizer": q_optimizer,
        "v_network": v_network,
        "v_optimizer": v_optimizer,
        "discount": config.discount,
        "tau": config.tau,
        "device": config.device,
        # IQL
        "beta": config.beta,
        "iql_tau": config.iql_tau,
        "max_steps": config.offline_iterations,
    }

    print("---------------------------------------")
    print(f"Training IQL, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = ImplicitQLearning(**kwargs)

    if config.load_model == True:
        weight_path = config.checkpoints_path + "/checkpoint.pt"
        trainer.load_state_dict(torch.load(weight_path,map_location=config.device))
        actor = trainer.actor

    evaluations = []
    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    add_num = 0
    episode_data = {
        "observations": [],
        "actions": [],
        "rewards": [],
        "next_observations": [],
        "terminals": []
    }
    goal_achieved = False

    eval_successes = []
    train_successes = []

    if config.load_model == True:
        t = config.offline_iterations
    else:
        t = 0
        print("Offline pretraining")

    while(t < config.offline_iterations + config.online_iterations):
        if t == config.offline_iterations:
            print("Online tuning")
        online_log = {}
        if t >= config.offline_iterations:
            episode_step += 1
            action = actor(
                torch.tensor(
                    state.reshape(1, -1), device=config.device, dtype=torch.float32
                )
            )
            if not config.iql_deterministic:
                action = action.sample()
            else:
                noise = (torch.randn_like(action) * config.expl_noise).clamp(
                    -config.noise_clip, config.noise_clip
                )
                action += noise
            action = torch.clamp(max_action * action, -max_action, max_action)
            action = action.cpu().data.numpy().flatten()
            next_state, reward, done, env_infos = env.step(action)
            
            real_done = False
            if done and episode_step < max_episode_steps:
                real_done = True

            episode_data["observations"].append(state)
            episode_data["actions"].append(action)
            episode_data["next_observations"].append(next_state)
            episode_data["rewards"].append(reward)
            episode_data["terminals"].append(real_done)

            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)

            episode_return += reward

            if config.normalize_reward:
                reward = modify_reward_online(
                    reward,
                    config.env,
                    **reward_mod_dict,
                )
            
            if online_replay_buffer._size < 999:
                online_replay_buffer.add_transition(state, action, reward, next_state, real_done)
                t += 1
                continue

            distance = offline_kd_tree.query(np.concatenate((state, action), axis = -1)).item()
            if distance < (config.threshold_distance + config.threshold_coefficient * t / config.offline_iterations):
                online_replay_buffer.add_transition(state, action, reward, next_state, real_done)
                add_num += 1
            state = next_state
            if done or episode_step >= config.max_sequence_length or t + 1 == config.offline_iterations + config.online_iterations:
                episode_data["observations"] = np.array(episode_data["observations"])
                episode_data["actions"] = np.array(episode_data["actions"])
                episode_data["next_observations"] = np.array(episode_data["next_observations"])
                episode_data["rewards"] = np.array(episode_data["rewards"])
                episode_data["terminals"] = np.array(episode_data["terminals"])

                if sequence_replay_buffer.update_top_episodes(episode_data):
                    offline_kd_tree.update(sequence_replay_buffer.get_obs_action_concat())

                state, done = env.reset(), False
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if is_env_with_goal:
                    train_successes.append(goal_achieved)
                    online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                    online_log["train/is_success"] = float(goal_achieved)
                online_log["train/episode_return"] = episode_return
                normalized_return = env.get_normalized_score(episode_return)
                online_log["train/d4rl_normalized_episode_return"] = (
                    normalized_return * 100.0
                )
                online_log["train/episode_length"] = episode_step
                episode_return = 0
                episode_step = 0
                episode_data = {
                "observations": [],
                "actions": [],
                "rewards": [],
                "next_observations": [],
                "terminals": []
                }
                goal_achieved = False

        online_batch_size = int(config.batch_size * config.online_ratio)
        offline_batch_size = config.batch_size - online_batch_size

        assert online_replay_buffer._size >= online_batch_size
        assert offline_replay_buffer._size >= offline_batch_size

        online_batch = online_replay_buffer.sample(online_batch_size)
        offline_batch = offline_replay_buffer.sample(offline_batch_size)

        batch = [torch.cat((online_batch_key, offline_batch_key), dim=0) for online_batch_key, offline_batch_key in zip(online_batch, offline_batch)]
        #batch = offline_replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        log_dict = trainer.train(batch)
        log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
            t if t < config.offline_iterations else t - config.offline_iterations
        )
        log_dict.update(online_log)
        #wandb.log(log_dict, step=trainer.total_it)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = env.get_normalized_score(np.mean(eval_scores))
            if t >= config.offline_iterations:
                iteration = t - config.offline_iterations
                normalized_eval_score = normalized * 100.0

                with open(config.results_path, 'a') as file:
                    file.write(f"{iteration} {normalized_eval_score} {add_num / iteration}\n")
                    
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if t >= config.offline_iterations and is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate

            normalized_eval_score = normalized * 100.0
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            evaluations.append(normalized_eval_score)
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")
            #wandb.log(eval_log, step=trainer.total_it)

        if t < config.offline_iterations and (t + 1) % config.save_checkpoints_freq == 0:
            if config.checkpoints_path:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint.pt"),
                )
        t += 1

if __name__ == "__main__":
    train_iql()
    train_pes_iql()
