import os
import tyro
from config import Args


def main():
    env_name = args.env_name
    n_envs = args.n_envs
    n_envs_eval = args.n_envs_eval
    n_steps = args.n_steps
    h_dim = args.h_dim

    # n_envs = 2
    # n_envs_eval = 0
    # n_steps = 256
    # h_dim = 128

    device = args.device
    algo = "ipl_rule_based"
    # algo = "ipl_llm_based"
    logdir = f"logs/{algo}/{env_name}"

    print(f"Environment: {env_name}")
    os.makedirs(logdir, exist_ok=True)
    
    from envs import VectorizedEnv

    envs = VectorizedEnv(env_name, n_envs)
    envs_eval = VectorizedEnv(env_name, n_envs_eval, seed=42) if n_envs_eval > 0 else None

    ob_dim, st_dim, ac_dim, n_agents, n_enemies, nf_al, nf_en = envs.get_env_infos()
    n_batches = n_envs * n_agents

    args.ob_dim = ob_dim
    args.st_dim = st_dim
    args.ac_dim = ac_dim
    args.n_agents = n_agents
    args.n_enemies = n_enemies
    args.nf_al = nf_al
    args.nf_en = nf_en
    
    import torch
    from runner import Runner
    from trainer import Trainer
    from policy import Actor, Critic
    from ipl_iql import IPL_IQL

    torch.set_num_threads(1)
    torch.set_float32_matmul_precision('high')

    actor = Actor(ob_dim, ac_dim, h_dim=h_dim)
    critic = Critic(st_dim, h_dim=h_dim)
    reward_net = IPL_IQL(ob_dim, st_dim, ac_dim, n_agents, h_dim)
    runner = Runner(envs, args, device=device)
    trainer = Trainer(actor, critic, reward_net, logdir, args, device=device)

    actor_rnn_state = (torch.zeros(1, n_batches, h_dim, device=device), torch.zeros(1, n_batches, h_dim, device=device))
    critic_rnn_state = (torch.zeros(1, n_envs, h_dim, device=device), torch.zeros(1, n_envs, h_dim, device=device))

    for step in range(5000):
        lr_now = 5e-4 * min(1.0, step / 100)
        for param_group in trainer.actor_optim.param_groups:
            param_group['lr'] = lr_now
        for param_group in trainer.critic_optim.param_groups:
            param_group['lr'] = lr_now

        if step % 25 == 0 and envs_eval is not None:
            torch.cuda.empty_cache()
            dead_allies, dead_enemies, winrates = runner.evaluate(envs_eval, actor, n_episodes=32)
            print(f"Eval - Game: {dead_allies:.3f}/{dead_enemies:.3f} - Winrate: {winrates:.3f}")
            trainer.writer.add_scalar('eval/dead_allies', dead_allies, step)
            trainer.writer.add_scalar('eval/dead_enemies', dead_enemies, step)
            trainer.writer.add_scalar('eval/winrates', winrates, step)

            actor.save(f"{logdir}/actor.pt")
            critic.save(f"{logdir}/critic.pt")
            torch.save(reward_net.state_dict(), f"{logdir}/reward_net.pt")
    
        tau = 0.001
        for target_param, param in zip(reward_net.target_Q_net.parameters(), reward_net.eval_Q_net.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)
        for target_param, param in zip(reward_net.target_mix_net.parameters(), reward_net.eval_mix_net.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

        data = runner.collect(actor, n_steps, actor_rnn_state, desc=f"Step {step+1} - Collecting ...", verbose=False)
        pg_loss, entropy_loss, vf_loss, ipl_loss, actor_rnn_state, critic_rnn_state, info = trainer.update(data, actor_rnn_state, critic_rnn_state)
        dead_allies, dead_enemies, winrates = info
        print(f"Step {step+1} - Loss: {pg_loss:.4f}/{vf_loss:.4f}/{entropy_loss:.4f} - Game: {dead_allies:.3f}/{dead_enemies:.3f} - Winrate: {winrates:.3f} - IPL: {ipl_loss:.4f}")
        trainer.writer.add_scalar('loss/pg_loss', pg_loss, step)
        trainer.writer.add_scalar('loss/vf_loss', vf_loss, step)
        trainer.writer.add_scalar('loss/entropy_loss', entropy_loss, step)
        trainer.writer.add_scalar('loss/ipl_loss', ipl_loss, step)
        trainer.writer.add_scalar('game/dead_allies', dead_allies, step)
        trainer.writer.add_scalar('game/dead_enemies', dead_enemies, step)
        trainer.writer.add_scalar('game/winrates', winrates, step)

    envs.close()


if __name__ == "__main__":
    args = tyro.cli(Args)
    main()