from argparse import Namespace

import numpy as np
import pytorch_lightning as pl
import torch as t

from ..simulators import SIMULATOR
from ..solvers import get_solver


class BaseTrainer(pl.LightningModule):
    def __init__(
        self,
        args: Namespace,
    ):
        super().__init__()

        self.solver = get_solver(args)
        self.learning_rate = args.learning_rate
        self.evaluation_frequency = args.evaluation_frequency
        self.n_evaluation_runs = args.n_evaluation_runs

    def compute_loss(
        self,
        states,
        actions,
        rewards,
        values,
    ):
        raise NotImplementedError

    def training_step(
        self,
        batch,
    ):
        if self.global_step % self.evaluation_frequency == 0:
            self.evaluate()

        loss, train_metrics = self.compute_loss(*batch)

        for k, v in train_metrics.items():
            self.log(f"train/{k}", v)

        return loss

    def evaluate(self):
        if self.global_step == 0:
            self.log("eval/reward", SIMULATOR.failed)
            self.log("eval/success_rate", 0.0)
            return

        self.solver.eval()
        with t.inference_mode():
            alive = list(range(self.n_evaluation_runs))
            episode_rewards = {env_id: 0 for env_id in alive}
            success = {env_id: False for env_id in alive}
            all_envs = {env_id: SIMULATOR() for env_id in alive}

            [all_envs[env_id].reset() for env_id in alive]

            for ts in range(SIMULATOR.max_steps):
                states = t.cat(
                    [all_envs[env_id].state_tensor() for env_id in alive],
                    dim=0,
                ).to(self.device)
                actions = self.solver.predict(states)

                dead = []
                for env_id, a in zip(alive, actions):
                    next_state, reward, terminal = all_envs[env_id].step(int(a))
                    episode_rewards[env_id] += reward

                    if terminal:
                        success[env_id] = all_envs[env_id].is_solved()
                        dead.append(env_id)

                for env_id in dead:
                    alive.remove(env_id)

                if len(alive) == 0:
                    break

        self.log("eval/reward", float(np.mean(list(episode_rewards.values()))))
        self.log("eval/success_rate", float(np.mean(list(success.values()))))
        self.solver.train()

    def configure_optimizers(self):
        return t.optim.AdamW(self.solver.parameters(), lr=self.learning_rate)
