import random
import time
import wandb
import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
import utils
from args import Args
from models import QNetwork, EnsembleQNetwork, TreeQNetwork
from sparselearning.core import Masking


def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)

        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)

        env.action_space.seed(seed)
        return env

    return thunk


def evaluate(
    eval_envs: gym.vector.SyncVectorEnv,
    model: torch.nn.Module,
    eval_episodes: int = 10,
    epsilon: float = 0.05,
):
    model.eval()
    obs, _ = eval_envs.reset()
    episodic_returns = []
    episodic_lengths = []
    while len(episodic_returns) < eval_episodes:
        if random.random() < epsilon:
            actions = np.array([eval_envs.single_action_space.sample() for _ in range(eval_envs.num_envs)])
        else:
            with torch.no_grad():
                q_values = model(torch.Tensor(obs).to(device))
                actions = torch.argmax(q_values, dim=1).cpu().numpy()
        next_obs, _, _, _, infos = eval_envs.step(actions)
        if "final_info" in infos:
            for info in infos["final_info"]:
                if "episode" not in info:
                    continue
                episodic_returns += [info["episode"]["r"]]
                episodic_lengths += [info["episode"]["l"]]
        obs = next_obs
    avg_return = np.mean(episodic_returns[:eval_episodes])
    avg_length = np.mean(episodic_lengths[:eval_episodes])
    model.train()
    return avg_return, avg_length


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


if __name__ == "__main__":
    args = tyro.cli(Args)
    assert args.num_envs == 1, "vectorized envs are not supported at the moment"
    run_name = utils.set_exp_name(args)
    print(f'run_name: {run_name}')

    total_train_steps = (args.total_timesteps - args.learning_starts) // args.train_frequency
    print(f'total_timesteps: {args.total_timesteps:,}  learning_starts at {args.learning_starts:,}  '
          f'train_frequency: {args.train_frequency}. So total training steps: {total_train_steps:,}')

    if args.track:
        wandb_mode = "online"
    else:
        wandb_mode = "disabled"

    wandb.init(
        project=args.wandb_project_name,
        entity=args.wandb_entity,
        config=vars(args),
        name=run_name,
        mode=wandb_mode,
        save_code=True,
    )

    # seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.seed + i, i, False, run_name) for i in range(args.num_envs)]
    )
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    eval_envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(1)]
    )

    if args.num_ensemble > 1:
        if args.blocks_in_head == -1:
            q_network = EnsembleQNetwork(envs, args).to(device)
            target_network = EnsembleQNetwork(envs, args).to(device)
            print(f"EnsembleQNetwork with {args.num_ensemble} members.")
        elif args.blocks_in_head > 0:
            q_network = TreeQNetwork(envs, args).to(device)
            target_network = TreeQNetwork(envs, args).to(device)
            print(f"TreeQNetwork with {args.blocks_in_head} blocks in each head. Total blocks: {q_network.total_blocks}. Shared blocks: {q_network.blocks_shared}.")
        else:
            raise ValueError("blocks_in_head must be -1 (full ensemble) or >0 (TreeQNetwork)")
    else:
        q_network = QNetwork(envs, args).to(device)
        target_network = QNetwork(envs, args).to(device)
    target_network.load_state_dict(q_network.state_dict())
    optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)

    mask = None
    if args.density < 1:
        print("Using sparse learning")
        mask = Masking(optimizer, total_train_steps, args=args)
        mask.add_module(q_network, sparse_init=args.sparse_init)

    if not args.joint_sampling:
        sample_q_idx = 0
    else:
        sample_q_idx = None

    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        optimize_memory_usage=True,
        handle_timeout_termination=False,
    )
    start_time = time.time()

    obs, _ = envs.reset(seed=args.seed)
    for global_step in range(args.total_timesteps + 1):

        epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
        if global_step % 1000 == 0:
            wandb.log({"charts/epsilon": epsilon}, step=global_step)

        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            if args.num_ensemble > 1 and not args.joint_sampling:
                q_values = q_network(torch.Tensor(obs).to(device), idx=sample_q_idx)
            else:
                q_values = q_network(torch.Tensor(obs).to(device))
            actions = torch.argmax(q_values, dim=1).cpu().numpy()

        # execute the game and log data.
        next_obs, rewards, terminations, truncations, infos = envs.step(actions)

        # record rewards for plotting purposes
        if "final_info" in infos:
            for info in infos["final_info"]:
                if info and "episode" in info:
                    wandb.log({
                            "charts/raw_episodic_return": info["episode"]["r"][0],
                            "charts/raw_episodic_length": info["episode"]["l"][0],
                        }, step=global_step,
                    )

                    # at episode end, reset the sample_q_idx for independent ensemble DQN
                    # this assumes num_envs=1, adjust code if using multiple envs
                    if args.num_ensemble > 1 and not args.joint_sampling:
                        sample_q_idx = (sample_q_idx + 1) % args.num_ensemble

        # save data to reply buffer; handle `final_observation`
        real_next_obs = next_obs.copy()
        for idx, trunc in enumerate(truncations):
            if trunc:
                real_next_obs[idx] = infos["final_observation"][idx]
        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

        # crucial step easy to overlook
        obs = next_obs

        # training
        if global_step > args.learning_starts:
            if global_step % args.train_frequency == 0:
                data = rb.sample(args.batch_size)

                # Compute targets
                with torch.no_grad():
                    if args.num_ensemble > 1 and not args.joint_training:
                        # independent training
                        qs_target = target_network.get_all(data.next_observations)
                        td_targets = []
                        for i in range(args.num_ensemble):
                            target_max_i, _ = qs_target[i].max(dim=1)
                            td_target_i = data.rewards.flatten() + args.gamma * target_max_i * (1 - data.dones.flatten())
                            td_targets.append(td_target_i)
                    else:
                        # joint training (or single DQN)
                        target_max, _ = target_network(data.next_observations).max(dim=1)
                        td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())

                # Compute loss
                if args.num_ensemble > 1 and not args.joint_training:
                    # independent training
                    loss = 0
                    qs = q_network.get_all(data.observations)
                    cur_val = torch.stack(qs, dim=1).mean(dim=1)  # only for logging
                    for i in range(args.num_ensemble):
                        cur_val_i = qs[i].gather(1, data.actions).squeeze()
                        loss_i = F.mse_loss(td_targets[i], cur_val_i)
                        loss += loss_i
                else:
                    # joint training (or single DQN)
                    cur_val = q_network(data.observations).gather(1, data.actions).squeeze()
                    loss = F.mse_loss(td_target, cur_val)


                if global_step % 1000 == 0:
                    log_dict = {
                        "losses/td_loss": loss,
                        "losses/q_values": cur_val.mean().item(),
                        "charts/SPS": int(global_step / (time.time() - start_time)),
                    }
                    if args.num_ensemble > 1:  # log the individual q values
                        for i in range(args.num_ensemble):
                            cur_val_i = q_network(data.observations, idx=i).gather(1, data.actions).squeeze()
                            log_dict[f"losses/q_values_{i}"] = cur_val_i.mean().item()
                    wandb.log(log_dict, step=global_step)

                # optimize the model
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if args.density < 1:
                    mask.step()

            # update target network
            if global_step % args.target_network_frequency == 0:
                for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
                    target_network_param.data.copy_(
                        args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
                    )

        if global_step % args.eval_freq == 0:
            avg_return, avg_length = evaluate(eval_envs, q_network, eval_episodes=10)
            wandb.log({
                    "eval/avg_episodic_return": avg_return,
                    "eval/avg_episodic_length": avg_length,
                }, step=global_step,
            )
            print(f"eval at global_step={global_step}, avg_return={avg_return}, avg_length={avg_length}")

    if args.save_model:
        model_path = f"runs/{run_name}/weights.model"
        torch.save(q_network.state_dict(), model_path)
        print(f"Model saved to {model_path}")
    envs.close()
