import os
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path

import d4rl
import gym
import numpy as np
import pyrallis
import torch

from config import TrainConfig, ENVS_WITH_GOAL
from utils import modify_reward, modify_reward_online, is_goal_reached, compute_mean_std, normalize_states, wandb_init, wrap_env, set_env_seed, set_seed, eval_actor
from buffer import ReplayBuffer, SequenceReplayBuffer
from network import FullyConnectedQFunction, TanhGaussianPolicy
from algo import ContinuousCQL
from utils import KD_tree, qlearning_dataset_with_timeouts


@pyrallis.wrap()
def train_cql(config: TrainConfig):
    if not config.algo=='cql':
        return
    env = gym.make(config.env)
    eval_env = gym.make(config.env)

    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)

    max_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    #dataset = d4rl.qlearning_dataset(env)
    dataset = qlearning_dataset_with_timeouts(env)
    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(
            dataset,
            config.env,
            reward_scale=config.reward_scale,
            reward_bias=config.reward_bias,
        )

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_dataset(dataset)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)
    set_env_seed(eval_env, config.eval_seed)

    critic_1 = FullyConnectedQFunction(
        state_dim,
        action_dim,
        config.orthogonal_init,
        config.q_n_hidden_layers,
    ).to(config.device)
    critic_2 = FullyConnectedQFunction(
        state_dim,
        action_dim,
        config.orthogonal_init,
        config.q_n_hidden_layers,
    ).to(config.device)
    critic_1_optimizer = torch.optim.Adam(list(critic_1.parameters()), config.qf_lr)
    critic_2_optimizer = torch.optim.Adam(list(critic_2.parameters()), config.qf_lr)

    actor = TanhGaussianPolicy(
        state_dim, action_dim, max_action, orthogonal_init=config.orthogonal_init
    ).to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), config.policy_lr)

    kwargs = {
        "critic_1": critic_1,
        "critic_2": critic_2,
        "critic_1_optimizer": critic_1_optimizer,
        "critic_2_optimizer": critic_2_optimizer,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "discount": config.discount,
        "soft_target_update_rate": config.soft_target_update_rate,
        "device": config.device,
        # CQL
        "target_entropy": -np.prod(env.action_space.shape).item(),
        "alpha_multiplier": config.alpha_multiplier,
        "use_automatic_entropy_tuning": config.use_automatic_entropy_tuning,
        "backup_entropy": config.backup_entropy,
        "policy_lr": config.policy_lr,
        "qf_lr": config.qf_lr,
        "bc_steps": config.bc_steps,
        "target_update_period": config.target_update_period,
        "cql_n_actions": config.cql_n_actions,
        "cql_importance_sample": config.cql_importance_sample,
        "cql_lagrange": config.cql_lagrange,
        "cql_target_action_gap": config.cql_target_action_gap,
        "cql_temp": config.cql_temp,
        "cql_alpha": config.cql_alpha,
        "cql_max_target_backup": config.cql_max_target_backup,
        "cql_clip_diff_min": config.cql_clip_diff_min,
        "cql_clip_diff_max": config.cql_clip_diff_max,
    }

    print("---------------------------------------")
    print(f"Training CQL, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = ContinuousCQL(**kwargs)
    
    t = 0

    if config.load_model == True:
        weight_path = config.checkpoints_path + "/checkpoint.pt"
        trainer.load_state_dict(torch.load(weight_path, map_location=config.device))
        actor = trainer.actor
        t = int(config.offline_iterations)

    evaluations = []
    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    goal_achieved = False

    eval_successes = []
    train_successes = []
    
    last_time = 0
    cur_time = 0

    print("Offline pretraining")
    while(t < int(config.offline_iterations) + int(config.online_iterations)):
        if t == config.offline_iterations:
            print("Online tuning")
            trainer.cql_alpha = config.cql_alpha_online
        online_log = {}
        if t >= config.offline_iterations:
            episode_step += 1
            action, _ = actor(
                torch.tensor(
                    state.reshape(1, -1), device=config.device, dtype=torch.float32
                )
            )
            action = action.cpu().data.numpy().flatten()
            next_state, reward, done, env_infos = env.step(action)
            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)

            episode_return += reward
            real_done = False  # Episode can timeout which is different from done
            if done and episode_step < max_steps:
                real_done = True

            if config.normalize_reward:
                reward = modify_reward_online(
                    reward,
                    config.env,
                    reward_scale=config.reward_scale,
                    reward_bias=config.reward_bias,
                    **reward_mod_dict,
                )
            replay_buffer.add_transition(state, action, reward, next_state, real_done)
            state = next_state

            if done:
                state, done = env.reset(), False
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if is_env_with_goal:
                    train_successes.append(goal_achieved)
                    online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                    online_log["train/is_success"] = float(goal_achieved)
                online_log["train/episode_return"] = episode_return
                normalized_return = eval_env.get_normalized_score(episode_return)
                online_log["train/d4rl_normalized_episode_return"] = (
                    normalized_return * 100.0
                )
                online_log["train/episode_length"] = episode_step
                episode_return = 0
                episode_step = 0
                goal_achieved = False

        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        if config.heterogeneous:
            log_dict = trainer.heter_train(batch)
        else:
            log_dict = trainer.train(batch)
        log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
            t if t < config.offline_iterations else t - config.offline_iterations
        )
        log_dict.update(online_log)
        #wandb.log(log_dict, step=trainer.total_it)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                eval_env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = eval_env.get_normalized_score(np.mean(eval_scores))
            normalized_success_rate = success_rate * 100
            if t >= config.offline_iterations:
                iteration = t - config.offline_iterations
                normalized_eval_score = normalized * 100.0
                import time
                cur_time = time.time()
                with open(config.results_path, 'a') as file:
                    if is_env_with_goal:
                        file.write(f"{iteration} {normalized_success_rate}\n")
                    else:
                        file.write(f"{iteration} {normalized_eval_score}\n")
                print(f"time_consume:{cur_time-last_time}")
                last_time = cur_time

            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if t >= config.offline_iterations and is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate

            normalized_eval_score = normalized * 100.0
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            evaluations.append(normalized_eval_score)
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")
            #wandb.log(eval_log, step=trainer.total_it)

        if t <= config.offline_iterations and (t + 1) % config.save_checkpoints_freq == 0:
            if config.checkpoints_path:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint.pt"),
                )
        t += 1
        
@pyrallis.wrap()
def train_pes_cql(config: TrainConfig):
    if not config.algo=='pes-cql':
        return
    env = gym.make(config.env)
    eval_env = gym.make(config.env)
    
    max_episode_steps = env._max_episode_steps

    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)

    #max_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = qlearning_dataset_with_timeouts(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(
            dataset,
            config.env,
            reward_scale=config.reward_scale,
            reward_bias=config.reward_bias,
        )

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)
    offline_replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    online_replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    sequence_replay_buffer = SequenceReplayBuffer(
        state_dim,
        action_dim,
        config.max_sequence_length,
        config.num_sequence,
        config.env
    )

    offline_replay_buffer.load_dataset(dataset)
    sequence_replay_buffer.load_dataset(dataset)
    online_replay_buffer.load_dataset(sequence_replay_buffer.get_buffer_data_dict())
    offline_kd_tree = KD_tree(data = sequence_replay_buffer.get_obs_action_concat(), k = 1)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)
    set_env_seed(eval_env, config.eval_seed)

    critic_1 = FullyConnectedQFunction(
        state_dim,
        action_dim,
        config.orthogonal_init,
        config.q_n_hidden_layers,
    ).to(config.device)
    critic_2 = FullyConnectedQFunction(
        state_dim,
        action_dim,
        config.orthogonal_init,
        config.q_n_hidden_layers,
    ).to(config.device)
    critic_1_optimizer = torch.optim.Adam(list(critic_1.parameters()), config.qf_lr)
    critic_2_optimizer = torch.optim.Adam(list(critic_2.parameters()), config.qf_lr)

    actor = TanhGaussianPolicy(
        state_dim, action_dim, max_action, orthogonal_init=config.orthogonal_init
    ).to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), config.policy_lr)

    kwargs = {
        "critic_1": critic_1,
        "critic_2": critic_2,
        "critic_1_optimizer": critic_1_optimizer,
        "critic_2_optimizer": critic_2_optimizer,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "discount": config.discount,
        "soft_target_update_rate": config.soft_target_update_rate,
        "device": config.device,
        # CQL
        "target_entropy": -np.prod(env.action_space.shape).item(),
        "alpha_multiplier": config.alpha_multiplier,
        "use_automatic_entropy_tuning": config.use_automatic_entropy_tuning,
        "backup_entropy": config.backup_entropy,
        "policy_lr": config.policy_lr,
        "qf_lr": config.qf_lr,
        "bc_steps": config.bc_steps,
        "target_update_period": config.target_update_period,
        "cql_n_actions": config.cql_n_actions,
        "cql_importance_sample": config.cql_importance_sample,
        "cql_lagrange": config.cql_lagrange,
        "cql_target_action_gap": config.cql_target_action_gap,
        "cql_temp": config.cql_temp,
        "cql_alpha": config.cql_alpha,
        "cql_max_target_backup": config.cql_max_target_backup,
        "cql_clip_diff_min": config.cql_clip_diff_min,
        "cql_clip_diff_max": config.cql_clip_diff_max,
    }
    

    print("---------------------------------------")
    print(f"Training CQL, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = ContinuousCQL(**kwargs)

    if config.load_model == True:
        weight_path = config.checkpoints_path + "/checkpoint.pt"
        trainer.load_state_dict(torch.load(weight_path, map_location=config.device))
        actor = trainer.actor

    #wandb_init(asdict(config))

    evaluations = []
    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    episode_data = {
        "observations": [],
        "actions": [],
        "rewards": [],
        "next_observations": [],
        "terminals": []
    }
    goal_achieved = False

    eval_successes = []
    train_successes = []

    if config.load_model == True:
        t = config.offline_iterations
    else:
        t = 0
        print("Offline pretraining")

    while(t < config.offline_iterations + config.online_iterations):
        if t == config.offline_iterations:
            print("Online tuning")
            trainer.cql_alpha = config.cql_alpha_online
        online_log = {}
        if t >= config.offline_iterations:
            episode_step += 1
            action, _ = actor(
                torch.tensor(
                    state.reshape(1, -1), device=config.device, dtype=torch.float32
                )
            )
            action = action.cpu().data.numpy().flatten()
            next_state, reward, done, env_infos = env.step(action)
            
            real_done = False
            if done and episode_step < max_episode_steps:
                real_done = True

            episode_data["observations"].append(state)
            episode_data["actions"].append(action)
            episode_data["next_observations"].append(next_state)
            episode_data["rewards"].append(reward)
            episode_data["terminals"].append(real_done)

            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)

            episode_return += reward

            if config.normalize_reward:
                reward = modify_reward_online(
                    reward,
                    config.env,
                    reward_scale=config.reward_scale,
                    reward_bias=config.reward_bias,
                    **reward_mod_dict,
                )
            if online_replay_buffer._size < 999:
                online_replay_buffer.add_transition(state, action, reward, next_state, real_done)
                t += 1
                continue
            distance = offline_kd_tree.query(np.concatenate((state, action), axis = -1)).item()
            if distance < (config.threshold_distance + config.threshold_coefficient * t / config.offline_iterations):
                online_replay_buffer.add_transition(state, action, reward, next_state, real_done)
            state = next_state
            if done or episode_step >= config.max_sequence_length or t + 1 == config.offline_iterations + config.online_iterations:
                episode_data["observations"] = np.array(episode_data["observations"])
                episode_data["actions"] = np.array(episode_data["actions"])
                episode_data["next_observations"] = np.array(episode_data["next_observations"])
                episode_data["rewards"] = np.array(episode_data["rewards"])
                episode_data["terminals"] = np.array(episode_data["terminals"])

                if sequence_replay_buffer.update_top_episodes(episode_data):
                    offline_kd_tree.update(sequence_replay_buffer.get_obs_action_concat())

                state, done = env.reset(), False
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if is_env_with_goal:
                    train_successes.append(goal_achieved)
                    online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                    online_log["train/is_success"] = float(goal_achieved)
                online_log["train/episode_return"] = episode_return
                normalized_return = eval_env.get_normalized_score(episode_return)
                online_log["train/d4rl_normalized_episode_return"] = (
                    normalized_return * 100.0
                )
                online_log["train/episode_length"] = episode_step
                episode_return = 0
                episode_step = 0
                episode_data = {
                "observations": [],
                "actions": [],
                "rewards": [],
                "next_observations": [],
                "terminals": []
                }
                goal_achieved = False

        online_batch_size = int(config.batch_size * config.online_ratio)
        offline_batch_size = config.batch_size - online_batch_size

        assert online_replay_buffer._size >= online_batch_size
        assert offline_replay_buffer._size >= offline_batch_size

        online_batch = online_replay_buffer.sample(online_batch_size)
        offline_batch = offline_replay_buffer.sample(offline_batch_size)

        batch = [torch.cat((online_batch_key, offline_batch_key), dim=0) for online_batch_key, offline_batch_key in zip(online_batch, offline_batch)]
        #batch = offline_replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        if config.heterogeneous: 
            log_dict = trainer.heter_train(batch)
        else:
            log_dict = trainer.train(batch)
        log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
            t if t < config.offline_iterations else t - config.offline_iterations
        )
        log_dict.update(online_log)
        #wandb.log(log_dict, step=trainer.total_it)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                eval_env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = eval_env.get_normalized_score(np.mean(eval_scores))
            if t >= config.offline_iterations:
                iteration = t - config.offline_iterations
                normalized_eval_score = normalized * 100.0
                with open(config.results_path, 'a') as file:
                    file.write(f"{iteration} {normalized_eval_score}\n")
                    
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if t >= config.offline_iterations and is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate

            normalized_eval_score = normalized * 100.0
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            evaluations.append(normalized_eval_score)
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")
            #wandb.log(eval_log, step=trainer.total_it)

        if t < config.offline_iterations and (t + 1) % config.save_checkpoints_freq == 0:
            if config.checkpoints_path:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint.pt"),
                )
        t += 1