import copy
import inspect
from typing import Optional

import numpy as np


class Config:
    interval:int=25
    do_test:bool=False
    render:bool=False
    train_from_scratch:bool=True
    seed:Optional[int]=None
    use_origin_score:bool=False

    env_name: str
    env_type: str
    env_config:dict={}
    n_actions: int
    state_shape: tuple
    obs_shape: tuple
    cell_shape: tuple
    obs_is_color:bool=False
    total_rollouts_per_env: int = 50000
    max_frames_per_episode: int = 4500
    rollout_length: int = 128
    n_epochs: int = 4
    n_mini_batch: int = 4
    lr: float = 1e-4
    n_workers:int=16
    batch_size: int
    policy_use_rnn:bool=False
    model_norm_type:int=0 #[None, batchnorm layernorm]
    adv_norm_type:int=0


    feature_net_arch:dict
    obs_net_arch:dict
    cell_net_arch:dict
    extra_hidden_size:int
    policy_kwargs:dict={}


    # rl related
    ext_gamma: float = 0.999
    int_gamma: float = 0.99
    lam: float = 0.95
    ext_adv_coeff: float = 2.0
    int_adv_coeff: float = 1.0
    ent_coeff: float = 1e-3
    clip_range: float = 0.10




    def __init__(self,dict):
        for k,v in dict.items():
            setattr(self,k,v)
        self.update_variables()

    @staticmethod
    def conv_shape(input, kernel_size, stride, padding=0):
        if not isinstance(input, tuple):
            return (input + 2 * padding - kernel_size) // stride + 1
        else:
            return tuple((i + 2 * padding - kernel_size) // stride + 1 for i in input)

    def update_variables(self):
        self.batch_size=(self.rollout_length*self.n_workers)//self.n_mini_batch
        self.update_flatten_size()
        self.update_spec_net_arch()

    def update_spec_net_arch(self):
        pass

    def update_flatten_size(self):
        if "cnn" not in self.feature_net_arch:
            return
        assert len(self.obs_shape)==3 and self.obs_shape[1:]==self.state_shape[1:]
        conv_size=self.obs_shape[1:]
        for arch in self.feature_net_arch["cnn"]:
            conv_size=self.conv_shape(conv_size,arch[2],arch[3],arch[4])
        self.feature_net_arch["mlp"][0]=self.feature_net_arch["cnn"][-1][1]*np.prod(conv_size)
        self.feature_net_arch["cnn"][0][0] = self.state_shape[0]

        if not hasattr(self, "obs_net_arch"):
            self.obs_net_arch=copy.deepcopy(self.feature_net_arch)
            self.obs_net_arch["cnn"][0][0] = self.obs_shape[0]

        if not hasattr(self, "cell_net_arch"):
            if self.env_type == "MiniGrid" and "MG_cell_size" in self.env_config and self.env_config[
                "MG_cell_size"] > 1:
                self.cell_shape=(self.cell_shape[0]*self.env_config["MG_cell_size"],*self.cell_shape[1:])
                self.cell_net_arch=copy.deepcopy(self.feature_net_arch)
                self.cell_net_arch["cnn"][0][0]=self.cell_shape[0]
            else:
                self.cell_net_arch=copy.deepcopy(self.feature_net_arch if self.cell_shape==self.state_shape else self.obs_net_arch)
        else:
            cell_conv_size=self.cell_shape[1:]
            for arch in self.cell_net_arch["cnn"]:
                cell_conv_size=self.conv_shape(cell_conv_size,arch[2],arch[3],arch[4])
            self.cell_net_arch["mlp"][0] = self.cell_net_arch["cnn"][-1][1] * np.prod(cell_conv_size)
            self.cell_net_arch["cnn"][0][0] = self.cell_shape[0]




    def get_all_params(self):
        vars_dict = {i[0]: i[1] for i in inspect.getmembers(self) if
                     not inspect.ismethod(i[1]) and not i[0].startswith("_")}
        return vars_dict