import warnings
warnings.filterwarnings("ignore")
import wandb
import argparse
import os,sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import torch
import numpy as np
import gym, random
from IFactor.utils.wrapper import GymMinAtar, OneHotAction
from myenv.cartpole.cartpole import CartPoleWorldEnv
from IFactor.training.config import CartPole1Config
from IFactor.training.dtrainer_cartpole import Trainer
from IFactor.training.evaluator import Evaluator
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ['SDL_VIDEODRIVER'] = 'dummy'

def main(args):
    wandb.login()
    env_name = 'cartpole1'
    exp_id = args.id

    '''make dir for saving results'''
    result_dir = os.path.join('results', '{}'.format(env_name), '{}'.format(exp_id))
    model_dir = os.path.join(result_dir, 'models')
    gif_dir = os.path.join(result_dir, 'visualization')
    #dir to save learnt models
    os.makedirs(model_dir, exist_ok=True)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    if torch.cuda.is_available() and args.device:
        device = torch.device('cuda')
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        device = torch.device('cpu')
    print('using :', device)

    env = OneHotAction(CartPoleWorldEnv(state_obs_noise=args.noise, distractor=args.distractor))
    test_env = OneHotAction(CartPoleWorldEnv(state_obs_noise=args.noise, distractor=args.distractor))
    eval_env = OneHotAction(CartPoleWorldEnv(state_obs_noise=args.noise, distractor=args.distractor, full_state=True))
    obs_shape = env.observation_space.shape
    action_size = env.action_space.shape[0]
    obs_dtype = np.uint8
    action_dtype = np.float32
    batch_size = args.batch_size
    seq_len = args.seq_len
    rssm_type = args.rssm_type
    disentangle = args.disentangle
    horizon = args.horizon

    config = CartPole1Config(
        env=env_name,
        obs_shape=obs_shape,
        action_size=action_size,
        obs_dtype = obs_dtype,
        action_dtype = action_dtype,
        seq_len = seq_len,
        batch_size = batch_size,
        model_dir=model_dir,
        gif_dir=gif_dir,
        rssm_type=rssm_type,
        disentangle=disentangle,
        horizon=horizon
    )

    config_dict = config.__dict__
    trainer = Trainer(config, device)
    resume_step = 0
    if args.resume:
        resume_step = trainer.resume_training(model_dir, 200000)
    
    evaluator = Evaluator(config, device)

    with wandb.init(project='Cartpole1', entity="yuren", config=config_dict):
        """training loop"""
        print('...training...')
        train_metrics = {}
        eval_metrics = {}
        trainer.collect_seed_episodes(env)
        obs, score = env.reset(), 0
        done = False
        prev_rssmstate = trainer.RSSM._init_rssm_state(1)
        prev_action = torch.zeros(1, trainer.action_size).to(trainer.device)
        episode_actor_ent = []
        scores = []
        best_mean_score = 0
        best_save_path = os.path.join(model_dir, 'models_best.pth')
        best_mean_r2 = 0
        print(f"Enter training iteration")
        for iter in range(1 + resume_step, trainer.config.train_steps+1+resume_step):
            if iter%trainer.config.train_every == 1:
                train_metrics = trainer.train_batch(train_metrics)
            if iter%trainer.config.slow_target_update == 0:
                trainer.update_target()
            if iter%trainer.config.save_every == 1:
                model_dir = trainer.save_model(iter)
            if iter%trainer.config.eval_every == 1:
                eval_score = evaluator.eval_agent(test_env, trainer.RSSM, trainer.ObsEncoder, trainer.ObsDecoder, trainer.ActionModel, iter)
                r2 = evaluator.eval_block_wise(eval_env, trainer.RSSM, trainer.ObsEncoder, trainer.ActionModel, model_dir, data_size_dict = {'s1': 2, 's2': 2, 's3': 1, 's4': 4})
                eval_metrics['s12hs1_r2'] = r2[0]
                eval_metrics['s22hs2_r2'] = r2[1]
                eval_metrics['s32hs3_r2'] = r2[2]
                eval_metrics['s42hs4_r2'] = r2[3]
                eval_metrics['mean_s2hs'] = np.mean([r2[0], r2[1], r2[2], r2[3]])
                eval_metrics['s132hs13_r2'] = r2[4]
                eval_metrics['s242hs24_r2'] = r2[5]
                eval_metrics['hs12s1_r2'] = r2[6]
                eval_metrics['hs22s2_r2'] = r2[7]
                eval_metrics['hs32s3_r2'] = r2[8]
                eval_metrics['hs42s4_r2'] = r2[9]
                eval_metrics['mean_hs2s'] = np.mean([r2[6], r2[7], r2[8], r2[9]])
                eval_metrics['eval_rewards'] = eval_score
                wandb.log(eval_metrics, step=iter)
                eval_s34_r2 = (eval_metrics['s32hs3_r2'] + eval_metrics['s42hs4_r2'])/2.0
                if eval_s34_r2 > best_mean_r2:
                    best_mean_r2 = eval_s34_r2
                    trainer.save_model(iter)
            with torch.no_grad():
                obs_tensor = torch.tensor(obs, dtype=torch.float32)
                if obs.dtype == np.uint8:
                    obs_tensor = obs_tensor.div(255).sub_(0.5)
                embed = trainer.ObsEncoder(obs_tensor.unsqueeze(0).to(trainer.device))
                _, posterior_rssm_state = trainer.RSSM.rssm_observe(embed, prev_action, not done, prev_rssmstate)
                # model_state = trainer.RSSM.get_model_state(posterior_rssm_state)
                asr_state = trainer.RSSM.get_asr_state(posterior_rssm_state)
                action, action_dist = trainer.ActionModel(asr_state)
                action = trainer.ActionModel.add_exploration(action, iter).detach()
                action_ent = torch.mean(action_dist.entropy()).item()
                episode_actor_ent.append(action_ent)

            next_obs, rew, done, _ = env.step(action.squeeze(0).cpu().numpy())
            score += rew

            if done:
                trainer.buffer.add(obs, action.squeeze(0).cpu().numpy(), rew, done)
                train_metrics['train_rewards'] = score
                train_metrics['action_ent'] =  np.mean(episode_actor_ent)
                wandb.log(train_metrics, step=iter)
                scores.append(score)
                if len(scores)>20:
                    scores.pop(0)
                    current_average = np.mean(scores)
                    if current_average>best_mean_score:
                        best_mean_score = current_average
                        print('saving best model with mean score : ', best_mean_score)
                        save_dict = trainer.get_save_dict()
                        torch.save(save_dict, best_save_path)

                obs, score = env.reset(), 0
                done = False
                prev_rssmstate = trainer.RSSM._init_rssm_state(1)
                prev_action = torch.zeros(1, trainer.action_size).to(trainer.device)
                episode_actor_ent = []
            else:
                trainer.buffer.add(obs, action.squeeze(0).detach().cpu().numpy(), rew, done)
                obs = next_obs
                prev_rssmstate = posterior_rssm_state
                prev_action = action

    '''evaluating probably best model'''
    # evaluator.eval_saved_agent(env, best_save_path)

if __name__ == "__main__":
    # python test/cartpole_run.py --noise False --distractor True --id 4
    """there are tonnes of HPs, if you want to do an ablation over any particular one, please add if here"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default='cartpole', help='mini atari env name')
    parser.add_argument("--id", type=str, default='0', help='Experiment ID')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--horizon', type=int, default=12, help='Random seed')
    parser.add_argument('--device', default='cuda', help='CUDA or CPU')
    parser.add_argument('--batch_size', type=int, default=20, help='Batch size')
    parser.add_argument('--seq_len', type=int, default=30, help='Sequence Length (chunk length)')
    parser.add_argument('--rssm_type', type=str, default="continuous", help='rssm_type: continuous or discrete')
    parser.add_argument('--noise', action='store_true', help='noise in the dynamics')
    parser.add_argument('--resume', action='store_true', help='noise in the dynamics')
    parser.add_argument('--distractor', action='store_true', help='distractor in the input observation')
    parser.add_argument('--disentangle', action='store_true')
    parser.add_argument('--no-noise', dest='noise', action='store_false')
    parser.add_argument('--no-distractor', dest='distractor', action='store_false')
    parser.add_argument('--no-disentangle', dest='disentangle', action='store_false')
    parser.set_defaults(noise=True)
    parser.set_defaults(distractor=True)
    parser.set_defaults(disentangle=True)
    parser.set_defaults(resume=False)
    args = parser.parse_args()
    main(args)
