#!/usr/bin/env python3
"""
Beta-DAgger training script using GHN (Graph HyperNetwork) as the student.

Official DAgger implementation:
- Uses beta-mixture policy: action = teacher if random() < beta else student
- Collects (state, teacher_action) pairs for ALL visited states
- Trains via BC on aggregated data
- Beta decays exponentially: beta = p^iteration
"""
import os
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"

import math
import random
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import tqdm
import tyro
from torch.utils.tensorboard import SummaryWriter
import wandb

from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv

# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path(__file__).parent / "train"))

# Import shared model definitions
from extraction.models import TeacherAgent
from train.hyper.core import hyperActor


@dataclass
class Args:
    exp_name: Optional[str] = None
    seed: int = 1
    torch_deterministic: bool = True
    cuda: bool = True
    track: bool = False
    wandb_project_name: str = "ManiSkill"
    wandb_entity: Optional[str] = None
    wandb_group: str = "DAgger-GHN"
    capture_video: bool = False
    save_model: bool = True
    train_dir: Optional[str] = None

    # Environment specific arguments
    env_id: str = "PickCube-v1"
    num_envs: int = 512
    num_eval_envs: int = 16
    num_steps: int = 50  # Steps per rollout
    num_eval_steps: int = 50
    eval_reconfiguration_freq: Optional[int] = 1
    control_mode: Optional[str] = "pd_joint_delta_pos"

    # Note: We use manual termination tracking with diverse configs on reset
    # This parameter is kept for compatibility but not actively used
    partial_reset: bool = True

    # Teacher checkpoint (required)
    teacher_checkpoint: str = None

    # DAgger parameters
    total_iterations: int = 100       # Number of DAgger iterations
    beta_decay_rate: float = 0.9      # Exponential decay: beta = p^iteration
    bc_batch_size: int = 512          # BC training batch size
    buffer_size: int = 500000         # Replay buffer size
    bc_updates_per_iter: int = 100    # BC gradient steps per iteration
    learning_rate: float = 3e-4
    min_learning_rate: float = 1e-5   # For cosine annealing
    eval_freq: int = 10               # Evaluation frequency (iterations)
    log_freq: int = 1                 # Logging frequency (iterations)

    # GHN specific parameters
    meta_batch_size: int = 8          # Number of architectures per training step
    architecture_sampling_mode: str = "uniform"
    dagger_num_archs: int = 8         # Number of architectures during rollout

    # Optimization parameters
    amp: bool = True                  # Automatic mixed precision
    grad_clip: float = 1.0            # Gradient clipping

    # Final evaluation
    eval_save_video: bool = True


class ReplayBuffer:
    """Simple circular replay buffer for BC training."""

    def __init__(self, max_size, obs_dim, act_dim, device):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.device = device
        self.obs = torch.zeros((max_size, obs_dim), device=device)
        self.actions = torch.zeros((max_size, act_dim), device=device)

    def add(self, obs, actions):
        batch_size = obs.shape[0]
        if self.ptr + batch_size > self.max_size:
            # Wrap around
            remaining = self.max_size - self.ptr
            self.obs[self.ptr:] = obs[:remaining]
            self.actions[self.ptr:] = actions[:remaining]
            overflow = batch_size - remaining
            self.obs[:overflow] = obs[remaining:]
            self.actions[:overflow] = actions[remaining:]
            self.ptr = overflow
        else:
            self.obs[self.ptr:self.ptr+batch_size] = obs
            self.actions[self.ptr:self.ptr+batch_size] = actions
            self.ptr = (self.ptr + batch_size) % self.max_size

        self.size = min(self.size + batch_size, self.max_size)

    def sample(self, batch_size):
        indices = torch.randint(0, self.size, (batch_size,), device=self.device)
        return self.obs[indices], self.actions[indices]


def dagger_rollout(ghn_actor, teacher, envs, num_envs, num_steps, num_archs, beta, device):
    """
    Collect data using beta-mixture policy with GHN-generated students.

    Collects (state, teacher_action) for:
    1. Full pre-success trajectory (all steps up to and including success)
    2. Plus 1 additional post-success step (for stability after task completion)

    Args:
        ghn_actor: GHN that generates student networks
        teacher: Teacher policy for labeling
        envs: Vectorized environment (ignore_terminations=True for manual control)
        num_envs: Total number of parallel environments
        num_steps: Steps per rollout
        num_archs: Number of architectures to sample
        beta: Mixture coefficient (probability of using teacher)
        device: Torch device

    Returns:
        obs_data: Tensor of observations (trajectory + 1 post-success)
        action_data: Tensor of teacher actions (labels)
    """
    all_obs = []
    all_teacher_actions = []

    envs_per_arch = num_envs // num_archs

    # Sample random architectures and generate student networks
    arch_indices = np.random.choice(len(ghn_actor.list_of_arcs), num_archs, replace=True)
    shape_inds = torch.stack([ghn_actor.list_of_shape_inds[i] for i in arch_indices])
    ghn_actor.set_graph(arch_indices, shape_inds)
    students = ghn_actor.current_model  # List of generated networks

    # Reset with new configurations at start of each rollout
    obs, _ = envs.reset(options={"reconfiguration_freq": 1})

    # Track which envs are still pre-termination
    active_mask = torch.ones(num_envs, dtype=torch.bool, device=device)
    # Track which envs should have data collected THIS step (terminated last step)
    collect_mask = torch.zeros(num_envs, dtype=torch.bool, device=device)

    for _ in range(num_steps):
        with torch.no_grad():
            # Compute student actions per architecture
            student_actions = []
            for arch_idx in range(num_archs):
                start_env = arch_idx * envs_per_arch
                end_env = start_env + envs_per_arch
                arch_obs = obs[start_env:end_env]
                arch_action = students[arch_idx](arch_obs)
                student_actions.append(arch_action)
            student_action = torch.cat(student_actions, dim=0)

            # Teacher labels ALL states
            teacher_action = teacher(obs)

            # Per-step beta mixing: use teacher with probability beta
            use_teacher = torch.rand(num_envs, device=device) < beta
            action = torch.where(use_teacher.unsqueeze(-1), teacher_action, student_action)

        # Collect data for:
        # 1. ACTIVE envs (pre-success trajectory) - original behavior
        # 2. Envs that terminated LAST step (1 post-success) - additional
        collect_this_step = active_mask | collect_mask
        if collect_this_step.any():
            all_obs.append(obs[collect_this_step].clone())
            all_teacher_actions.append(teacher_action[collect_this_step].clone())

        # Clear collect_mask after collecting the post-success step
        collect_mask.fill_(False)

        # Step environment
        obs, _reward, terminations, truncations, _infos = envs.step(action)

        # Find which active envs just terminated
        done = terminations | truncations
        just_terminated = active_mask & done

        # Mark these for collection NEXT iteration (post-success step)
        collect_mask = just_terminated.clone()

        # Update active mask
        active_mask = active_mask & ~done

        # Mass reset if <20% still active
        if active_mask.sum() < num_envs * 0.2:
            obs, _ = envs.reset(options={"reconfiguration_freq": 1})
            active_mask = torch.ones(num_envs, dtype=torch.bool, device=device)
            collect_mask.fill_(False)  # Don't collect post-reset (not true post-success)

    # Handle empty data case
    if all_obs:
        obs_data = torch.cat(all_obs, dim=0)
        action_data = torch.cat(all_teacher_actions, dim=0)
    else:
        obs_data = torch.empty(0, obs.shape[1], device=device)
        action_data = torch.empty(0, teacher_action.shape[1], device=device)

    return obs_data, action_data


class Logger:
    def __init__(self, log_wandb=False, tensorboard: SummaryWriter = None) -> None:
        self.writer = tensorboard
        self.log_wandb = log_wandb

    def add_scalar(self, tag, scalar_value, step):
        if self.log_wandb:
            wandb.log({tag: scalar_value}, step=step)
        if self.writer:
            self.writer.add_scalar(tag, scalar_value, step)

    def close(self):
        if self.writer:
            self.writer.close()


if __name__ == "__main__":
    args = tyro.cli(Args)
    if args.teacher_checkpoint is None:
        raise ValueError("--teacher-checkpoint required")

    if args.exp_name is None:
        args.exp_name = os.path.basename(__file__)[: -len(".py")]
        run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    else:
        run_name = args.exp_name

    run_name = os.path.join(args.train_dir or "runs_dagger", run_name)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # ========== ENVIRONMENT SETUP ==========
    env_kwargs = dict(obs_mode="state", render_mode="rgb_array", sim_backend="physx_cuda")
    if args.control_mode:
        env_kwargs["control_mode"] = args.control_mode

    envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs)
    eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, reconfiguration_freq=args.eval_reconfiguration_freq,
                         human_render_camera_configs=dict(shader_pack="default"), **env_kwargs)

    if isinstance(envs.action_space, gym.spaces.Dict):
        envs = FlattenActionSpaceWrapper(envs)
        eval_envs = FlattenActionSpaceWrapper(eval_envs)

    if args.capture_video:
        eval_output_dir = f"{run_name}/videos"
        print(f"Saving eval videos to {eval_output_dir}")
        eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=False,
                                   max_steps_per_video=args.num_eval_steps, video_fps=30)

    # Training envs: ignore_terminations=True for manual control
    # We handle terminations manually in dagger_rollout() for diverse configs
    envs = ManiSkillVectorEnv(envs, args.num_envs,
                              ignore_terminations=True,
                              record_metrics=True)

    # Eval envs: always run full episodes without partial reset for consistent evaluation
    eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs,
                                    ignore_terminations=True,
                                    record_metrics=True)

    n_act = math.prod(envs.single_action_space.shape)
    n_obs = math.prod(envs.single_observation_space.shape)
    max_episode_steps = gym_utils.find_max_episode_steps_value(envs._env)

    # ========== LOGGING SETUP ==========
    if args.track:
        config = vars(args)
        config["env_cfg"] = dict(**env_kwargs, num_envs=args.num_envs, env_id=args.env_id,
                                  env_horizon=max_episode_steps, partial_reset=args.partial_reset)
        config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id,
                                       env_horizon=max_episode_steps)
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=False,
            config=config,
            name=run_name,
            save_code=True,
            group=args.wandb_group,
            tags=["dagger", "ghn", "bc"]
        )
    writer = SummaryWriter(run_name)
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    logger = Logger(log_wandb=args.track, tensorboard=writer)

    # ========== LOAD TEACHER ==========
    print("Loading teacher checkpoint...")
    teacher = TeacherAgent(n_obs, n_act, device=device)
    teacher_state = torch.load(args.teacher_checkpoint, map_location=device)
    teacher.actor_mean.load_state_dict({k.replace('actor_mean.', ''): v
                                         for k, v in teacher_state.items() if 'actor_mean' in k})
    teacher.eval()

    # ========== INITIALIZE GHN STUDENT ==========
    print("Initializing GHN hypernetwork...")
    ghn_actor = hyperActor(
        act_dim=n_act,
        obs_dim=n_obs,
        meta_batch_size=args.meta_batch_size,
        device=device,
        architecture_sampling_mode=args.architecture_sampling_mode,
        multi_gpu=False,
    )

    # Optimizer trains GHN parameters (not individual network weights)
    optimizer = optim.Adam(ghn_actor.ghn.parameters(), lr=args.learning_rate, eps=1e-5)

    # LR scheduler: cosine annealing
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=args.total_iterations,
        eta_min=args.min_learning_rate
    )

    # Mixed precision scaler
    scaler = torch.amp.GradScaler('cuda') if args.amp else None

    print(f"GHN initialized with {sum(p.numel() for p in ghn_actor.ghn.parameters())} parameters")
    print(f"Number of possible architectures: {len(ghn_actor.list_of_arcs)}")
    print(f"Optimizations: AMP={args.amp}, Grad Clip={args.grad_clip}")
    print(f"LR Schedule: {args.learning_rate} -> {args.min_learning_rate} (cosine annealing)")
    print(f"Beta decay: {args.beta_decay_rate}^iteration (exponential)")
    print(f"Manual termination tracking: enabled (diverse configs on reset)")

    # ========== INITIALIZE REPLAY BUFFER ==========
    replay_buffer = ReplayBuffer(args.buffer_size, n_obs, n_act, device)

    # ========== MAIN TRAINING LOOP ==========
    print("\nStarting Beta-DAgger training...")
    start_time = time.time()
    pbar = tqdm.tqdm(range(1, args.total_iterations + 1))

    for iteration in pbar:
        # Compute beta (exponential decay)
        # If rate is 0, we want pure student (beta=0) immediately.
        if args.beta_decay_rate == 0:
            beta = 0.0
        else:
            beta = args.beta_decay_rate ** (iteration - 1)

        # ===== COLLECT DAGGER DATA =====
        ghn_actor.eval()
        obs_data, action_data = dagger_rollout(
            ghn_actor, teacher, envs, args.num_envs, args.num_steps,
            args.dagger_num_archs, beta, device
        )
        replay_buffer.add(obs_data, action_data)

        # ===== BC TRAINING ON GHN =====
        ghn_actor.train()
        total_loss = 0.0

        for bc_step in range(args.bc_updates_per_iter):
            # Sample new architectures for this training step
            ghn_actor.change_graph(repeat_sample=False)

            # Sample data from replay buffer
            effective_batch_size = (args.bc_batch_size // args.meta_batch_size) * args.meta_batch_size
            obs_batch, action_batch = replay_buffer.sample(effective_batch_size)

            optimizer.zero_grad()

            # Forward pass with mixed precision
            if args.amp:
                with torch.amp.autocast('cuda'):
                    pred_action = ghn_actor(obs_batch, track=False)
                    loss = F.mse_loss(pred_action, action_batch)

                scaler.scale(loss).backward()

                if args.grad_clip > 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(ghn_actor.ghn.parameters(), args.grad_clip)

                scaler.step(optimizer)
                scaler.update()
            else:
                pred_action = ghn_actor(obs_batch, track=False)
                loss = F.mse_loss(pred_action, action_batch)
                loss.backward()

                if args.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(ghn_actor.ghn.parameters(), args.grad_clip)

                optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / args.bc_updates_per_iter

        # ===== LOGGING =====
        if iteration % args.log_freq == 0:
            logger.add_scalar("dagger/beta", beta, iteration)
            logger.add_scalar("bc/loss", avg_loss, iteration)
            logger.add_scalar("bc/learning_rate", optimizer.param_groups[0]["lr"], iteration)
            logger.add_scalar("dagger/buffer_size", replay_buffer.size, iteration)
            logger.add_scalar("dagger/samples_collected", obs_data.shape[0], iteration)

            pbar.set_description(f"Beta: {beta:.3f}, Loss: {avg_loss:.4f}")

        # ===== EVALUATION =====
        if iteration % args.eval_freq == 0:
            print(f"\n[Iter {iteration}] Evaluating all {len(ghn_actor.list_of_arcs)} architectures...")
            ghn_actor.eval()

            # Test ALL architectures
            num_archs = len(ghn_actor.list_of_arcs)
            all_indices = np.arange(num_archs)
            all_shape_inds = torch.stack([ghn_actor.list_of_shape_inds[i] for i in all_indices])
            ghn_actor.set_graph(all_indices, all_shape_inds)

            # Create temporary eval env with all architectures (1 env per architecture)
            temp_eval_env = gym.make(args.env_id, num_envs=num_archs, reconfiguration_freq=1, **env_kwargs)
            if isinstance(temp_eval_env.action_space, gym.spaces.Dict):
                temp_eval_env = FlattenActionSpaceWrapper(temp_eval_env)
            temp_eval_env = ManiSkillVectorEnv(temp_eval_env, num_archs, ignore_terminations=True, record_metrics=True)

            eval_obs, _ = temp_eval_env.reset()

            # Run full evaluation - each env uses its own architecture
            for step_idx in range(args.num_eval_steps):
                with torch.no_grad():
                    actions = []
                    for arch_idx in range(num_archs):
                        env_obs = eval_obs[arch_idx:arch_idx+1]
                        env_action = ghn_actor.current_model[arch_idx](env_obs)
                        actions.append(env_action)
                    eval_action = torch.cat(actions, dim=0)

                    eval_obs, eval_rew, eval_terminations, eval_truncations, eval_infos = temp_eval_env.step(eval_action)

            # Get metrics from final_info
            if "final_info" in eval_infos:
                ep_data = eval_infos["final_info"]["episode"]
                all_returns = ep_data["return"].float()
                all_success_once = ep_data["success_once"].float()
                all_success_at_end = ep_data["success_at_end"].float()

                logger.add_scalar("eval/return", all_returns.mean(), iteration)
                logger.add_scalar("eval/success_once", all_success_once.mean(), iteration)
                logger.add_scalar("eval/success_at_end", all_success_at_end.mean(), iteration)

                print(f"Iter {iteration} Eval (all {num_archs} archs): "
                      f"return={all_returns.mean():.3f} "
                      f"success_once={all_success_once.mean():.3f} "
                      f"success_at_end={all_success_at_end.mean():.3f}")

            temp_eval_env.close()
            ghn_actor.train()

    # ========== SAVE FINAL CHECKPOINT ==========
    if args.save_model:
        os.makedirs(run_name, exist_ok=True)
        print(f"Saving GHN checkpoint to {run_name}/ghn_final_ckpt.pt")
        torch.save({
            'ghn_state_dict': ghn_actor.ghn.state_dict(),
            'ghn_config': ghn_actor.ghn_config,
            'architectures': ghn_actor.list_of_arcs,
        }, f"{run_name}/ghn_final_ckpt.pt")

    logger.close()
    envs.close()
    eval_envs.close()

    # ========== FINAL COMPREHENSIVE EVALUATION ==========
    print("\n" + "="*80)
    print("FINAL EVALUATION: Testing all architectures")
    print("="*80)

    import pandas as pd

    eval_dir = Path(run_name) / "final_evaluation"
    eval_dir.mkdir(parents=True, exist_ok=True)

    ghn_actor.eval()

    all_architectures = ghn_actor.list_of_arcs
    num_archs = len(all_architectures)

    print(f"\nEvaluating {num_archs} architectures in parallel")
    print(f"  Video recording: {'ENABLED' if args.eval_save_video else 'DISABLED'}")

    # Create parallel eval environment
    parallel_eval_env = gym.make(args.env_id, num_envs=num_archs, reconfiguration_freq=1, **env_kwargs)
    if isinstance(parallel_eval_env.action_space, gym.spaces.Dict):
        parallel_eval_env = FlattenActionSpaceWrapper(parallel_eval_env)

    if args.eval_save_video:
        video_dir = eval_dir / "videos"
        video_dir.mkdir(parents=True, exist_ok=True)
        parallel_eval_env = RecordEpisode(
            parallel_eval_env,
            output_dir=str(video_dir),
            save_trajectory=False,
            max_steps_per_video=args.num_eval_steps,
            video_fps=30
        )
    parallel_eval_env = ManiSkillVectorEnv(parallel_eval_env, num_archs, ignore_terminations=True, record_metrics=True)

    # Generate all student networks
    all_indices = np.arange(num_archs)
    all_shape_inds = torch.stack([ghn_actor.list_of_shape_inds[i] for i in all_indices])
    ghn_actor.set_graph(all_indices, all_shape_inds)
    all_students = ghn_actor.current_model

    obs, _ = parallel_eval_env.reset()

    for step in range(args.num_eval_steps):
        with torch.no_grad():
            actions = []
            for arch_idx in range(num_archs):
                arch_obs = obs[arch_idx:arch_idx+1]
                arch_action = all_students[arch_idx](arch_obs)
                actions.append(arch_action)
            actions = torch.cat(actions, dim=0)

        obs, reward, terminations, truncations, infos = parallel_eval_env.step(actions)

    parallel_eval_env.close()

    # Collect and save results
    results = []
    if "final_info" in infos:
        ep_data = infos["final_info"]["episode"]
        for arch_idx, architecture in enumerate(all_architectures):
            num_params = sum(p.numel() for p in all_students[arch_idx].parameters())
            result = {
                'architecture': str(tuple(architecture)),
                'num_params': num_params,
                'return': ep_data["return"][arch_idx].item(),
                'success_once': ep_data["success_once"][arch_idx].item(),
                'success_at_end': ep_data["success_at_end"][arch_idx].item(),
            }
            results.append(result)
            print(f"[{arch_idx+1}/{num_archs}] {architecture}: "
                  f"Return={result['return']:.3f}, "
                  f"Success(once)={result['success_once']:.3f}")

    # Save to CSV
    csv_path = eval_dir / "all_architectures_results.csv"
    df = pd.DataFrame(results)
    df.to_csv(csv_path, index=False)

    print(f"\n" + "="*80)
    print(f"Results saved to: {csv_path}")
    print("="*80)

    if len(df) > 0:
        print("\nSummary Statistics:")
        print(f"  Best Return: {df['return'].max():.3f}")
        print(f"  Best Success Once: {df['success_once'].max():.3f}")
        print(f"  Avg Return: {df['return'].mean():.3f} +/- {df['return'].std():.3f}")
        print(f"  Avg Success Once: {df['success_once'].mean():.3f}")

    print(f"\nTraining complete!")
    print(f"  GHN checkpoint: {run_name}/ghn_final_ckpt.pt")
    print(f"  Results CSV: {csv_path}")
