import numpy as np
import torch
import gym
import argparse
import os
import random
import math
import time
import copy
from pathlib import Path
import yaml


from algo.OTDF_ratio import OTDF
import algo.utils as utils
from envs.env_utils import call_terminal_func
from envs.common import call_env
from tensorboardX import SummaryWriter
import wandb

from info import REF_MIN_SCORE, REF_MAX_SCORE
print("REF_MIN_SCORE", REF_MIN_SCORE)
print("REF_MAX_SCORE", REF_MAX_SCORE)

import d4rl


def eval_policy(policy, env, eval_episodes=10, eval_cnt=None):
    eval_env = env

    avg_reward = 0.
    for episode_idx in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            action = policy.select_action(np.array(state))
            next_state, reward, done, _ = eval_env.step(action)

            avg_reward += reward
            state = next_state
    avg_reward /= eval_episodes

    print("[{}] Evaluation over {} episodes: {}".format(eval_cnt, eval_episodes, avg_reward))

    return avg_reward




def check_nan(data, name):
    data_np = np.array(data)
    has_nan = np.isnan(data_np).any()
    nan_count = np.isnan(data_np).sum()
    print(f"{name}: has_nan = {has_nan}, nan_count = {nan_count}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", default="./logs")
    parser.add_argument("--policy", default="OTDF", help='policy to use, support OTDF')
    parser.add_argument("--env", default="halfcheetah")
    parser.add_argument("--seed", default=0, type=int)            
    parser.add_argument("--save-model", action="store_true")        
    parser.add_argument("--metric", default='cosine', type=str)     
    parser.add_argument('--srctype', default='medium', type=str)
    parser.add_argument("--tartype", default='medium', type=str)
    parser.add_argument("--steps", default=1e6, type=int)
    parser.add_argument("--weight", action="store_true")
    parser.add_argument("--proportion", default=0.8, type=float)
    parser.add_argument("--noreg", action="store_true")
    parser.add_argument("--reg_weight", default=0.5, type=float)
    parser.add_argument("--idn", default=None, type=str)
    parser.add_argument("--start_idx", default=0, type=int)
    parser.add_argument("--wandb_idn", default=None, type=str)
    parser.add_argument("--select_ratio", default=0, type=int)
     

    
    args = parser.parse_args()
    
      
    

    with open(f"{str(Path(__file__).parent.absolute())}/config/{args.policy.lower()}/{args.env.replace('-', '_')}.yaml", 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)

    print("------------------------------------------------------------")
    print("Policy: {}, Env: {}, Seed: {}".format(args.policy, args.env, args.seed))
    print("------------------------------------------------------------")

    ENVNAME = args.env
    
    outdir = args.dir + '/' + args.policy.lower() + '/' + args.env + '-srcdatatype-' + args.srctype + '-tardatatype-' + args.tartype + '/r' + str(args.seed)
    writer = SummaryWriter('{}/tb'.format(outdir))
    if args.save_model and not os.path.exists("{}/models".format(outdir)):
        os.makedirs("{}/models".format(outdir))
    
    if '_' in args.env:
        args.env = args.env.replace('_', '-')


    before_dash, after_dash = args.env.split('-', 1)
  
    src_env_name = args.env.split('-')[0] + '-' + args.srctype + '-v2'
    src_env = gym.make(src_env_name)
    src_env.seed(args.seed)
    tar_env = call_env(config['tar_env_config'])
    tar_env.seed(args.seed)
    src_eval_env = gym.make(src_env_name)
    src_eval_env.seed(args.seed + 100)
    tar_eval_env = call_env(config['tar_env_config'])
    tar_eval_env.seed(args.seed + 100)

    src_env.action_space.seed(args.seed)
    tar_env.action_space.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)

    state_dim = src_env.observation_space.shape[0] 
    action_dim = src_env.action_space.shape[0] 
    max_action = float(src_env.action_space.high[0]) 
    min_action = -max_action
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    config['metric'] = args.metric

    weight = True if args.weight else False
    noreg = True if args.noreg else False

    config.update({
        'state_dim': state_dim,
        'action_dim': action_dim,
        'max_action': max_action,
        'weight': weight,
        'proportion': float(args.proportion),
        'noreg': noreg,
        'reg_weight': args.reg_weight,
    })

    if args.policy.lower() == 'otdf':
        policy = OTDF(config, device)
    else:
        raise NotImplementedError
    
    print("policy", policy)
    
    with open(outdir + 'log.txt','w') as f:
        f.write('\n Policy: {}; Env: {}, seed: {}'.format(args.policy, args.env, args.seed))
        for item in config.items():
            f.write('\n {}'.format(item))

    src_replay_buffer = utils.OTReplayBuffer(state_dim, action_dim, device)
    tar_replay_buffer = utils.ReplayBuffer(state_dim, action_dim, device)

    src_dataset = d4rl.qlearning_dataset(src_env)
    tar_dataset = utils.call_tar_dataset(args.env, args.tartype)

    base_dir = os.path.join(os.getcwd(), "datasets_modified")
    tar_env_underscore = args.env.replace('-', '_')
    target_env_name = f"{tar_env_underscore}_{args.tartype}"

    target_dir = os.path.join(base_dir, target_env_name)
    os.makedirs(target_dir, exist_ok=True)
    save_path = os.path.join(target_dir, args.idn)

    mod_data = np.load(save_path)
    check_nan(mod_data['states'], 'low_sa_samples')
    check_nan(mod_data['actions'], 'low_sa_samples_tran_true')
    check_nan(mod_data['next_states'], 'both_low_samples')
    check_nan(mod_data['rewards'], 'target_hat_lowgap')
    check_nan(mod_data['terminals'], 'sa_low_sas_high_samples')

    tar_states      = tar_dataset['observations']
    tar_actions     = tar_dataset['actions']
    tar_next_states = tar_dataset['next_observations']
    tar_rewards     = tar_dataset['rewards']
    tar_terminals   = tar_dataset['terminals']

    whole_len = len(mod_data['states'])

    tar_states = mod_data['states']
    tar_actions = mod_data['actions']
    tar_next_states = mod_data['next_states']
    tar_rewards = mod_data['rewards'].reshape(-1)
    low_sa_samples_len = len(mod_data['states'])
    source_terminal_info_arr = np.array(mod_data['terminals'].reshape(-1))

    if len(source_terminal_info_arr) < low_sa_samples_len:
        repeat_times = int(np.ceil(low_sa_samples_len / len(source_terminal_info_arr)))
        padded_arr = np.tile(source_terminal_info_arr, repeat_times)[:low_sa_samples_len]
    else:
        padded_arr = source_terminal_info_arr[:low_sa_samples_len]

    tar_terminals = padded_arr
    concat_dataset = {
        'observations': tar_states,
        'actions': tar_actions,
        'next_observations': tar_next_states,
        'rewards': tar_rewards,
        'terminals': tar_terminals
    }



    src_replay_buffer.convert_D4RL(concat_dataset)  
    tar_replay_buffer.convert_D4RL(tar_dataset)      


    if args.select_ratio == 0:
        print("Select ratio 1:7")
        gen_batch = 224
        true_batch = 32
    elif args.select_ratio ==1:
        print("Select ratio 1:3")
        gen_batch = 196
        true_batch = 64
    else:
        print("Select same ratio")
        gen_batch = 128
        true_batch = 128
    
    min_return = float(REF_MIN_SCORE[ENVNAME])
    max_return = float(REF_MAX_SCORE[ENVNAME])


    eval_cnt = 0
    if not noreg:
        policy.train_vae(tar_replay_buffer, config['batch_size'], writer)
    
    eval_src_return = eval_policy(policy, src_eval_env, eval_cnt=eval_cnt)
    eval_tar_return = eval_policy(policy, tar_eval_env, eval_cnt=eval_cnt)
    eval_cnt += 1

    for t in range(int(args.steps)):
        policy.train(src_replay_buffer, tar_replay_buffer, writer, gen_batch, true_batch)

        if (t + 1) % config['eval_freq'] == 0:
            tar_eval_return = eval_policy(policy, tar_eval_env, eval_cnt=eval_cnt)
            writer.add_scalar('test/target return', tar_eval_return, global_step = t+1)
            eval_cnt += 1

            normalized_return = 100 * (tar_eval_return - min_return) / (max_return - min_return)

            if args.save_model:
                policy.save('{}/models/model'.format(outdir))
    writer.close()






