
import dataclasses as dc
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import time
import torch
import yaml

from copy import deepcopy
from pathlib import Path

from algos.playground import Playground, Param4TrainEval, DataLogger


MUJOCO_MAX_EPISODE_LENGTH = 1000.0
# DEBUGGING = True
DEBUGGING = False


@dc.dataclass
class Param4TrainEval4Mujoco(Param4TrainEval):

    algo_name           : str = None
    env_name            : str = None
    save_to_resume_every: int = 0
    resume_from         : str = None
    check_scales        : bool = False

    def __post_init__(self):
        super(Param4TrainEval4Mujoco, self).__post_init__()



class Playground4Mujoco(Playground):

    def initialize(self):

        def get_param_norm(model):
            norm = 0
            for p in model.parameters():
                norm += torch.sum(p**2).item()
            return norm

        norm_all = 0
        with torch.no_grad():
            if hasattr(self.agent, 'q_net'):
                norm = get_param_norm(self.agent.q_net)
                norm_all += norm
                print('q_net:', norm)
            if hasattr(self.agent, 'amortized'):
                norm = get_param_norm(self.agent.amortized)
                norm_all += norm
                print('amortized:', norm)
            if hasattr(self.agent, 'policy'):
                norm = get_param_norm(self.agent.policy)
                norm_all += norm
                print('policy:', norm)
            if hasattr(self.agent, 'qf1'):
                norm = get_param_norm(self.agent.qf1)
                norm_all += norm
                print('qf1:', norm)
            if hasattr(self.agent, 'actor'):
                norm = get_param_norm(self.agent.actor)
                norm_all += norm
                print('actor:', norm)

        print('Initial norm of {pi, Q} =', norm)

        if self.param_trev.check_scales:
            self.logger4scales = DataLogger(verbose=False)


    def pre_update(self, test_now=False):
        if self.param_trev.check_scales:
            if self.param_trev.algo_name == 'mdac':
                if test_now:
                    self.agent.logged_tau = []
                    self.agent.logged_kappa = []
                    self.agent.logged_reward = []
                    self.agent.logged_log_pi = []
                    self.agent.logged_bonus = []
                    self.agent.logged_next_log_pi = []
                    self.agent.logged_next_bonus = []
                    self.agent.logged_next_q = []
                    self.agent.logged_clipped = []
                    self.agent.logged_next_clipped = []
                else:
                    self.agent.logged_tau = None
                    self.agent.logged_kappa = None
                    self.agent.logged_reward = None
                    self.agent.logged_log_pi = None
                    self.agent.logged_bonus = None
                    self.agent.logged_next_log_pi = None
                    self.agent.logged_next_bonus = None
                    self.agent.logged_next_q = None
                    self.agent.logged_clipped = None
                    self.agent.logged_next_clipped = None
            elif self.param_trev.algo_name == 'mdac_v':
                if test_now:
                    self.agent.logged_reward = []
                    self.agent.logged_q_sa = []
                    self.agent.logged_v_s = []
                    self.agent.logged_kappa = []
                    self.agent.logged_bonus = []
                else:
                    self.agent.logged_reward = None
                    self.agent.logged_q_sa = None
                    self.agent.logged_v_s = None
                    self.agent.logged_kappa = None
                    self.agent.logged_bonus = None
        else:
            self.agent.logged_tau = None
            self.agent.logged_kappa = None
            self.agent.logged_reward = None
            self.agent.logged_log_pi = None
            self.agent.logged_bonus = None
            self.agent.logged_next_log_pi = None
            self.agent.logged_next_bonus = None
            self.agent.logged_next_q = None
            self.agent.logged_clipped = None
            self.agent.logged_next_clipped = None


    def at_update(self, test_now=False):
        pass


    def post_update(self, test_now=False):
        if self.param_trev.check_scales:
            if self.param_trev.algo_name == 'mdac':
                if test_now:
                    tau = torch.cat(self.agent.logged_tau, dim=0)
                    kappa = torch.cat(self.agent.logged_kappa, dim=0)
                    reward = torch.cat(self.agent.logged_reward, dim=0)
                    log_pi = torch.cat(self.agent.logged_log_pi, dim=0)
                    bonus = torch.cat(self.agent.logged_bonus, dim=0)
                    next_log_pi = torch.cat(self.agent.logged_next_log_pi, dim=0)
                    next_bonus = torch.cat(self.agent.logged_next_bonus, dim=0)
                    next_q = torch.cat(self.agent.logged_next_q, dim=0)
                    clipped = torch.cat(self.agent.logged_clipped, dim=0)
                    next_clipped = torch.cat(self.agent.logged_next_clipped, dim=0)

                    print(f'tau    : (min, max) = ({tau.min().item(), tau.max().item()})')
                    print(f'kappa  : (min, max) = ({kappa.min().item(), kappa.max().item()})')
                    print(f'reward : (min, max) = ({reward.min().item(), reward.max().item()})')
                    print(f'log_pi : (min, max) = ({log_pi.min().item(), log_pi.max().item()})')
                    print(f'bonus  : (min, max) = ({bonus.min().item(), bonus.max().item()})')
                    print(f'nlog_pi: (min, max) = ({next_log_pi.min().item(), next_log_pi.max().item()})')
                    print(f'nbonus : (min, max) = ({next_bonus.min().item(), next_bonus.max().item()})')
                    print(f'next_q : (min, max) = ({next_q.min().item(), next_q.max().item()})')
                    print(f'     clipped: (mean)     = ({clipped.mean().item()})')
                    print(f'next_clipped: (mean)     = ({next_clipped.mean().item()})')



    def test_prediction(self, epoch):

        if self.param_trev.verbose:

            seed = self.dice.randint(2**31-1)
            o, _ = self.test_env.reset(seed=seed)
            a = self.test_env.action_space.sample()

            q_sa = self.agent.predict_as_np(o, a)
            print(f'q_sa = {q_sa}')

            if self.param_trev.algo_name == 'sac':
                print(f'target_entropy    = {self.agent.target_entropy}')
                print(f'log_coeff_entropy = {self.agent.log_coeff_entropy}')
            elif self.param_trev.algo_name == 'mdac':
                print(f'(tau,  kappa) = ({self.agent.get_coeff_entropy().cpu().item()}, {self.agent.get_coeff_kl().cpu().item()})')
                if not self.agent.autotune_redual:
                    alpha, beta = self.agent.get_alpha_beta()
                    print(f'(alpha, beta) = ({alpha.cpu().item()}, {beta.cpu().item()})')



    def train_online(self, info=None):

        start_time = time.perf_counter()
        seed = self.dice.randint(2**31-1)
        o, _ = self.env.reset(seed=seed)
        ep_ret, ep_len = 0, 0
        ep_ret_all, ep_len_all, num_eps = 0, 0, 0
        total_eps = 0
        ave_loss, ave_errors = 0, 0
        old_model_tag = None
        best_step = None
        best_test_ep_ret = None
        best_test_ep_len = None
        longest_test_ep_len = 0

        ckpt_saved_before_at = None
        step_starts_at = 0
        if self.param_trev.resume_from is not None:
            self.load_to_resume(load_path=self.param_trev.resume_from)
            step_starts_at = int(self.param_trev.resume_from.stem.lstrip('step')) + 1
            print(f"Resuming the training at ep {step_starts_at} based on {self.param_trev.resume_from}.")

        ''' main loop '''
        for t in range(step_starts_at, self.param_trev.total_steps):

            if t >= self.param_trev.start_steps:
                a = self.agent.act(o)
            else:
                a = self.env.action_space.sample()

            o2, r, d, truncated, _ = self.env.step(a)

            ep_ret += r
            ep_len += 1

            self.agent.buffer.store(o, a, r, o2, d)
            if DEBUGGING and t % 1000 == 0:
                print(f'{t = }:')
                print(f'    {o = }:')
                print(f'    {a = }:')
                print(f'    {r = }:')
                print(f'    {o2 = }:')
                print(f'    {d = }:')

            o = o2

            ''' end of an episode '''
            if d or truncated or ep_len == self.param_trev.max_ep_len:
                ep_ret_all += ep_ret
                ep_len_all += ep_len
                num_eps += 1
                seed = self.dice.randint(2**31-1)
                o, _ = self.env.reset(seed=seed)
                ep_ret, ep_len = 0, 0

            ''' update of agent parameter '''
            if (t+1) >= self.param_trev.update_after and (t+1) % self.param_trev.update_every == 0:
                # print(f'update')
                test_now = (t+1) % self.param_trev.steps_per_epoch == 0
                self.pre_update(test_now)
                ave_loss, ave_errors = 0, 0
                for j in range(self.param_trev.update_every):
                    # print(f'{j=}')
                    if self.param_trev.update_every == 1:
                        update_policy = (t+1)%self.param_trev.policy_update_every==0
                    else:
                        update_policy = (j%self.param_trev.policy_update_every==0)
                    loss, errors = self.agent.update(
                        update_policy=update_policy,
                    )
                    ave_loss    += loss
                    ave_errors  += errors
                    self.at_update(test_now)
                self.post_update(test_now)
                ave_loss    /= self.param_trev.update_every
                ave_errors  /= self.param_trev.update_every

            ''' test, logging and misc '''
            if (t+1) % self.param_trev.steps_per_epoch == 0:
                epoch = (t+1) // self.param_trev.steps_per_epoch
                total_eps += num_eps
                test_ep_ret, test_ep_disc, test_ep_len = self.test_agent()

                self.test_prediction(epoch)

                best_updated = False
                if best_test_ep_ret is None:
                    best_updated = True
                elif best_test_ep_ret <= test_ep_ret.mean():
                    best_updated = True
                if best_updated:
                    best_step = t+1
                    best_test_ep_ret = test_ep_ret.mean()
                    best_test_ep_len = test_ep_len.mean()
                    print(f'*** Best Updated: t={best_step}, z={best_test_ep_ret}')
                    if self.param_trev.save_model:
                        tag = f't{best_step}_z{int(best_test_ep_ret)}'
                        self.save_model(tag, verbose=True)
                        if old_model_tag is not None and self.param_trev.remove_old_models:
                            self.remove_model(old_model_tag)
                        old_model_tag = tag

                if longest_test_ep_len < test_ep_len.mean():
                    longest_test_ep_len = test_ep_len.mean()
                assert longest_test_ep_len <= MUJOCO_MAX_EPISODE_LENGTH

                if self.logger:

                    info = {}
                    info.update(
                        epoch=epoch, episodes=total_eps,
                        elapsed=time.perf_counter() - start_time,
                        env_steps=t,
                        train_loss=ave_loss, train_error=ave_errors,
                        train_ep_len=ep_len_all / num_eps if num_eps > 0 else ep_len_all,
                        train_ep_ret=ep_ret_all / num_eps if num_eps > 0 else ep_ret_all,
                        test_ep_len=test_ep_len.mean(), test_ep_ret=test_ep_ret.mean(),
                        best_step=best_step, best_test_ep_ret=best_test_ep_ret,
                        best_test_ep_len=best_test_ep_len, longest_test_ep_len=longest_test_ep_len,
                    )
                    if self.param_trev.algo_name == 'mdac':
                        info.update(
                            tau = self.agent.get_coeff_entropy().cpu().item(),
                            kappa = self.agent.get_coeff_kl().cpu().item(),
                        )
                    self.logger.info(**info)
                    if self.param_trev.check_scales and self.param_trev.algo_name == 'mdac':
                        info = {}
                        if (t+1) >= self.param_trev.update_after:
                            tau = torch.cat(self.agent.logged_tau, dim=0)
                            kappa = torch.cat(self.agent.logged_kappa, dim=0)
                            reward = torch.cat(self.agent.logged_reward, dim=0)
                            log_pi = torch.cat(self.agent.logged_log_pi, dim=0)
                            bonus = torch.cat(self.agent.logged_bonus, dim=0)
                            next_log_pi = torch.cat(self.agent.logged_next_log_pi, dim=0)
                            next_bonus = torch.cat(self.agent.logged_next_bonus, dim=0)
                            next_q = torch.cat(self.agent.logged_next_q, dim=0)
                            clipped = torch.cat(self.agent.logged_clipped, dim=0)
                            next_clipped = torch.cat(self.agent.logged_next_clipped, dim=0)
                            info.update(
                                env_steps=t,
                                tau_min=tau.min(), tau_mean=tau.mean(), tau_max=tau.max(),
                                kappa_min=kappa.min(), kappa_mean=kappa.mean(), kappa_max=tau.max(),
                                reward_min=reward.min(), reward_mean=reward.mean(), reward_max=reward.max(),
                                log_pi_min=log_pi.min(), log_pi_mean=log_pi.mean(), log_pi_max=log_pi.max(),
                                bonus_min=bonus.min(), bonus_mean=bonus.mean(), bonus_max=bonus.max(),
                                next_log_pi_min=next_log_pi.min(), next_log_pi_mean=next_log_pi.mean(), next_log_pi_max=next_log_pi.max(),
                                next_bonus_min=next_bonus.min(), next_bonus_mean=next_bonus.mean(), next_bonus_max=next_bonus.max(),
                                next_q_min=next_q.min(), next_q_mean=next_q.mean(), next_q_max=next_q.max(),
                                clipped_min=clipped.min(), clipped_mean=clipped.mean(), clipped_max=clipped.max(),
                                next_clipped_min=next_clipped.min(), next_clipped_mean=next_clipped.mean(), next_clipped_max=next_clipped.max()
                            )
                        else:
                            info.update(
                                tau_min=np.nan, tau_mean=np.nan, tau_max=np.nan,
                                kappa_min=np.nan, kappa_mean=np.nan, kappa_max=np.nan,
                                reward_min=np.nan, reward_mean=np.nan, reward_max=np.nan,
                                log_pi_min=np.nan, log_pi_mean=np.nan, log_pi_max=np.nan,
                                bonus_min=np.nan, bonus_mean=np.nan, bonus_max=np.nan,
                                next_log_pi_min=np.nan, next_log_pi_mean=np.nan, next_log_pi_max=np.nan,
                                next_bonus_min=np.nan, next_bonus_mean=np.nan, next_bonus_max=np.nan,
                                next_q_min=np.nan, next_q_mean=np.nan, next_q_max=np.nan,
                                clipped_min=np.nan, clipped_mean=np.nan, clipped_max=np.nan,
                                next_clipped_min=np.nan, next_clipped_mean=np.nan, next_clipped_max=np.nan
                            )
                        self.logger4scales.info(**info)
                        self.logger4scales.save(self.param_trev.log_file.parent/'scales')

                    self.logger.save(self.param_trev.log_file)

                ep_ret_all, ep_len_all, num_eps = 0, 0, 0

                if self.param_trev.save_model:
                    self.save_model('recent')

                if self.param_trev.save_to_resume_every > 0:
                    if (t+1) >= self.param_trev.update_after and epoch % self.param_trev.save_to_resume_every == 0:
                        ckpt_path = self.param_trev.save_model_dir / f'step{t}.ckpt'
                        self.save_to_resume(save_path=ckpt_path)
                        if ckpt_saved_before_at is not None:
                            old_ckpt_path = self.param_trev.save_model_dir / f'step{ckpt_saved_before_at}.ckpt'
                            old_ckpt_path.unlink(missing_ok=True)
                        ckpt_saved_before_at = t

        ''' after main loop '''
        if self.logger:
            self.logger.save(self.param_trev.log_file)

        if self.param_trev.save_batch:
            self.save_batch()


    def test_agent(self):
        ep_rets = np.zeros(self.param_trev.num_test_episodes)
        ep_discs = np.zeros(self.param_trev.num_test_episodes)
        ep_lens = np.zeros(self.param_trev.num_test_episodes)
        # show_trajectory = True
        show_trajectory = False
        for j in range(self.param_trev.num_test_episodes):
            seed = self.dice.randint(2**31-1)
            o, _ = self.test_env.reset(seed=seed)
            d, truncated, ep_ret, ep_disc = False, False, 0, 0
            step = 0
            while not (d or truncated or step == self.param_trev.max_ep_len):
                a = self.agent.act(o, greedy=self.param_trev.greedy_test)
                o2, r, d, truncated, info = self.test_env.step(a)
                ep_ret  += r
                ep_disc += r * self.agent.gamma ** step
                step    += 1
                if show_trajectory and j == 0 and (step%100==0 or d or truncated):
                    print(f'{step = }')
                    print(f'    {o = }')
                    print(f'    {a = }')
                    print(f'    {r = }')
                    print(f'    {d = }')
                    print(f'    {truncated = }')
                o = o2
            if show_trajectory and j == 0:
                print(f'    {o = }')

            ep_rets [j] = ep_ret
            ep_discs[j] = ep_disc
            ep_lens [j] = step

        return ep_rets, ep_discs, ep_lens


    def test(self):
        test_ep_ret, test_ep_disc, test_ep_len = self.test_agent()
        print(f'{test_ep_len = }')
        print(f'{test_ep_disc = }')
        print(f'{test_ep_ret = }')
        print(f'{goal_counts = }')


    def save_to_resume(self, save_path: Path):#, device: str='cpu'):

        checkpoint = dict()
        checkpoint['logger'] = self.logger
        checkpoint['agent'] = self.agent.get_state_dict()
        checkpoint['env'] = self.env
        checkpoint['test_env'] = self.test_env
        checkpoint['random_state'] = random.getstate()
        checkpoint['random_state_np'] = np.random.get_state()
        checkpoint['random_state_torch'] = torch.random.get_rng_state()
        if self.agent.device == torch.device("cuda"):
            checkpoint['random_state_torch_cuda'] = torch.cuda.get_rng_state()

        torch.save(checkpoint, save_path)
        print(f'Checkpoint was saved as {save_path}')


    def load_to_resume(self, load_path: Path):

        checkpoint = torch.load(load_path)
        self.logger = checkpoint['logger']
        self.agent.set_state_dict(checkpoint['agent'])
        self.env = checkpoint['env']
        self.test_env = checkpoint['test_env']
        random.setstate(checkpoint['random_state'])
        np.random.set_state(checkpoint['random_state_np'])
        torch.random.set_rng_state(checkpoint['random_state_torch'])
        if self.agent.device == torch.device("cuda"):
            torch.cuda.set_rng_state(checkpoint['random_state_torch_cuda'])

        print(f'Checkpoint was loaded from {load_path}')



def train(args):

    log_dir_ea = Path('./log').resolve() / args.env / args.algo

    if args.resume_from is None:
        log_dir = log_dir_ea / f"{time.strftime('%Y-%m%d-%H%M%S-%Z', time.localtime())}-s{args.seed}"
        param_env = None
        param_agent = None
        param_trev = None
    else:
        resume_tag = args.resume_from.split(",")[0]
        resume_step = args.resume_from.split(",")[1]
        log_dir = log_dir_ea / resume_tag
        args.resume_from = log_dir / 'model' / f'step{resume_step}.ckpt'
        with open(log_dir_ea / resume_tag / 'param_env.yaml') as f:
            param_env = yaml.safe_load(f)
        with open(log_dir_ea / resume_tag / 'param_agent.yaml') as f:
            param_agent = yaml.safe_load(f)


    if param_env is None:
        param_env = dict()
    import gymnasium as gym
    env = gym.make(args.env)
    test_env = gym.make(args.env)
    env.action_space.seed(args.seed)
    test_env.action_space.seed(args.seed)

    if args.env in ['Hopper-v4']:
        num_epochs = int(3e3)
    elif args.env in ['Walker2d-v4', 'HalfCheetah-v4', 'Ant-v4']:
        num_epochs = int(3e3)
    elif args.env in ['Humanoid-v4', 'HumanoidStandup-v4']:
        num_epochs = int(3e3)
    else:
        raise ValueError

    greedy_test = True
    num_test_episodes = 1 if greedy_test else 10
    batch_size = 256
    if DEBUGGING:
        num_epochs = int(5)
        batch_size = 4

    param_trev = Param4TrainEval4Mujoco(
        algo_name=args.algo, env_name=args.env,
        max_ep_len = 1e4,
        num_test_episodes = num_test_episodes,
        num_epochs = num_epochs,
        steps_per_epoch = int(1e3),
        start_steps = int(5e3),
        update_after = int(5e3),
        update_every = 2,
        policy_update_every = 2,
        greedy_test = greedy_test,
        logging = True,
        reward_scale = 1.,
        log_dir = log_dir,
        verbose = args.verbose,
        check_scales = args.check_scales,
        save_model = True,
        remove_old_models = True,
        seed = args.seed,
        save_to_resume_every = args.save_to_resume_every,
        resume_from = args.resume_from,
    )


    if args.algo == 'td3_clean':
        from algos.td3_clean import TD3 as Agent
        from algos.td3_clean import Param4TD3 as Param4Agent
        if param_agent is not None:
            param_agent = Param4Agent(**param_agent)
        else:
            param_agent = Param4Agent(
                buffer_size = int(1e6),
                batch_size = batch_size,
                gamma = 0.99,
                tau = 0.005,
                seed = args.seed,
                cuda = True,
            )

    elif args.algo == 'sac_clean':
        from algos.sac_clean import SoftActorCritic as Agent
        from algos.sac_clean import Param4SAC as Param4Agent
        if param_agent is not None:
            param_agent = Param4Agent(**param_agent)
        else:
            param_agent = Param4Agent(
                buffer_size = int(1e6),
                batch_size = batch_size,
                gamma = 0.99,
                autotune_ent_coeff = True,
                tau = 0.005,
                policy_lr = 3e-4,
                q_lr = 3e-4,
                seed = args.seed,
                cuda = True,
                policy_type = 'gaussian',
                bonus_squash = True,
            )

    elif args.algo == 'sac':
        from algos.sac import SoftActorCritic as Agent
        from algos.sac import Param4SAC as Param4Agent
        if param_agent is not None:
            param_agent = Param4Agent(**param_agent)
        else:
            param_agent = Param4Agent(
                buffer_size = int(1e6),
                n_hidden = [256, 256],
                batch_size = batch_size,
                gamma = 0.99,
                autotune_ent_coeff = True,
                action_scaler = 'naive',
                use_double = True,
                lr_policy = 3e-4,
                lr = 3e-4,
                grad_clipping = None,
                seed = args.seed,
                cuda = True,
            )

    elif args.algo == 'mdac':
        from algos.mdac import MirrorDescentActorCritic as Agent
        from algos.mdac import Param4MDAC as Param4Agent

        gamma = 0.99
        kappa = 1 - (1 - gamma)**2

        if param_agent is not None:
            param_agent = Param4Agent(**param_agent)
        else:
            param_agent = Param4Agent(
                buffer_size = int(1e6),
                n_hidden=[256, 256],
                batch_size = batch_size,
                gamma = gamma,
                use_double = True,
                lr = 3e-4,
                lr_policy = 3e-4,
                action_scaler = 'naive',
                grad_clipping = None,
                seed = args.seed,
                cuda = True,
                autotune_kl_coeff = False,
                coeff_kl = kappa,
                autotune_ent_coeff = True,
                bound_f = args.bound_f,
                bound_g = args.bound_g,
                policy_type = 'gaussian',
                explorer_type = 'gaussian',
            )

    else:
        raise ValueError

    with open(log_dir / 'param_env.yaml', 'w') as file:
        yaml.dump(param_env, file)
    with open(log_dir / 'param_agent.yaml', 'w') as file:
        yaml.dump(dc.asdict(param_agent), file)
    with open(log_dir / 'param_trev.yaml', 'w') as file:
        yaml.dump(dc.asdict(param_trev), file)

    agent = Agent(
        env.observation_space, env.action_space,
        **dc.asdict(param_agent)
    )

    playground = Playground4Mujoco(
        param=param_trev, agent=agent, env=env, test_env=test_env,
    )
    playground.train_online()



if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Ant-v4')
    parser.add_argument('--algo', type=str, default='mdac')
    parser.add_argument('--seed', type=int, default=459)
    parser.add_argument('--train', action='store_true')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--check_scales', action='store_true')
    parser.add_argument('--load_path', type=str, default=None)
    parser.add_argument('--load_tag', type=str, default=None)
    parser.add_argument('--save_to_resume_every', type=int, default=0)
    parser.add_argument('--resume_from', type=str, default=None)

    parser.add_argument('--bound_f', type=str, default='rclip')
    parser.add_argument('--bound_g', type=str, default='rclip')

    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if args.train:
        train(args)
    if args.test:
        test(args)
