import os
import time
from copy import deepcopy
import uuid

import numpy as np
import pprint
from IS import Importance_Sampling
import gym
import torch
from tqdm import trange
import absl.app
import absl.flags

from .sac import SAC
from .replay_buffer import ReplayBuffer, batch_to_torch, sample_observations, subsample_two_batch, subsample_batch
from .model import TanhGaussianPolicy, FullyConnectedQFunction, SamplerPolicy
from .sampler import StepSampler, TrajSampler, DiffSampler
from .utils import Timer, define_flags_with_default, set_random_seed, print_flags, get_user_flags, prefix_metrics
from .utils import WandBLogger
from .viskit.logging import logger, setup_logger
import pickle

class Flag:
    def __init__(self, data):
        for key, value in data.items():
            setattr(self, key, value)

def sac_train(env_name, diff_env, save_name=""):

    FLAGS_DEF = define_flags_with_default(
        env=env_name,
        max_traj_length=1000,
        replay_buffer_size=50000 * 5 * 5,
        seed=0, #3407,#42,
        device='cuda',
        save_model=True,

        reward_scale=1.0,
        reward_bias=0.0,
        clip_action=0.999,

        policy_arch='256-256',
        qf_arch='256-256',
        orthogonal_init=False,
        policy_log_std_multiplier=1.0,
        policy_log_std_offset=-1.0,

        n_epochs=2000,
        n_train_step_per_epoch=1000,
        rollout_length=5,
        rollout_batch_size=50000,
        real_ratio=0.05,
        threshold=0.01,
        eval_period=1,
        eval_n_trajs=3,
        sample_from_replay_buffer=True,

        batch_size=256,

        sac=SAC.get_default_config({"policy_lr": 1e-4, "qf_lr": 3e-4}), # MOPO
        # sac=SAC.get_default_config({"policy_lr": 5e-5, "qf_lr":1e-4}),
        logging=WandBLogger.get_default_config({"save_name":save_name}),
    )


    FLAGS = Flag(FLAGS_DEF)
    variant = get_user_flags(FLAGS, FLAGS_DEF)
    wandb_logger = WandBLogger(config=FLAGS.logging, variant=variant)
    setup_logger(
        variant=variant,
        exp_id=wandb_logger.experiment_id if not save_name else save_name,
        seed=FLAGS.seed,
        base_log_dir=FLAGS.logging.output_dir,
        include_exp_prefix_sub_dir=False,
        exp_prefix=env_name
    )

    set_random_seed(FLAGS.seed)

    train_sampler = DiffSampler(diff_env, FLAGS.rollout_length)

    eval_sampler = TrajSampler(gym.make(FLAGS.env).unwrapped, FLAGS.max_traj_length, normalizer=diff_env.dataset.normalizer)
    # eval_sampler = TrajSampler(gym.make(FLAGS.env).unwrapped, FLAGS.max_traj_length)#, normalizer=diff_env.dataset.normalizer['observations'])

    replay_buffer = ReplayBuffer(FLAGS.replay_buffer_size)

    real_batch = int(FLAGS.batch_size * FLAGS.real_ratio)
    fake_batch = FLAGS.batch_size - real_batch
    dataset_length = diff_env.dataset.length
    dataset = diff_env.dataset.fields #get_d4rl_dataset(eval_sampler.env)
    # dataset = get_d4rl_dataset(eval_sampler.env)

    dataset['rewards'] = dataset['rewards'] * FLAGS.reward_scale + FLAGS.reward_bias
    # dataset['actions'] = np.clip(dataset['actions'], -FLAGS.clip_action, FLAGS.clip_action)
    dataset['dones'] = dataset['dones'].flatten()
    dataset['rewards'] = dataset['rewards'].flatten()
    # a = np.isnan(dataset['observations']).any()
    #dones = dataset0['dones'].flatten() - dataset['dones'].flatten()
    for key in ['rewards', 'actions', 'dones', 'observations', 'next_observations']:
        dataset[key] = torch.tensor(dataset[key], dtype=torch.float32, device="cuda")

    # a = torch.isnan(dataset["observations"]).any()

    policy = TanhGaussianPolicy(
        eval_sampler.env.observation_space.shape[0],
        eval_sampler.env.action_space.shape[0],
        arch=FLAGS.policy_arch,
        log_std_multiplier=FLAGS.policy_log_std_multiplier,
        log_std_offset=FLAGS.policy_log_std_offset,
        orthogonal_init=FLAGS.orthogonal_init
    )
    # policy.load_state_dict(torch.load(f'./bc_{env_name}_100000.pth'))

    qf1 = FullyConnectedQFunction(
        eval_sampler.env.observation_space.shape[0],
        eval_sampler.env.action_space.shape[0],
        arch=FLAGS.qf_arch,
        orthogonal_init=FLAGS.orthogonal_init,
    )
    target_qf1 = deepcopy(qf1)

    qf2 = FullyConnectedQFunction(
        eval_sampler.env.observation_space.shape[0],
        eval_sampler.env.action_space.shape[0],
        arch=FLAGS.qf_arch,
        orthogonal_init=FLAGS.orthogonal_init,
    )

    '''Importance Sampling'''

    impo_samp_training = Importance_Sampling(diffusion_model=diff_env.diffuser, rw_model=diff_env.rw_model)
    behavior_policy = TanhGaussianPolicy(
        eval_sampler.env.observation_space.shape[0],
        eval_sampler.env.action_space.shape[0],
        arch=FLAGS.policy_arch,
        log_std_multiplier=FLAGS.policy_log_std_multiplier,
        log_std_offset=FLAGS.policy_log_std_offset,
        orthogonal_init=FLAGS.orthogonal_init,
    )
    behavior_policy.to(device=FLAGS.device)
    behavior_policy = SamplerPolicy(behavior_policy, FLAGS.device)
    # behavior_policy.load_state_dict(torch.load(f'./bc_{env_name}_100000.pth'))

    target_qf2 = deepcopy(qf2)

    if FLAGS.sac.target_entropy >= 0.0:
        FLAGS.sac.target_entropy = -np.prod(eval_sampler.env.action_space.shape).item()

    sac = SAC(FLAGS.sac, policy, qf1, qf2, target_qf1, target_qf2)

    sac.torch_to_device(FLAGS.device)


    sampler_policy = SamplerPolicy(policy, FLAGS.device)

    viskit_metrics = {}
    metrics = {}
    metrics['best_score'] = 0
    for epoch in trange(FLAGS.n_epochs):
        # metrics = {'epoch': epoch}
        metrics['epoch'] = epoch

        '''Importance Sampling'''

        if (epoch + 1) % 1 == 0:
            batch_is = subsample_batch(dataset, 10000)
            with torch.no_grad():
                log_prob = sampler_policy.log_prob(batch_is['observations'], batch_is['actions'], return_tensor=True)
                behavior_log_prob = behavior_policy.log_prob(batch_is['observations'], batch_is['actions'], return_tensor=True)
            probs = torch.exp(log_prob - behavior_log_prob) # for random policy
            impo_samp_training.IS_train(batch_is, probs, tensor=True)


        #sample
        with Timer() as rollout_timer:
            # pass
            init_obss = sample_observations(dataset_length, dataset, FLAGS.rollout_batch_size, replay_buffer=replay_buffer, sample_from_buffer = FLAGS.sample_from_replay_buffer)
            train_sampler.sample(
                sampler_policy, init_obss, replay_buffer=replay_buffer, threshold=FLAGS.threshold
            )
            metrics['env_steps'] = replay_buffer.total_steps

        with Timer() as train_timer:
            '''
            for batch_idx in range(FLAGS.n_train_step_per_epoch):
                batch = subsample_batch(dataset, FLAGS.batch_size)
                batch = batch_to_torch(batch, FLAGS.device)
                loss = sac.behavior_cloning(batch, env_name)
            '''
            for batch_idx in range(FLAGS.n_train_step_per_epoch):

                batch = subsample_batch(dataset, real_batch)
                batch2 = replay_buffer.sample(fake_batch)
                for key in batch.keys():
                    batch[key] = torch.cat((batch[key], batch2[key]), dim=0)
                if batch_idx + 1 == FLAGS.n_train_step_per_epoch:
                    metrics.update(
                        prefix_metrics(sac.train(batch), 'sac')
                    )
                else:
                    sac.train(batch)
                


        with Timer() as eval_timer:
            if epoch == 0 or (epoch + 1) % FLAGS.eval_period == 0:
                trajs = eval_sampler.sample(
                    sampler_policy, FLAGS.eval_n_trajs, deterministic=True
                )

                metrics['average_return'] = np.mean([np.sum(t['rewards']) for t in trajs])
                metrics['average_traj_length'] = np.mean([len(t['rewards']) for t in trajs])

                metrics['average_normalized_return'] = np.mean(
                    [eval_sampler.env.get_normalized_score(np.sum(t['rewards'])) for t in trajs]
                )
                metrics['std_normalized_return'] = np.std(
                    [eval_sampler.env.get_normalized_score(np.sum(t['rewards'])) for t in trajs]
                )
                if metrics['average_normalized_return'] > metrics['best_score']:
                    metrics['best_score'] = metrics['average_normalized_return']
                    if FLAGS.save_model:
                        save_data = {'sac': sac, 'variant': variant, 'epoch': epoch}
                        wandb_logger.save_pickle(save_data, f'best_model_{save_name}.pkl')
        
        # metrics['rollout_time'] = rollout_timer()
        metrics['train_time'] = train_timer()
        metrics['eval_time'] = eval_timer()
        metrics['epoch_time'] = train_timer() + eval_timer() #+ rollout_timer()
        wandb_logger.log(metrics)
        viskit_metrics.update(metrics)
        logger.record_dict(metrics)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)


        '''behavior cloning'''
        '''
        batch = subsample_batch(dataset, FLAGS.batch_size)
        batch = batch_to_torch(batch, FLAGS.device)
        loss = sac.behavior_cloning(batch, env_name)
        metrics['BC_loss'] = loss
        wandb_logger.log(metrics)
        viskit_metrics.update(metrics)
        logger.record_dict(metrics)
        logger.dump_tabular(with_prefix=False, with_timestamp=False)
        '''

    if FLAGS.save_model:
        save_data = {'sac': sac, 'variant': variant, 'epoch': epoch}
        wandb_logger.save_pickle(save_data, f'model_{save_name}.pkl')


if __name__ == '__main__':
    absl.app.run(sac_train)
