import gymnasium as gym
import torch
from torch.nn.utils import clip_grad_norm_
from rl_tools import Sampler, ReplayMemory, ContToDiscreteActWrap
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from time import time
import argparse
from os.path import join
from gymnasium.wrappers import FlattenObservation, TimeLimit
from gymnasium.experimental.wrappers import RecordVideoV0
from dwex_nn import DWExNNvanilla
from copy import deepcopy


def get_logits_ensemble(dwexnns, obs, device):
    return sum([dwexnn.get_logits(torch.tensor(obs[None, :].astype(np.float32), device=device)) for dwexnn in dwexnns]) / len(dwexnns)

def get_logits_ensemble_torch(dwexnns, obs, no_old=False):
    return sum([dwexnn.get_logits(obs, no_old) for dwexnn in dwexnns]) / len(dwexnns)

def numpy_softmax(obs, dwexnns, device):
    with torch.no_grad():
        dist =  torch.distributions.Categorical(logits=get_logits_ensemble(dwexnns, obs, device))
        return dist.sample().squeeze().cpu().numpy(), dist.entropy().squeeze().cpu().numpy()


def numpy_argmax_policy(obs, dwexnns, device):
    with torch.no_grad():
        dist =  torch.distributions.Categorical(logits=get_logits_ensemble(dwexnns, obs, device))
        return dist.probs.argmax().cpu().numpy(), dist.entropy().squeeze().cpu().numpy()


def numpy_egreedy_softpolicy(obs, dwexnns, eps, device):
    with torch.no_grad():
        act, entrop = numpy_softmax(obs, dwexnns, device)
    if np.random.rand() < eps:
        return np.random.randint(dwexnns[0].output_size), entrop
    return act, entrop



def numpy_egreedy_detpolicy(obs, dwexnns, eps, device):
    with torch.no_grad():
        act, entrop = numpy_argmax_policy(obs, dwexnns, device)
    if np.random.rand() < eps:
        return np.random.randint(dwexnns[0].output_size), entrop
    return act, entrop



def update_target(sources, targets=None, tau=None, update_type=None):
    if update_type == 'hard':
        targets = [deepcopy(source) for source in sources]
    elif update_type == 'soft':
        for source, target in zip(sources, targets):
            for target_param, param in zip(target.parameters(), source.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
    else:
         raise ValueError(f'Update type must be "hard" or "soft", not "{update_type}"')

    [target.train(False) for target in targets]
    return targets

def run(logging_path, env_name='CartPole-v1', seed=0, max_trans=1_000_000, trans_per_iter=5000, n_ensemble=1, mode='mean',
        nb_neurone_perit=256, kl_weight=5., init_ew=1.0, final_ew=0.05, nb_hidden=2, device='cpu', l2_weight=1e-5,
        max_expand=300, udr=1., rwd_scale=1., nl_last='relu', w_correction=1, rep_mem_size=50_000, n_eval_episodes=25, eval_interval=50_000,
        lr=1e-3, batch_size=128, target_type='hard', hard_target_steps=200, soft_target_polyak=0.005, clip_grad=0.0, init_eps=0.1, final_eps=0.1, end_decay=500_000,
        use_wandb=False, det_explore=False, render=False):
    torch.set_num_threads(2)
    torch.manual_seed(seed)
    np.random.seed(seed)

    if not torch.cuda.is_available():
        device = "cpu"

    device = torch.device(device)
    print('device', device)

    assert target_type in ['hard', 'soft']

    # preparing env and env_eval

    env = gym.make(env_name)
    env_eval = gym.make(env_name, render_mode='rgb_array')
    if env_name.startswith('MinAtar'):
        env = FlattenObservation(env)
        env_eval = FlattenObservation(env_eval)
    elif env_name in ['Hopper-v4', 'Ant-v4', 'Walker2d-v4', 'HalfCheetah-v4', 'Humanoid-v4']:
        env = ContToDiscreteActWrap(env)
        env_eval = ContToDiscreteActWrap(env_eval)

    env = TimeLimit(env, max_episode_steps=5000)
    env_eval = TimeLimit(env_eval, max_episode_steps=5000)

    if render:
        env_eval = RecordVideoV0(env_eval, "videos", lambda x: x % n_eval_episodes == 0)

    gamma = 0.99

    s_dim = env.observation_space.shape[0]
    n_act = env.action_space.n

    sampler = Sampler(env)
    sampler_eval = Sampler(env_eval)
    print('logging path', logging_path)
    logger = SummaryWriter(logging_path)




    if nl_last == 'relu':
        qfuncs = [DWExNNvanilla(s_dim, n_act, nb_hidden=nb_hidden, expand_rate=nb_neurone_perit, max_expand=max_expand, use_w_correction=w_correction,
                                kl_weight=kl_weight, entropy_weight=init_ew, device=device, last_layer_nl=torch.nn.ReLU(inplace=True)) for _ in range(n_ensemble)]
    else:
        qfuncs = [DWExNNvanilla(s_dim, n_act, nb_hidden=nb_hidden, expand_rate=nb_neurone_perit, max_expand=max_expand, use_w_correction=w_correction,
                                kl_weight=kl_weight, entropy_weight=init_ew, device=device, last_layer_nl=torch.nn.Tanh()) for _ in range(n_ensemble)]

    if use_wandb:
        for qfunc in qfuncs:
            wandb.watch(qfunc.train_q)
            wandb.watch(qfunc.train_feat)

    repmem = ReplayMemory(rep_mem_size, s_dim, device)

    total_trans = 0
    # end_decay = 500_000 # max_trans // 10
    eweight_fct = lambda x: (min(x / end_decay, 1) * final_ew + (1 - min(x / end_decay, 1)) * init_ew) / np.log(n_act)
    qoptim = torch.optim.Adam([p for qfunc in qfuncs for p in qfunc.parameters()], lr=lr, weight_decay=l2_weight)

    qtars = update_target(qfuncs, update_type='hard')

    while total_trans < max_trans:
        eweight = eweight_fct(total_trans)
        [qfunc.set_entropy_weight(eweight) for qfunc in qfuncs]
        logger.add_scalar('pars/entrop_weight', eweight, total_trans)

        # sampling new trans
        [qfunc.train(False) for qfunc in qfuncs]
        eps = (min(total_trans / end_decay, 1) * final_eps + (1 - min(total_trans / end_decay, 1)) * init_eps)
        if det_explore:
            new_trans, returns, entropies, _ = sampler.rollouts(policy=lambda x: numpy_egreedy_detpolicy(x, qfuncs, eps, device),
                                                min_trans=trans_per_iter, max_trans=trans_per_iter)
        else:
            new_trans, returns, entropies, _ = sampler.rollouts(policy=lambda x: numpy_egreedy_softpolicy(x, qfuncs, eps, device),
                                              min_trans=trans_per_iter, max_trans=trans_per_iter)

        # logging returns
        for ret, entr in zip(returns, entropies):
            logger.add_scalar('eval/return', ret.value, ret.step)
            logger.add_scalar('eval/return_n_entropy', ret.value + eweight * entr, ret.step)

        repmem.add_trans(new_trans)
        total_trans += trans_per_iter

        testb = repmem.sample(200, device=device)  # for logging
        with torch.no_grad():
            # for logging
            old_dist = torch.distributions.Categorical(logits=get_logits_ensemble_torch(qfuncs, testb.obs))
            old_logits_tilde = get_logits_ensemble_torch(qfuncs, testb.obs, no_old=True)
            pol_entropy = old_dist.entropy().mean()

            # pre-computations
            nologits = torch.zeros(repmem.size, n_act, device=device)
            sid = 0
            for k in range(min(200, repmem.size), repmem.size + 1, 200):
                nologits[sid:k] = get_logits_ensemble_torch(qfuncs, repmem.repmem.nobs[sid:k])
                sid = k
            if repmem.size > sid:
                nologits[sid:repmem.size] = get_logits_ensemble_torch(qfuncs, repmem.repmem.nobs[sid:repmem.size])

        # learning q function, expansion for new q
        # [qfunc.expand_width() for qfunc in qfuncs]


        grad_steps = 0
        [qfunc.train(True) for qfunc in qfuncs]

        while grad_steps < int(udr * trans_per_iter):  # utd = 1
            # Update target networks (if using)
            if target_type == 'hard' and grad_steps % hard_target_steps == 0:
                qtars = update_target(qfuncs, update_type='hard')
            elif target_type == 'soft':
                qtars = update_target(qfuncs, qtars, soft_target_polyak, update_type='soft')

            qoptim.zero_grad()
            db, idxs = repmem.sample_with_idxs(batch_size, device=device)
            ono = torch.vstack([db.obs, db.nobs])


            curr_qalls = [qfunc(db.obs) for qfunc in qfuncs]

            with torch.no_grad():
                next_qalls = [qtar(db.nobs) for qtar in qtars]




            with torch.no_grad():
                # sample next action
                nol = nologits[idxs]
                nopol = torch.distributions.Categorical(logits=nol)

                na = nopol.sample().view(-1, 1)
                qnos = [qall.gather(dim=1, index=na) for qall in next_qalls]


                ent_term = eweight * nopol.entropy()[:, None]

            scaled_rwd = rwd_scale * db.rwd
            if mode == 'indep':
                targs = [scaled_rwd + gamma * (1 - db.terminated) * (qno.detach() + ent_term) for qno in qnos]
                lossq = sum([(qall.gather(dim=1, index=db.act) - targ).pow(2).mean()
                             for qall, targ in zip(curr_qalls, targs)]) / n_ensemble
            elif mode == 'mean':
                targ = sum([qno.detach() for qno in qnos]) / n_ensemble
                targ = scaled_rwd + gamma * (1 - db.terminated) * (targ + ent_term)
                lossq = sum([(qall.gather(dim=1, index=db.act) - targ).pow(2).mean()
                            for qall in curr_qalls]) / n_ensemble
            elif mode == 'min':
                targ = torch.hstack([qno.detach() for qno in qnos]).min(1, True)[0]
                targ = scaled_rwd + gamma * (1 - db.terminated) * (targ + ent_term)
                lossq = sum([(qall.gather(dim=1, index=db.act) - targ).pow(2).mean()
                             for qall in curr_qalls]) / n_ensemble
            elif mode == 'median':
                targ = torch.hstack([qno.detach() for qno in qnos]).median(1, True)[0]
                targ = scaled_rwd + gamma * (1 - db.terminated) * (targ + ent_term)
                lossq = sum([(qall.gather(dim=1, index=db.act) - targ).pow(2).mean()
                             for qall in curr_qalls]) / n_ensemble

            loss_logqdist = (((eweight * nol - sum(next_qalls) / n_ensemble) ** 2) * (1 - db.terminated)).mean()

            lossq.backward()

            if grad_steps % 100 == 0:
                logger.add_scalar('loss/bellerror', lossq.item(), int(udr * (total_trans - trans_per_iter)) + grad_steps)
                logger.add_scalar('loss/logqdist', loss_logqdist.item(), int(udr * (total_trans - trans_per_iter)) + grad_steps)


            if clip_grad != 0.0:
                [clip_grad_norm_(qfunc.parameters(), clip_grad) for qfunc in qfuncs]

            qoptim.step()
            grad_steps += 1

        # adding weights of current q to sumq
        [qfunc.train(False) for qfunc in qfuncs]
        with torch.no_grad():
            # evaluate policy
            if total_trans % eval_interval == 0:
                [qfunc.train(False) for qfunc in qfuncs]
                eval_rollouts = [sampler_eval.rollouts(lambda x: numpy_argmax_policy(x, qfuncs, device), 1, np.inf) for _ in range(n_eval_episodes)]
                logger.add_scalar('eval/poleval', np.mean([r[1][0].value for r in eval_rollouts]), total_trans)
                logger.add_scalar('eval/mean_ep_length', np.mean([r[3][0] for r in eval_rollouts]), total_trans)


if __name__ == '__main__':
    argp = argparse.ArgumentParser()
    argp.add_argument('--experiment', type=str, default='')
    argp.add_argument('--env-name', type=str, default='Hopper-v4')
    argp.add_argument('--seed', type=int, default=0)
    argp.add_argument('--logging-dir', type=str, default='')
    argp.add_argument('--max-trans', type=int, default=5_000_000)
    argp.add_argument('--nb-hidden', type=int, default=2)
    argp.add_argument('--max-expand', type=int, default=300)
    argp.add_argument('--n-ensemble', type=int, default=2)
    argp.add_argument('--w-correction', type=int, default=1)
    argp.add_argument('--trans-per-iter', type=int, default=5000)
    argp.add_argument('--rep-mem-size', type=int, default=50_000)
    argp.add_argument('--device', type=str, default='cuda')
    argp.add_argument('--mode', type=str, default='mean')
    argp.add_argument('--nl-last', type=str, default='relu')
    argp.add_argument('--init-ew', type=float, default=2.0)
    argp.add_argument('--final-ew', type=float, default=0.4)
    argp.add_argument('--end-decay', type=float, default=500_000)
    argp.add_argument('--lr', type=float, default=1e-3)
    argp.add_argument('--batch-size', type=int, default=128)
    argp.add_argument('--nb-neurone-perit', type=int, default=256)
    argp.add_argument('--init-eps', type=float, default=0.1)
    argp.add_argument('--final-eps', type=float, default=0.1)
    argp.add_argument('--l2-weight', type=float, default=0.0)
    argp.add_argument('--kl-weight', type=float, default=20.)
    argp.add_argument('--rwd-scale', type=float, default=1.)
    argp.add_argument('--udr', type=float, default=1.)
    argp.add_argument('--clip-grad', type=float, default=0.0)
    argp.add_argument('--n-eval-episodes', type=int, default=50)
    argp.add_argument('--eval-interval', type=int, default=50_000)
    argp.add_argument('--target-type', type=str, default='hard')
    argp.add_argument('--hard-target-steps', type=int, default=200)
    argp.add_argument('--soft-target-polyak', type=float, default=0.005)
    argp.add_argument('--det-explore', action='store_true')
    argp.add_argument('--wandb', action='store_true')
    argp.add_argument('--render', action='store_true')

    args = argp.parse_args()

    logging_dir = args.logging_dir
    paras = args.__dict__.copy()
    if logging_dir == '':
        logging_dir = f'logs/{time()}'

    experiment = '' if args.experiment == '' else f'-{args.experiment}'
    run_name = join(f'dwexvanil{experiment}_target{args.target_type}_htar{args.hard_target_steps}_star{args.soft_target_polyak}'
                    f'_n_esbl{args.n_ensemble}_nhid{args.nb_hidden}_nneur{args.nb_neurone_perit}_l2w{args.l2_weight}_lr{args.lr}_batch{args.batch_size}'
                    f'_mode{args.mode}_inew{args.init_ew}_fiew{args.final_ew}_mex{args.max_expand}_rs{args.rwd_scale}'
                    f'_tpi{args.trans_per_iter}_udr{args.udr}_klw{args.kl_weight}_nll{args.nl_last}_wcor{args.w_correction}'
                    f'_rm_size{args.rep_mem_size}_ieps{args.init_eps}_feps{args.final_eps}',
                    args.env_name, str(args.seed))

    paras['logging_path'] = join(logging_dir, run_name)

    del paras['logging_dir']
    del paras['experiment']

    if args.wandb:
        import wandb

        wandb.init(
            project='NeuRL',
            config=paras,
            sync_tensorboard=True,
            name=join(run_name, str(args.seed)),
            monitor_gym=True,
        )

    paras['use_wandb'] = paras['wandb']
    del paras['wandb']

    run(**paras)
