from typing import Optional, Dict, Tuple, List, Literal


import numpy as np
import torch
import torch.nn as nn

from .train import Train
import utils
from utils.typings import ObjectArrays, ObjectTensors, NamedTensors, EnvModel


from alg.model.reward_predictor import OORewardPredictor, MLPRewardPredictor, RewardPredictor
from ._get_env import get_config


class TrainReward(Train):
    @classmethod
    def set_parser(cls, parser):
        super().set_parser(parser)
        parser.add_argument("--network", type=str, choices=['mlp', 'oo'], default='oo',
            help="which network is used as the reward predictor.")
    
    def __make_mlp_reward_predictor(self, argpath: Optional[str]):
        args = get_config(self.env_args.env_id).mlpreward_args
        if args is None:
            args = MLPRewardPredictor.Args()
        if argpath is not None:
            args.load(argpath)
        self.save_args('model', args)
        return MLPRewardPredictor(self.taskinfo, args, self.device, self.dtype)
    
    def __make_oo_reward_predictor(self, argpath: Optional[str]):
        args = get_config(self.env_args.env_id).ooreward_args
        if args is None:
            args = OORewardPredictor.Args()
        if argpath is not None:
            args.load(argpath)
        self.save_args('model', args)
        return OORewardPredictor(self.envinfo, args, self.device, self.dtype)

    def setup(self, args):
        super().setup(args)

        self.model: RewardPredictor
        if args.network == 'mlp':
            self.model = self.__make_mlp_reward_predictor(args.model_args)
        elif args.network == 'oo':
            self.model = self.__make_oo_reward_predictor(args.model_args)
        self.add_network('reward_predictor', self.model)
        self.save_args("network", args.network)
    
    def __loglikeli_r(self, model: RewardPredictor, raw_attributes: ObjectTensors,
                      reward: torch.Tensor, obj_mask: NamedTensors):
        r = model.forward(raw_attributes, obj_mask)
        ll_r = torch.mean(r.log_prob(reward))
        return ll_r

    def fit_batch(self, attrs: ObjectTensors, next_state: ObjectTensors, 
                    obj_mask: NamedTensors, reward: torch.Tensor, eval=False
                    ) -> Tuple[float, ...]:
        model = self.get_network('reward_predictor')
        assert isinstance(model, RewardPredictor)

        ll_r = self.__loglikeli_r(model, attrs, reward, obj_mask)
        loss = -ll_r
        
        if not eval:
            loss.backward()
            self.optim_step('reward_predictor')

        return float(loss), float(ll_r)
    
    def log_batch(self, log: utils.Log, *scalars: float):
        loss, r = scalars
        log(loss)
        log['r'] = r
    
    def print_batch(self, i_batch: int, n_batch: int, *scalars: float):
        loss, ll_r = scalars
        print(f"loss of batch {i_batch + 1}/{n_batch}: {loss}")
        print(f"- reward loglikelihood: {ll_r}")

    def record(self, train_log, eval_log):
        super().record(train_log, eval_log)

        print(f"evaluation loss: {eval_log.mean}")
        print(f"- reward loglikelihood: {eval_log['r'].mean}")

        self.writer.add_scalar('loss', eval_log.mean, self.global_step)
        self.writer.add_scalar('reward_loglikelihood', eval_log['r'].mean, self.global_step)
