from typing import Optional, Dict, Tuple, List, Literal


import numpy as np
import torch
import torch.nn as nn
import json
import argparse
import tensorboardX
import os
import time
import pathlib

from .base import Experiment
from alg.model.base import RLModel
from .train import Train
from core import DType, Taskinfo
from alg.buffer import ObjectOrientedBuffer
from alg.model import *
from core.causal_graph import CausalGraph, ObjectOrientedCausalGraph
from alg.cem import cross_entropy_method, CEMArgs
import alg.model.mask_generator as mg
import alg.functional as F
import utils
from utils.typings import ObjectTensors, NamedTensors, TransitionModel


_REWARD = 'reward'
_RETURN = 'return'

class Test(Experiment):
    use_existing_path = True

    class Args(utils.Struct):
        def __init__(self,
            device: str = 'cuda',
            n_timestep: int = 1000,
            eval_n_sample: int = 10000,
            eval_batchsize: int = 512,
            cem_args = CEMArgs(),
        ):
            self.device = device
            self.n_timestep = n_timestep
            self.eval_n_sample = eval_n_sample
            self.eval_batchsize = eval_batchsize
            self.cem_args = cem_args

    def __init__(self,
        path: str,
        args: Optional['Test.Args'] = None,
        env_options: dict = {},
        label: str = 'test',
    ) -> None:
        super().__init__()

        args = args or self.Args()
        self.device = torch.device(args.device)
        self.dtype = DType.Real.torch
        self.n_timestep = args.n_timestep
        self.eval_n_sample = args.eval_n_sample
        self.eval_batchsize = args.eval_batchsize
        self.cem_args = args.cem_args
        self.label = label

        self.follow_experiment(path, env_options)

        model_args = utils.Struct.from_dict(self.load_args('model'))
        self.model = self.load_model(model_args)
        self.model.train(False)

        self.setup()

    def setup(self):
        pass

    def load_model(self, model_args: utils.Struct) -> RLModel:
        raise NotImplementedError
    
    def reset_env(self):
        return self.env.reset()
        
    def __get_action_cem(self):
        return cross_entropy_method(
            self.env, self.get_transition_model(), None, self.device, self.cem_args)
    
    def __get_action_random(self):
        return self.env.action_space.sample()
    
    def __print_log(self, log: utils.Log):
        print(f"- average episode return: {log[_RETURN].mean}")
        print(f"- average step reward: {log[_REWARD].mean}")

    def collect(self, n_sample: int, actor: Literal['random', 'cem'] = 'random',
                buffer: Optional[ObjectOrientedBuffer] = None):
        '''collect real-world samples and compute returns'''

        if actor == 'random':
            get_action = self.__get_action_random
        elif actor == 'cem':
            get_action = self.__get_action_cem
        else:
            assert False

        log = utils.Log()
        episodic_return = 0.

        self.reset_env()

        _timer = time.time()
        for i_sample in range(n_sample):

            # print progress every second
            _new_timer = time.time()
            if _new_timer - _timer >= 5:
                print(f"Collecting samples... ({i_sample}/{n_sample})")
                self.__print_log(log)
                _timer = _new_timer

            # interact with the environment
            a = get_action()
            next_state, reward, terminated, truncated, attrs = self.env.step(a)

            # record information
            episodic_return += reward
            log[_REWARD] = reward
            
            if truncated or terminated:
                log[_RETURN] = episodic_return
                episodic_return = 0.
            
            # reset if done
            if truncated or terminated:
                self.reset_env()
            
            if buffer is not None:
                buffer.add(attrs, next_state, reward)

        return log
    
    def __eval_batch(self, attrs: ObjectTensors, next_state: ObjectTensors, 
                     objmask: NamedTensors, reward: torch.Tensor, model: TransitionModel,
                    ):
        s = model(attrs, objmask)
        
        label = F.raws2labels(self.envinfo, next_state)
        logprob = F.sum_logprob_by_class(F.logprob(s, label), objmask)

        return float(logprob)

    @ torch.no_grad()
    def eval_model(self):
        '''
        train network with fixed causal graph.
        '''
        buffer = ObjectOrientedBuffer(self.eval_n_sample, self.envinfo)

        print("collecting samples for model evaluation")
        self.collect(self.eval_n_sample, 'random', buffer)

        log = utils.Log()
        batch_size = self.eval_batchsize
        for batch in buffer.epoch(batch_size, self.device):
            logp = self.__eval_batch(*batch, self.get_transition_model())
            log['logp'] = logp
        
        print("model evaluation:")
        print(f"- loglikelihood: {log['logp'].mean}")

        return log

    def get_transition_model(self) -> TransitionModel:
        raise NotImplementedError

    def main(self):
        label = self.label

        eval_log = self.eval_model()
        self.save_result((label, 'loglikelihood'), eval_log['logp'].mean)

        collect_log = self.collect(self.n_timestep, 'cem')
        print("finished:")
        self.__print_log(collect_log)
        self.save_result((label, 'reward'), collect_log[_REWARD].mean)
        self.save_result((label, 'return'), collect_log[_RETURN].mean)


class TestOOC(Test):
    def load_model(self, model_args: utils.Struct) -> RLModel:
        m = OOCModel(self.env, model_args, self.device, self.dtype)
        state_dict = torch.load(self._file_path('model', 'nn'))
        m.load_state_dict(state_dict, strict=False)
        return m

    def load_causal_graph(self):
        causal_graph = ObjectOrientedCausalGraph(self.envinfo)
        with open(self.path / 'causal-graph.json', 'r') as f:
            d = json.load(f)
        causal_graph.load_state_dict(d)
        return causal_graph

    def setup(self):
        super().setup()
        # mask generators
        self.__maskgen_graph = mg.GraphMaskGenerator(self.envinfo, self.device)
        self.__maskgen_full = mg.FullMaskGenerator(self.envinfo, self.device)

        # causal graph
        causal_graph = self.load_causal_graph()
        self.__maskgen_graph.load_graph(causal_graph)

    def get_transition_model(self):
        self.model: OOCModel
        return self.model.make_transition_model(self.__maskgen_graph)


class TestFull(Test):
    def load_model(self, model_args: utils.Struct) -> RLModel:
        m = OOCModel(self.env, model_args, self.device, self.dtype)
        state_dict = torch.load(self._file_path('model', 'nn'))
        m.load_state_dict(state_dict, strict=False)
        return m

    def setup(self):
        super().setup()
        # mask generators
        self.__maskgen_full = mg.FullMaskGenerator(self.envinfo, self.device)

    def get_transition_model(self):
        self.model: OOCModel
        return self.model.make_transition_model(self.__maskgen_full)


class TestMLP(Test):
    
    def load_model(self, model_args: utils.Struct) -> RLModel:
        m = MLPModel(self.env, model_args, self.device, self.dtype)
        state_dict = torch.load(self._file_path('model', 'nn'))
        m.load_state_dict(state_dict, strict=False)
        return m

    def get_transition_model(self):
        assert isinstance(self.model, MLPModel)
        return self.model.make_transition_model()


class TestCDL(Test):
    def load_model(self, model_args: utils.Struct) -> RLModel:
        m = CDLModel(self.env, model_args, self.device, self.dtype)
        state_dict = torch.load(self._file_path('model', 'nn'))
        m.load_state_dict(state_dict, strict=False)
        return m

    def load_causal_graph(self):
        causal_graph = CausalGraph(self.taskinfo)
        with open(self.path / 'causal-graph.json', 'r') as f:
            d = json.load(f)
        causal_graph.load_state_dict(d)
        return causal_graph

        # causal graph
    def setup(self):
        self.causal_graph = self.load_causal_graph()

    def get_transition_model(self):
        self.model: CDLModel
        return self.model.make_transition_model('graph', causal_graph=self.causal_graph)


class TestGRU(Test):
    def load_model(self, model_args: utils.Struct) -> RLModel:
        m = GRUModel(self.env, model_args, self.device, self.dtype)
        state_dict = torch.load(self._file_path('model', 'nn'))
        m.load_state_dict(state_dict, strict=False)
        return m

    def load_causal_graph(self):
        causal_graph = CausalGraph(self.taskinfo)
        with open(self.path / 'causal-graph.json', 'r') as f:
            d = json.load(f)
        causal_graph.load_state_dict(d)
        return causal_graph

    def setup(self):
        super().setup()
        self.causal_graph = self.load_causal_graph()

    def get_transition_model(self):
        self.model: GRUModel
        return self.model.make_transition_model(causal_graph=self.causal_graph)
    

class TestTICSA(Test):
    def load_model(self, model_args: utils.Struct) -> RLModel:
        m = TISCAModel(self.env, model_args, self.device, self.dtype)
        state_dict = torch.load(self._file_path('model', 'nn'))
        m.load_state_dict(state_dict, strict=False)
        return m

    def get_transition_model(self):
        self.model: TISCAModel
        return self.model.make_transition_model(tau=0)


class TestGNN(Test):
    
    def load_model(self, model_args: utils.Struct) -> RLModel:
        m = GNNModel(self.env, model_args, self.device, self.dtype)
        state_dict = torch.load(self._file_path('model', 'nn'))
        m.load_state_dict(state_dict, strict=False)
        return m

    def get_transition_model(self):
        assert isinstance(self.model, GNNModel)
        return self.model.make_transition_model()