
import dataclasses as dc
import gym
import numpy as np
# import or_gym
import os
import random
import time
import torch
import typing as ty

from abc import ABC, abstractmethod


# msg_r = lambda msg: f"\x1b[1;31;40m{msg}\x1b[0m"
# msg_g = lambda msg: f"\x1b[1;32;40m{msg}\x1b[0m"
# msg_y = lambda msg: f"\x1b[1;33;40m{msg}\x1b[0m"
# msg_b = lambda msg: f"\x1b[1;34;40m{msg}\x1b[0m"
# msg_p = lambda msg: f"\x1b[1;35;40m{msg}\x1b[0m"
# msg_s = lambda msg: f"\x1b[1;36;40m{msg}\x1b[0m"
msg_r = lambda msg: f"\033[31m{msg}\033[0m"
msg_g = lambda msg: f"\033[32m{msg}\033[0m"
msg_y = lambda msg: f"\033[33m{msg}\033[0m"
msg_b = lambda msg: f"\033[34m{msg}\033[0m"
msg_p = lambda msg: f"\033[35m{msg}\033[0m"
msg_s = lambda msg: f"\033[36m{msg}\033[0m"


@dc.dataclass
class Param4TrainEval:

    ''' common '''
    # algo_name           : str = None
    # env_name            : str = None
    max_ep_len          : int = 1000
    num_test_episodes   : int = 100
    num_epochs          : int = 1000
    reward_scale        : int = 1.

    ''' online '''
    start_steps         : int = 100 # if t < start_steps, random action is performed
    update_after        : int = 100
    update_every        : int = 10
    policy_update_every : int = 1
    steps_per_epoch     : int = 100
    greedy_test         : bool = True

    ''' logging and misc '''
    logging             : bool = True
    log_dir             : str = './tmp'
    # exp_name            : str = 'test'
    save_model          : bool = False
    remove_old_models   : bool = False
    save_batch          : bool = False
    seed                : int = 777
    verbose             : bool = True
    message_freq        : int = 10

    def __post_init__(self):
        # self.log_dir = self.log_dir + '/log'
        self.log_file = self.log_dir / 'score'
        self.log_dir.mkdir(parents=True, exist_ok=True)
        # self.log_file.mkdir(parents=True, exist_ok=True)
        # if not os.path.isdir(self.log_dir):
        #     os.makedirs(self.log_dir, exist_ok=True)
        if self.save_model:
            self.save_model_dir = self.log_dir / 'model'
            self.save_model_dir.mkdir(parents=True, exist_ok=True)
            # if not os.path.isdir(self.save_model_dir):
            #     os.makedirs(self.save_model_dir, exist_ok=True)
        if self.save_batch:
            self.save_batch_dir = self.log_dir / 'batch'
            self.save_batch_dir.mkdir(parents=True, exist_ok=True)
            # if not os.path.isdir(self.save_batch_dir):
            #     os.makedirs(self.save_batch_dir, exist_ok=True)
        self.total_steps = self.steps_per_epoch * self.num_epochs



class DataLogger:

    def __init__(self, verbose=True) -> None:
        self.verbose = verbose
        self.log = {}
        self.counter = 0


    def info(self, **kwargs):
        if self.counter == 0:
            for key in kwargs.keys():
                self.log[key] = [kwargs[key]]
        else:
            for key in kwargs.keys():
                if key in self.log:
                    self.log[key].append(kwargs[key])
                else:
                    self.log[key] = [None] * self.counter
                    self.log[key].append(kwargs[key])
        if self.verbose:
            print('===============')
            # for key in list(self.log.keys()).sort():
            for key in sorted(self.log.keys()):
                if 'error' in key or 'loss' in key:
                    print(f'{key}:', '{:e}'.format(self.log[key][self.counter]))
                elif 'test_ep_ret' in key:
                    print(f'{key}: {msg_b(self.log[key][self.counter])}')
                elif 'goal' in key:
                    print(f'{key}: {msg_r(self.log[key][self.counter])}')
                else:
                    print(f'{key}: {self.log[key][self.counter]}')
            # print('===============')
        self.counter += 1


    # def save(self, log_file: str, info: dict = None):
    def save(self, log_file: str):
        to_be_saved = {}
        for key in self.log.keys():
            to_be_saved[key] = np.array(self.log[key])
        np.savez_compressed(
            log_file,
            # **info,
            **to_be_saved,
        )
        print(log_file)


class Playground(ABC):

    def __init__(
            self,
            param: Param4TrainEval,
            agent: torch.nn.Module,
            env_name: str=None,
            # env: ty.Callable=None,
            # test_env: ty.Callable=None,
            env: gym.Env=None,
            test_env: gym.Env=None,
        ) -> None:

        self.param_trev = param
        self.dice = np.random.RandomState(self.param_trev.seed)

        self.agent = agent

        if env_name is None:
            assert env is not None and env is not None
            self.env = env
            self.test_env = test_env
        else:
            self.env = gym.make(env_name)
            self.test_env = gym.make(env_name)

            # torch.manual_seed(dice.randint(2**31-1))
            # self.env.seed(dice.randint(2**31-1))
            # self.env.action_space.seed(dice.randint(2**31-1))
            # self.test_env.seed(dice.randint(2**31-1))
            # self.test_env.action_space.seed(dice.randint(2**31-1))
            self.env.seed(self.param_trev.seed)
            self.env.action_space.seed(self.param_trev.seed)
            self.test_env.seed(self.param_trev.seed+1)
            self.test_env.action_space.seed(self.param_trev.seed+1)
            # random.seed(dice.randint(2**31-1))

        if self.param_trev.logging:
            self.logger = DataLogger(verbose=self.param_trev.verbose)
        else:
            self.logger = None

        self.initialize()


    @abstractmethod
    def initialize(self):
        pass


    @abstractmethod
    def pre_update(self):
        pass
        # anchors = []
        # for j in range(self.agent.n_ensamble):
        #     anchors.append([p.data.clone().detach() for p in list(self.agent.networks[j].parameters())])

    @abstractmethod
    def at_update(self):
        pass

    @abstractmethod
    def post_update(self):
        pass
        # for j in range(self.agent.n_ensamble):
        #     norm = []
        #     for i, p in enumerate(self.agent.networks[j].parameters()):
        #         norm.append(torch.sum((p - anchors[j][i])**2))
        #     print(j, torch.stack(norm).sum())


    @abstractmethod
    def test_prediction(self):
        pass
        # if 'qrql' in self.algo:
        #     u_1a, z_1a \
        #         = self.agent.predict_as_np(torch.tensor([0,0,0,0]))
        #     u_2a, z_2a \
        #         = self.agent.predict_as_np(torch.tensor([1,1,1,1]))
        # if 'uaqrql' in self.algo:
        #     u_ep1 = self.agent.compute_epistemic_uncertainty(torch.tensor([0,0,0,0]), torch.tensor([0])).mean()
        #     u_ep2 = self.agent.compute_epistemic_uncertainty(torch.tensor([0,0,0,0]), torch.tensor([1])).mean()
        #     u_ep3 = self.agent.compute_epistemic_uncertainty(torch.tensor([1,1,1,1]), torch.tensor([0])).mean()
        #     u_ep4 = self.agent.compute_epistemic_uncertainty(torch.tensor([1,1,1,1]), torch.tensor([1])).mean()
        #     epistemic = u_ep1 + u_ep3 + u_ep3 + u_ep4


    def save_model(self, tag: str, verbose=False):
        self.agent.save(self.param_trev.save_model_dir, tag, verbose)


    def remove_model(self, tag: str):
        for file in os.listdir(self.param_trev.save_model_dir):
            if tag in file:
                os.remove(self.param_trev.save_model_dir / file)
                # old_ckpt_path.unlink(missing_ok=True)


    def save_batch(self):
        # data_dir = 'rdrl/data'
        # if not os.path.isdir(data_dir):
        #     os.makedirs(data_dir, exist_ok=True)
        batch_file = self.param_trev.save_batch_dir + '/SARD'
        np.savez_compressed(
            batch_file,
            S=self.agent.buffer.ob_buf,
            A=self.agent.buffer.ac_buf,
            R=self.agent.buffer.rew_buf,
            D=self.agent.buffer.done_buf,
        )
        print(log_file)


    def train_online(self, info=None):

        start_time = time.perf_counter()
        o, ep_ret, ep_len = self.env.reset(), 0, 0
        ep_ret_all, ep_len_all, num_eps = 0, 0, 0
        total_eps = 0
        ave_loss, ave_errors = 0, 0
        old_model_tag = None
        best_step = None
        best_test_ep_ret = None

        ''' main loop '''
        for t in range(self.param_trev.total_steps):

            # print(f'{t=}')

            # print(f'agent.act')
            if t >= self.param_trev.start_steps:
                a = self.agent.act(o)
            else:
                a = self.env.action_space.sample()

            # print(f'env.step')
            o2, r, d, _ = self.env.step(a)
            d = False if ep_len == self.param_trev.max_ep_len else d

            ep_ret += r
            ep_len += 1

            # print(f'buffer.store')
            self.agent.buffer.store(o, a, r*self.param_trev.reward_scale, o2, d)

            o = o2

            ''' end of an episode '''
            if d or (ep_len == self.param_trev.max_ep_len):
                ep_ret_all += ep_ret
                ep_len_all += ep_len
                num_eps += 1
                o, ep_ret, ep_len = self.env.reset(), 0, 0

            ''' update of agent parameter '''
            if (t+1) >= self.param_trev.update_after and (t+1) % self.param_trev.update_every == 0:
                # print(f'update')
                self.pre_update()
                ave_loss, ave_errors = 0, 0
                for j in range(self.param_trev.update_every):
                    # print(f'{j=}')
                    loss, errors = self.agent.update(update_policy=(j%self.param_trev.policy_update_every==0))
                    ave_loss    += loss
                    ave_errors  += errors
                    self.at_update()
                self.post_update()
                if (t+1) % self.param_trev.steps_per_epoch == 0:
                    self.agent.update_target_network()
                ave_loss    /= self.param_trev.update_every
                ave_errors  /= self.param_trev.update_every

            ''' test, logging and misc '''
            if (t+1) % self.param_trev.steps_per_epoch == 0:
                epoch = (t+1) // self.param_trev.steps_per_epoch
                total_eps += num_eps
                test_ep_ret, test_ep_disc, test_ep_len = self.test_agent()

                self.test_prediction()

                best_updated = False
                if best_test_ep_ret is None:
                    best_updated = True
                elif best_test_ep_ret <= test_ep_ret.mean():
                    best_updated = True
                if best_updated:
                    best_step = t+1
                    best_test_ep_ret = test_ep_ret.mean()

                    if self.param_trev.save_model:
                        # self.agent.save(self.param_trev.save_model_dir)
                        tag = f't{best_step}_z{int(best_test_ep_ret)}'
                        self.save_model(tag, verbose=True)
                        if old_model_tag is not None and self.param_trev.remove_old_models:
                            self.remove_model(old_model_tag)
                        old_model_tag = tag

                if self.logger:
                    self.logger.info(
                        epoch=epoch, episodes=total_eps,
                        elapsed=time.perf_counter() - start_time,
                        env_steps=t,
                        train_loss=ave_loss, train_error=ave_errors,
                        train_ep_len=ep_len_all / num_eps if num_eps > 0 else ep_len_all,
                        train_ep_ret=ep_ret_all / num_eps if num_eps > 0 else ep_ret_all,
                        test_ep_len=test_ep_len.mean(), test_ep_ret=test_ep_ret.mean(),
                        best_step=best_step, best_test_ep_ret=best_test_ep_ret,
                    )
                    self.logger.save(self.param_trev.log_file, info)

                ep_ret_all, ep_len_all, num_eps = 0, 0, 0

                if self.param_trev.save_model:
                    self.save_model('recent')

        ''' after main loop '''
        if self.logger:
            self.logger.save(self.param_trev.log_file, info)

        if self.param_trev.save_batch:
            self.save_batch()


    def collect_data(self, model_path, ratio_expert, batch_sizes):
        self.agent.load(model_path)

        def act_random():
            return self.test_env.action_space.sample()

        self.buffer = ReplayBuffer(self.env.observation_space, int(batch_sizes[-1]*1.1))
        # self.buffer_expert = ReplayBuffer(self.env.observation_space, int(batch_sizes[-1]*1.1))
        # self.buffer_random = ReplayBuffer(self.env.observation_space, int(batch_sizes[-1]*1.1))

        ave_ep_ret, ave_ep_disc, ave_ep_len, ave_rew = 0, 0, 0, 0

        bsize_idx = 0
        n_transitions = 0
        n_transitions_expert = 0
        n_transitions_random = 0
        phase_expert = True

        while True:
            o, d, ep_ret, ep_disc, ep_len = self.test_env.reset(), False, 0, 0, 0
            step = 0
            while not(d or (ep_len == self.max_ep_len)):
                # Take deterministic actions at test time
                if phase_expert:
                    a = self.agent.act(o, greedy=True)
                else:
                    a = act_random()
                o2, r, d, _ = self.test_env.step(a)
                self.buffer.store(o, a, r, o2, d)

                # self.buffer.store(o, a, r, o2, d)
                # if phase_expert:
                #     self.buffer_expert.store(o, a, r, o2, d)
                # else:
                #     self.buffer_random.store(o, a, r, o2, d)

                ep_ret  += r
                ep_disc += r * self.gamma**step
                ep_len  += 1
                step    += 1

                o = o2

            ave_ep_ret  += ep_ret
            ave_ep_disc += ep_disc
            ave_ep_len  += ep_len
            ave_rew     += ep_ret / ep_len

            # print(ep_len)

            n_transitions += step
            if phase_expert:
                n_transitions_expert += step
            else:
                n_transitions_random += step

            if phase_expert and n_transitions_expert >= batch_sizes[bsize_idx] * ratio_expert:
                print(phase_expert, n_transitions, n_transitions_expert, n_transitions_random)
                phase_expert = False
            elif not phase_expert and n_transitions_random >= batch_sizes[bsize_idx] * (1 - ratio_expert):
                print(phase_expert, n_transitions, n_transitions_expert, n_transitions_random)
                bsize_idx += 1
                phase_expert = True

                data_dir = 'rdrl/data'
                if not os.path.isdir(data_dir):
                    os.makedirs(data_dir, exist_ok=True)
                log_file = data_dir + '/' + 'batch_' + self.env_name + '_r' + str(ratio_expert) + '_N' + str(n_transitions)
                np.savez_compressed(
                    log_file,
                    S=self.buffer.ob_buf,
                    A=self.buffer.ac_buf,
                    R=self.buffer.rew_buf,
                    D=self.buffer.done_buf,
                )
                print(log_file)
                if bsize_idx >= len(batch_sizes):
                    break

        return ave_ep_ret/self.num_test_episodes, ave_ep_disc/self.num_test_episodes, ave_ep_len/self.num_test_episodes, ave_rew/ self.num_test_episodes



    def test_agent(self):
        ep_rets = np.zeros(self.param_trev.num_test_episodes)
        ep_discs = np.zeros(self.param_trev.num_test_episodes)
        ep_lens = np.zeros(self.param_trev.num_test_episodes)
        for j in range(self.param_trev.num_test_episodes):
            o, d, ep_ret, ep_disc, ep_len = self.test_env.reset(), False, 0, 0, 0
            step = 0
            while not(d or (ep_len == self.param_trev.max_ep_len)):
                a = self.agent.act(o, greedy=self.param_trev.greedy_test)
                o, r, d, _ = self.test_env.step(a)
                ep_ret  += r
                ep_disc += r * self.agent.gamma ** step
                ep_len  += 1
                step    += 1
            ep_rets [j] = ep_ret
            ep_discs[j] = ep_disc
            ep_lens [j] = ep_len

        return ep_rets, ep_discs, ep_lens


if __name__ == "__main__":

    import argparse
    import yaml

    parser = argparse.ArgumentParser()
    # parser.add_argument('config', type=str)
    parser.add_argument('--config', type=str)
    parser.add_argument('--env', type=str, default='mdp2s-v0')
    parser.add_argument('--train', action='store_true')
    parser.add_argument('--plot', action='store_true')
    parser.add_argument('--collect', action='store_true')
    parser.add_argument('--model_path', type=str)
    args = parser.parse_args()

    assert args.config is not None
    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
        print(config)

    if args.train:
        playground = Playground(**config)
        playground.train()
        # playground.train_by_uniform_sampling()
    if args.collect:
        assert args.model_path is not None
        # model_path = 'rdrl/model/cartpole_expert.pth'
        ratio_expert = [0.9, 0.7, 0.5, 0.3, 0.1]
        # ratio_expert = [0.7, 0.5, 0.3, 0.1]
        batch_sizes = [1e3, 1e4, 1e5, 1e6]
        for r in ratio_expert:
            playground = Playground(**config)
            playground.collect_data(args.model_path, r, batch_sizes)
