
from mcac.algos import SAC, TD3, GQE, AWAC, CQL, SSG, SAC_SIG, TD3_SIG, GQE_SIG, AWAC_SIG, CQL_SIG
import mcac.utils as utils
import mcac.utils.env_utils as eu
import mcac.utils.data_utils as du
import mcac.utils.pytorch_utils as ptu
from mcac.utils.arg_parser import parse_args
from mcac.utils.logx import EpochLogger

import numpy as np
from tqdm import trange
import os
import json

import torch
import torch.nn as nn

def img_gen(obs_start, hor, agent, replay_buffer_gen, ssg_module):

    img_obs = obs_start
    img_ep_buf = []

    for j in range(hor):

        img_act = agent.select_action(img_obs)

        img_obs, img_act = ptu.torchify(img_obs, img_act)
        img_next_obs = ssg_module.forward_net(torch.cat([img_obs, img_act], dim=-1))
        img_rew = ssg_module.reward_net(torch.cat([img_obs], dim=-1))

        img_obs = img_obs.cpu().detach().numpy().astype("float64")
        img_next_obs = img_next_obs.cpu().detach().numpy().astype("float64")
        img_act = img_act.cpu().detach().numpy()
        img_rew = float(img_rew.cpu().detach().numpy()[0])

        if -0.01 < img_rew < 0.01:
            img_rew = 0.0
        elif -1.01 < img_rew < -0.99:
            img_rew = -1.0
        else:
            j += 1
            continue

        img_ep_buf.append({
            'obs': img_obs,
            'next_obs': img_next_obs,
            'act': img_act,
            'rew': img_rew,
            'done': False,
            'expert': 0,
            'goal': 0.0,
            'mask': 1.0
        })
        img_obs = img_next_obs
        j += 1

    x, succ = 0, 0
    discount = 0.99
    for j, transition in enumerate(reversed(img_ep_buf)):
        if j == 0:
            succ = succ or transition['goal']
            if not transition['mask']:
                x = transition['rew']
            else:
                reward_estimate = img_ep_buf[-1]['rew']
                if discount < 1:
                    x = reward_estimate / (1 - discount)
                else:
                    x = reward_estimate * float('inf')
        else:
            x = transition['rew'] + transition['mask'] * discount * x

        transition['drtg'] = x
        transition['succ'] = succ
        del transition['goal']

    for transition in img_ep_buf:
        replay_buffer_gen.store_transition(transition)

def compute_img_horizon(loss, loss0):

    x1 = loss0
    x2 = 0.0

    y1 = 0
    y2 = 50

    # self
    a = (y1 - y2) / (x1 - x2)
    b = -x1 * a
    hor = a * loss + b
    if hor > y2:
        hor = y2
    if hor < 0:
        hor = 0
    hor = int(hor)

    return hor


def main():
    global midpoint, imgfreq, loss_limit, loss_std_limit, check_flag
    params = parse_args()

    use_img = params['useimg']

    logdir = utils.get_file_prefix(params)
    params['data_folder'] = utils.get_data_dir(params)
    params['logdir'] = logdir

    utils.seed(params['seed'])
    os.makedirs(logdir)
    ptu.setup(params['device'])
    with open(os.path.join(logdir, 'hparams.json'), 'w') as f:
        json.dump(params, f)

    env, test_env = eu.make_env(params)

    logger = EpochLogger(output_dir=logdir, exp_name=params['exper_name'])

    ssg_module = SSG(params['d_obs'][0], params['d_act'][0], params['hidden_size'], 1).to(ptu.TORCH_DEVICE)

    loss_limit = 0.25
    loss_std_limit = 0.1

    imgfreq = 250

    if params['algo'] == 'sac':
        agent = SAC_SIG(params, ssg_module, use_img)
    elif params['algo'] == 'td3':
        agent = TD3_SIG(params, ssg_module, use_img)
    elif params['algo'] == 'gqe':
        agent = GQE_SIG(params, ssg_module, use_img)
    elif params['algo'] == 'cql':
        agent = CQL_SIG(params, ssg_module, use_img)

    if params['gen_data']:
        expert_policy = eu.make_expert_policy(params, test_env)
        du.generate_offline_data(test_env, expert_policy, params)
    replay_buffer = du.load_replay_buffer(params)

    replay_buffer_gen = utils.ReplayBuffer(int(10000000))

    if params['checkpoint'] is not None:
        agent.load(params['checkpoint'])
    else:
        print('Pretraining Policy')
        os.makedirs(os.path.join(logdir, 'pretrain_plots'))
        for i in trange(params['init_iters']):
            update_info, ssg_loss = agent.update(replay_buffer, replay_buffer_gen, i, 1)
        if params['init_iters'] > 0:
            agent.save(os.path.join(logdir, 'pretrain'))

    if params['rb_checkpoint'] is not None:
        replay_buffer.load(params['rb_checkpoint'])

    i = 0
    n_episodes = 0
    epoch = 0
    robosuite = params['env'] in eu.robosuite_envs

    total_timesteps = params['total_timesteps']

    midpoint = 0.5 * total_timesteps

    count_test = 0
    count_step = 0
    loss_record = 0.0
    loss_record_list = []
    check_list2 = []
    global ep_ret_record
    ep_ret_record = 0.0
    check_flag = 0

    flag = 0
    cnt_check = 0

    print("==========================start training===========================")
    while i < total_timesteps:

        loss_std = np.std(loss_record_list, axis=0)

        if i % 1000 == 0:
            print("now time step is{}".format(i))

        if i < midpoint:

            obs, done, t = env.reset(), False, 0
            ep_buf, rets = [], []
            while not done and t < params['horizon']:

                if i % params['eval_freq'] == 0:
                    ep_ret_record = do_eval(agent, test_env, logger, params['num_eval_episodes'], epoch, i, robosuite, loss_record, loss_std, midpoint, len(replay_buffer))
                    epoch += 1

                # if i % params['save_freq'] == 0:
                #     agent.save(os.path.join(logdir, f'models/{i}'))
                #     replay_buffer.save(os.path.join(logdir, f'rb/{i}'))

                if i < params['start_timesteps']:
                    act = env.action_space.sample()
                else:
                    act = agent.select_action(obs)
                    if params['algo'] == 'td3':
                        act = (agent.select_action(obs) +
                           np.random.normal(0, params['max_action'] * params['expl_noise'],
                                            size=params['d_act']))\
                        .clip(-params['max_action'], params['max_action'])

                next_obs, rew, done, info = env.step(act)

                ep_buf.append({
                    'obs': obs,
                    'next_obs': next_obs,
                    'act': act,
                    'rew': utils.shift_reward(rew, params),
                    'done': done,
                    'expert': 0,
                    'goal': info['goal'] if 'goal' in info else 0,
                    'mask': info['mask'] if 'mask' in info
                    else (1 if t == params['horizon'] else float(not done))

                })
                obs = next_obs
                rets.append(rew)

                if i >= params['start_timesteps']:
                    for _ in range(params['update_n_steps']):
                        if len(replay_buffer) == 0:
                            break
                        update_info, ssg_loss = agent.update(replay_buffer, replay_buffer_gen, i, loss_record)
                        if i % 1000 == 0:
                            print("ssg_loss")
                            print(ssg_loss)
                        logger.store(**update_info)
                        loss_record = ssg_loss
                        loss_record_list.append(loss_record.item())

                hor = compute_img_horizon(loss_record, loss_record_list[0])
                sample_dict = replay_buffer.sample(1)
                obs_start = sample_dict['obs'][0]
                img_gen(obs_start, hor, agent, replay_buffer_gen, ssg_module)

                i += 1
                t += 1

            x, succ = 0, 0
            for j, transition in enumerate(reversed(ep_buf)):
                if j == 0:
                    succ = succ or transition['goal']
                    if not transition['mask']:
                        x = transition['rew']
                    else:
                        reward_estimate = ep_buf[-1]['rew']
                        if params['discount'] < 1:
                            x = reward_estimate / (1 - params['discount'])
                        else:
                            x = reward_estimate * float('inf')
                else:
                    x = transition['rew'] + transition['mask'] * params['discount'] * x

                transition['drtg'] = x
                transition['succ'] = succ
                del transition['goal']

            for transition in ep_buf:
                replay_buffer.store_transition(transition)

            if robosuite:
                env.close()

            logger.store(TrainEpRet=sum(rets), TrainEpLen=len(rets))
            n_episodes += 1

        elif i >= midpoint:

            if i % imgfreq == 0:

                count_test += 1

                obs, done, t = env.reset(), False, 0
                ep_buf, rets = [], []
                while not done and t < params['horizon']:

                    if i % params['eval_freq'] == 0:
                        ep_ret_record = do_eval(agent, test_env, logger, params['num_eval_episodes'], epoch, i, robosuite, loss_record, loss_std, midpoint, len(replay_buffer))
                        epoch += 1

                    # if i % params['save_freq'] == 0:
                    #     agent.save(os.path.join(logdir, f'models/{i}'))
                    #     replay_buffer.save(os.path.join(logdir, f'rb/{i}'))

                    if i < params['start_timesteps']:
                        act = env.action_space.sample()
                    else:
                        act = agent.select_action(obs)
                        if params['algo'] == 'td3':
                            act = (agent.select_action(obs) +
                                   np.random.normal(0, params['max_action'] * params['expl_noise'],
                                                    size=params['d_act'])) \
                                .clip(-params['max_action'], params['max_action'])

                    next_obs, rew, done, info = env.step(act)
                    count_step += 1

                    ep_buf.append({
                        'obs': obs,
                        'next_obs': next_obs,
                        'act': act,
                        'rew': utils.shift_reward(rew, params),
                        'done': done,
                        'expert': 0,
                        'goal': info['goal'] if 'goal' in info else 0,
                        'mask': info['mask'] if 'mask' in info
                        else (1 if t == params['horizon'] else float(not done))

                    })
                    obs = next_obs
                    rets.append(rew)

                    if i >= params['start_timesteps']:
                        for _ in range(params['update_n_steps']):
                            if len(replay_buffer) == 0:
                                break
                            update_info, ssg_loss = agent.update(replay_buffer, replay_buffer_gen, i, loss_record)
                            if i % 1000 == 0:
                                print("ssg_loss")
                                print(ssg_loss)
                            logger.store(**update_info)
                            loss_record = ssg_loss
                            loss_record_list.append(loss_record.item())

                    hor = compute_img_horizon(loss_record, loss_record_list[0])
                    sample_dict = replay_buffer.sample(1)
                    obs_start = sample_dict['obs'][0]
                    img_gen(obs_start, hor, agent, replay_buffer_gen, ssg_module)

                    i += 1
                    t += 1

                x, succ = 0, 0
                for j, transition in enumerate(reversed(ep_buf)):
                    if j == 0:
                        succ = succ or transition['goal']
                        if not transition['mask']:
                            x = transition['rew']
                        else:
                            reward_estimate = ep_buf[-1]['rew']
                            if params['discount'] < 1:
                                x = reward_estimate / (1 - params['discount'])
                            else:
                                x = reward_estimate * float('inf')
                    else:
                        x = transition['rew'] + transition['mask'] * params['discount'] * x

                    transition['drtg'] = x
                    transition['succ'] = succ
                    del transition['goal']

                for transition in ep_buf:
                    replay_buffer.store_transition(transition)

                if robosuite:
                    env.close()

                logger.store(TrainEpRet=sum(rets), TrainEpLen=len(rets))

            else:
                if i >= params['start_timesteps']:
                    for _ in range(params['update_n_steps']):
                        if len(replay_buffer) == 0:
                            break
                        update_info, ssg_loss = agent.update(replay_buffer, replay_buffer_gen, i, loss_record)
                        if i % 1000 == 0:
                            print("ssg_loss")
                            print(ssg_loss)
                        logger.store(**update_info)
                        loss_record = ssg_loss
                        loss_record_list.append(loss_record.item())

                hor = compute_img_horizon(loss_record, loss_record_list[0])
                sample_dict = replay_buffer.sample(1)
                obs_start = sample_dict['obs'][0]
                img_gen(obs_start, hor, agent, replay_buffer_gen, ssg_module)

                if i % params['eval_freq'] == 0:
                    ep_ret_record = do_eval(agent, test_env, logger, params['num_eval_episodes'], epoch, i, robosuite, loss_record, loss_std, midpoint, len(replay_buffer))
                    epoch += 1

                i += 1


def do_eval(agent, test_env, logger, num_eval_episodes, epoch, i, robosuite, loss_record, loss_std, midpoint, rb1len):

    logger.store(rb1len=rb1len)

    global ep_ret_record
    print('Testing Agent')
    for _ in range(num_eval_episodes):
        obs, done, ep_ret, ep_len = test_env.reset(), False, 0, 0
        while not done:
            act = agent.select_action(obs, evaluate=True)
            next_obs, rew, done, info = test_env.step(act)
            ep_ret += rew
            ep_len += 1
            obs = next_obs
        if robosuite:
            test_env.close()
        ep_ret_record = ep_ret
        logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

    logger.log_tabular('Epoch', epoch)
    logger.log_tabular('TotalEnvInteracts', i)
    logger.log_tabular('TestEpRet')
    logger.log_tabular('TestEpLen', average_only=True)
    if epoch == 0:
        logger.log_tabular('AverageTrainEpRet', 0)
        logger.log_tabular('StdTrainEpRet', 0)
        logger.log_tabular('TrainEpLen', 0)
        logger.log_tabular('Q1', 0)
        logger.log_tabular('Q2', 0)
    else:
        logger.log_tabular('TrainEpRet')
        logger.log_tabular('TrainEpLen', average_only=True)
        logger.log_tabular('Q1', average_only=True)
        logger.log_tabular('Q2', average_only=True)

    logger.log_tabular('loss', loss_record)
    logger.log_tabular('loss_std', loss_std)
    logger.log_tabular('Switch', midpoint)
    logger.log_tabular('rb1len', rb1len)

    logger.dump_tabular()

    return ep_ret_record


if __name__ == '__main__':
    main()
