import glob
import os
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
import random

import numpy as np
import tqdm
from absl import app, flags
from ml_collections import config_flags
from tensorboardX import SummaryWriter
import wandb
from wandb_log import init_wandb_or_disable

import sys
sys.path.append("../")
from jax_rl.agents import AWACLearner, PESACLearner
from jax_rl.datasets import ReplayBuffer
from jax_rl.evaluation_fetch import evaluate
from jax_rl.utils_fetch import make_env
from jax_rl.utils import _should_record

FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'FetchPushDense-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('eval_episodes', 10,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 10000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 256, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')
flags.DEFINE_integer('start_training', int(1e4),
                     'Number of training steps to start training.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_boolean('save_video', True, 'Save videos during evaluation.')
flags.DEFINE_integer('video_interval', 10000, 'Video saving interval.')
flags.DEFINE_boolean('perl_value', True, 'Use PERL for value function')
flags.DEFINE_boolean('perl_policy', True, 'Use PERL for policy function')
flags.DEFINE_string('group', '', 'Also use PERL for value function')
flags.DEFINE_float('equiv_wd', 1e-6, 'Policy Equivariant weight decay')
flags.DEFINE_float('basic_wd', 1e-6, 'Policy Basic weight decay')
flags.DEFINE_float('cequiv_wd', 0, 'Critic Equivariant weight decay')
flags.DEFINE_float('cbasic_wd', 0, 'Critic Basic weight decay')
flags.DEFINE_list('hidden_dims', [256,256], 'Dimension of hidden layers') # Ant: [320,320]
flags.DEFINE_boolean('small_init', True, 'Use smaller init for last policy layer')
flags.DEFINE_boolean('old_rep',False,"Use original rep allocation heuristic")
flags.DEFINE_boolean("gan_betas", False, "use GAN betas or not")
flags.DEFINE_float("tau", 0.005, 'tau for SAC updates')
flags.DEFINE_boolean('standardize',False,"Use equivariant standardization of the state")
flags.DEFINE_float('clipping', 0.5, 'Gradient Norm magnitude at which to clip')
flags.DEFINE_integer('ncritic', 1, 'Number of critic updates per policy update')

flags.DEFINE_list('g_hidden_dims', [128,128], 'Dimension of gater hidden layers')
flags.DEFINE_list('d_hidden_dims', [256,256], 'Dimension of dynamics hidden layers')

flags.DEFINE_integer('dynamics_batch_size', 256, 'Dynamics batch size')
flags.DEFINE_integer('dynamics_update_freq', 1, 'Number of intervals for dynamics update')
flags.DEFINE_integer('dynamics_grad_steps', 2, 'Number of gradient steps for dynamics update')
flags.DEFINE_float('dynamics_grad_clipping', 1.0, 'Gradient Norm magnitude at which to clip for dyn model updates')
flags.DEFINE_float('dyn_lr', 3e-4, 'Learning rate for dynamics updates')

flags.DEFINE_integer('lambda_batch_size', 256, 'Lambda batch size')
flags.DEFINE_integer('lambda_train_start', int(5e4), 'Number of steps before lambda training starts')
flags.DEFINE_integer('lambda_update_freq', 1, 'Number of intervals for lambda update')
flags.DEFINE_integer('lambda_grad_steps', 1, 'Number of lambda grad steps per update')
flags.DEFINE_float('lam_grad_clipping', 0.5, 'Gradient Norm magnitude at which to clip for lambda updates')
flags.DEFINE_float('lambda_q_lr', 1e-4, 'Learning rate for lambda updates')
flags.DEFINE_float('lambda_p_lr', 1e-4, 'Learning rate for lambda updates')
flags.DEFINE_float('gater_bias', 1.2, 'Initial bias for gater network')

flags.DEFINE_integer('lam_stats_train_start', int(1e4), 'Number of steps before lambda stats training starts')
# flags.DEFINE_integer('lam_stats_update_freq', 1000, 'Number of intervals for lambda stats updates')
# flags.DEFINE_integer('lam_stats_grad_steps', 20, 'Number of lambda stats grad steps per update')
flags.DEFINE_integer('threshold_steps', 100, 'Number of intervals for threshold updates')

flags.DEFINE_boolean('det_lam', True, 'Use lambda as probability or deterministic 0/1')
flags.DEFINE_float('exp_tau', 0.8, 'Expectile tau for gater training')
flags.DEFINE_integer('exp_samples', 4, 'Number of samples for expectile estimation. If 0, use batch actions')
flags.DEFINE_boolean('adaptive_k_std', False, 'Use adaptive k_std for gater stats')


config_flags.DEFINE_config_file(
    'config',
    'configs/sac_default.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)
flags.DEFINE_string('wandb_project', None, "W&B project name")
flags.DEFINE_string('wandb_entity', 'aerl', "W&B entity name")
flags.DEFINE_string('wandb_run_name', None, 'W&B run name')
flags.DEFINE_string('wandb_group', None, 'W&B group name')
flags.DEFINE_list('wandb_tags', [], 'W&B tags')
flags.DEFINE_string('wandb_mode', 'online', 'W&B mode: online, offline, disabled')
flags.DEFINE_boolean('debug', False, 'Debug mode, no logging')

from representations import environment_symmetries
from emlp.groups import *
from jax import jit,vmap


def main(_):
    method='perl'
    print("CWD = ", os.getcwd())
    fname = f'{method}_seed{FLAGS.seed}'
    save_dir = FLAGS.save_dir
    
    if not FLAGS.debug:
        summary_writer = SummaryWriter(
            os.path.join(save_dir, FLAGS.env_name, fname))

    video_train_folder = None
    video_eval_folder = None
    # if FLAGS.save_video:
    #     video_train_folder = os.path.join(save_dir, 'video', 'train')
    #     video_eval_folder = os.path.join(save_dir, 'video', 'eval')
    # else:
    #     video_train_folder = None
    #     video_eval_folder = None

    env = make_env(FLAGS.env_name, FLAGS.seed, video_train_folder)
    # eval_env = make_env(FLAGS.env_name, FLAGS.seed + 42, video_eval_folder)

    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    kwargs = dict(FLAGS.config)
    if FLAGS.group:
        kwargs.update(environment_symmetries[FLAGS.env_name + '_' + FLAGS.group])
    else:
        kwargs.update(environment_symmetries[FLAGS.env_name])
    
    hidden_dims = tuple(int(hd) for hd in FLAGS.hidden_dims)
    g_hidden_dims = tuple(int(hd) for hd in FLAGS.g_hidden_dims)
    d_hidden_dims = tuple(int(hd) for hd in FLAGS.d_hidden_dims)
    kwargs['hidden_dims'] = hidden_dims
    kwargs['g_hidden_dims'] = g_hidden_dims
    kwargs['d_hidden_dims'] = d_hidden_dims
    

    kwargs['perl_value']=FLAGS.perl_value
    kwargs['perl_policy']=FLAGS.perl_policy
    if FLAGS.group:
        kwargs['symmetry_group']=eval(FLAGS.group)
        
    kwargs['state_rep'] = kwargs['state_rep'](kwargs['symmetry_group'])
    kwargs['action_rep'] = kwargs['action_rep'](kwargs['symmetry_group'])
    if FLAGS.old_rep:
        kwargs.pop('middle_rep',None)
        
    model_bits = [
        "PE-SAC",
        FLAGS.env_name,
        "VAL" if FLAGS.perl_value else "NO-VAL",
        "POL" if FLAGS.perl_policy else "NO-POL",
    ]
    if FLAGS.group:
        model_bits.append(FLAGS.group)
    model_name = "_".join(model_bits)
    
    if FLAGS.wandb_run_name is None:
        FLAGS.wandb_run_name = model_name + f"_seed{FLAGS.seed}"
    if not FLAGS.debug:
        init_wandb_or_disable(FLAGS, fname, model_name=model_name, extra_cfg=kwargs)

    replay_buffer_size = kwargs.pop('replay_buffer_size')
    action_dim = env.action_space.shape[0] if kwargs['action_space']=='continuous' else 1
    replay_buffer = ReplayBuffer(env.observation_space, action_dim,
                                 replay_buffer_size or FLAGS.max_steps,kwargs['state_rep'],
                                 kwargs['state_transform'],kwargs['inv_state_transform'],
                                 FLAGS.standardize)

    algo = kwargs.pop('algo')
    assert algo=='sac', "other RL algos not yet supported"
    
    if algo == 'sac':
        agent = PESACLearner(FLAGS.seed,
                            env.observation_space.sample()[np.newaxis],
                            np.asarray(env.action_space.sample())[None], 
                            actor_basic_wd=FLAGS.basic_wd,
                            actor_equiv_wd=FLAGS.equiv_wd,
                            critic_basic_wd=FLAGS.cbasic_wd,
                            critic_equiv_wd=FLAGS.cequiv_wd,
                            standardizer=replay_buffer.running_stats.standardize if FLAGS.standardize else None,
                            clipping=FLAGS.clipping,
                            gan_betas=FLAGS.gan_betas,
                            tau=FLAGS.tau,
                            det_lam=FLAGS.det_lam,
                            gater_q_lr=FLAGS.lambda_q_lr,
                            gater_p_lr=FLAGS.lambda_p_lr,
                            dyn_lr=FLAGS.dyn_lr,
                            exp_tau=FLAGS.exp_tau,
                            exp_samples=FLAGS.exp_samples,
                            adaptive_k_std=FLAGS.adaptive_k_std,
                            dyn_clipping=FLAGS.dynamics_grad_clipping,
                            lam_clipping=FLAGS.lam_grad_clipping,
                            thr_steps=FLAGS.threshold_steps,
                            lam_train_start=FLAGS.lambda_train_start,
                            gater_bias=FLAGS.gater_bias,
                            **kwargs)
    else:
        raise NotImplementedError()
    
    # import jax
    # from flax.traverse_util import flatten_dict
    # def count_parameters(params):
    #     return sum(int(np.prod(p.shape)) for p in jax.tree_util.tree_leaves(params))
    
    # def split_counts(params, body_keys=("body_rpp", "body_emlp"), 
    #                     head_keys=("mean_head", "std_head")):
    #     flat = flatten_dict(params, sep='/')
    #     body = sum(int(np.prod(v.shape)) for k,v in flat.items() if any(bk in k for bk in body_keys))
    #     heads = sum(int(np.prod(v.shape)) for k,v in flat.items() if any(hk in k for hk in head_keys))
    #     other = count_parameters(params) - body - heads
    #     return body, heads, other

    # body, heads, other = split_counts(agent.sac.actor.params)
    # print("Actor params → body:", body, " heads:", heads, " other:", other)

    # networks = {
    #     "Actor": agent.sac.actor.params,
    #     "Critic": agent.sac.critic.params,
    # }

    # for name, params in networks.items():
    #     if params is not None:
    #         param_count = count_parameters(params)
    #     print(f"{name} network has {param_count} parameters.")
    
    # raise
        
        # Parameter counts:
        # Ant-v2

    @jit
    def reprhos(x):
        gs = kwargs['symmetry_group'].samples(x.shape[0])
        ring = vmap(kwargs['state_rep'].rho_dense)(gs)
        routg = vmap(kwargs['action_rep'].rho_dense)(gs)
        return ring,routg

    eval_returns = []
    (observation, _), done = env.reset(seed=FLAGS.seed), False
    for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
                       smoothing=0.1,
                       disable=not FLAGS.tqdm):
        if i < FLAGS.start_training:
            action = env.action_space.sample()
        else:
            action = agent.sample_actions(observation)
        next_observation, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        
        if not done or truncated:
            mask = 1.0
        else:
            mask = 0.0
            
        if i >= FLAGS.lam_stats_train_start and mask == 1.0:
            stats_info = agent.update_gater_stats(observation, action)

        replay_buffer.insert(observation, action, reward, mask,
                             next_observation)
        observation = next_observation

        if done:
            (observation, _), done = env.reset(seed=FLAGS.seed + i), False
            
            ep_log = {}
            for k, v in info['episode'].items():
                ep_log[f"training/{k}"] = v
            step_to_use = info.get('total', {}).get('timesteps', i)
            if not FLAGS.debug:
                wandb.log(ep_log, step=step_to_use)
            
                for k, v in info['episode'].items():
                    summary_writer.add_scalar(f'training/{k}', v,
                                            info['total']['timesteps'])

        if i >= FLAGS.start_training:
            for _ in range(FLAGS.ncritic-1):
                batch = replay_buffer.sample(FLAGS.batch_size)
                update_info = agent.update(batch,update_policy=False)
            batch = replay_buffer.sample(FLAGS.batch_size)
            update_info = agent.update(batch)
            
            if i % FLAGS.dynamics_update_freq == 0:                
                # dynamics models training
                for _ in range(FLAGS.dynamics_grad_steps):
                    l_batch = replay_buffer.sample_valids(FLAGS.dynamics_batch_size)
                    d_info = agent.update_dynamics(l_batch)
                update_info.update(d_info)
                
            # lam stats update
            if i >= FLAGS.lam_stats_train_start:    
            #     for _ in range(FLAGS.lam_stats_grad_steps):
            #         l_batch = replay_buffer.sample_valids(FLAGS.dynamics_batch_size)
            #         stats_info = agent.update_gater_stats_batch(l_batch)
                update_info.update(stats_info)                
                
            # lambda training
            if i >= FLAGS.lambda_train_start and i % FLAGS.lambda_update_freq == 0:
                for _ in range(FLAGS.lambda_grad_steps):
                    l_batch = replay_buffer.sample_valids(FLAGS.lambda_batch_size)
                    lam_info = agent.update_lambda(l_batch)
                update_info.update(**lam_info, **d_info)

            if i % FLAGS.log_interval == 0 and not FLAGS.debug:
                for k, v in update_info.items():
                    summary_writer.add_scalar(f'training/{k}', v, i)
                summary_writer.flush()
                wandb.log({f"training/{k}": v for k, v in update_info.items()}, step=i)

        if i % FLAGS.eval_interval == 0:
            
            record_now = _should_record(i, FLAGS.save_video, FLAGS.video_interval)
            eval_vid_dir = (os.path.join(FLAGS.save_dir, FLAGS.env_name, fname, 'eval_videos', f"step_{i}")) if record_now else None
            
            eval_env = make_env(FLAGS.env_name, FLAGS.seed + 42, save_folder=eval_vid_dir)
            
            eval_stats = evaluate(agent,eval_env,FLAGS.eval_episodes,FLAGS.seed)
            
            step_to_use = info.get('total', {}).get('timesteps', i)
            if not FLAGS.debug:
                wandb.log({f"evaluation/average_{k}s": v for k, v in eval_stats.items()}, step=step_to_use)

                for k, v in eval_stats.items():
                    summary_writer.add_scalar(f'evaluation/average_{k}s', v,
                                            info['total']['timesteps'])
                summary_writer.flush()

                if eval_vid_dir is not None:
                    mp4s = sorted(glob.glob(os.path.join(eval_vid_dir, '*.mp4')))
                    for ep_idx, vf in enumerate(mp4s):
                        wandb.log(
                            {f"eval/video/ep{ep_idx}": wandb.Video(vf, fps=30, format="mp4",
                                                                   caption=f"{FLAGS.env_name} @ step {step_to_use} (ep {ep_idx})")},
                            step=step_to_use
                        )
                try:
                    eval_env.close()
                except Exception:
                    pass

                eval_returns.append(
                    (info['total']['timesteps'], eval_stats['return']))
                np.savetxt(os.path.join(FLAGS.save_dir, FLAGS.env_name, fname, f'{FLAGS.seed}.txt'),
                        eval_returns,
                        fmt=['%d', '%.1f'])

    wandb.finish()

if __name__ == '__main__':
    app.run(main)
