#!/usr/bin/env python3
"""
Beta-DAgger training script using a standard MLP student.

Adapted from beta_dagger_ghn_student.py, but removes GHN and trains a single 3x256 MLP directly.
"""
import os
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"

import math
import random
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
import tyro
from torch.utils.tensorboard import SummaryWriter
import wandb

from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv

# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path(__file__).parent / "train"))

from extraction.models import TeacherAgent
from train.hyper.model import MlpNetwork

@dataclass
class Args:
    exp_name: Optional[str] = None
    seed: int = 1
    torch_deterministic: bool = True
    cuda: bool = True
    track: bool = False
    wandb_project_name: str = "ManiSkill"
    wandb_entity: Optional[str] = None
    wandb_group: str = "DAgger-MLP"
    capture_video: bool = False
    save_model: bool = True
    train_dir: Optional[str] = None

    # Environment specific arguments
    env_id: str = "PickCube-v1"
    num_envs: int = 512
    num_eval_envs: int = 16
    num_steps: int = 50  # Steps per rollout
    num_eval_steps: int = 50
    eval_reconfiguration_freq: Optional[int] = 1
    control_mode: Optional[str] = "pd_joint_delta_pos"

    partial_reset: bool = True

    # Teacher checkpoint (required)
    teacher_checkpoint: str = None

    # DAgger parameters
    total_iterations: int = 100       # Number of DAgger iterations
    beta_decay_rate: float = 0.9      # Exponential decay: beta = p^iteration
    bc_batch_size: int = 512          # BC training batch size
    buffer_size: int = 500000         # Replay buffer size
    bc_updates_per_iter: int = 100    # BC gradient steps per iteration
    learning_rate: float = 3e-4
    min_learning_rate: float = 1e-5   # For cosine annealing
    eval_freq: int = 10               # Evaluation frequency (iterations)
    log_freq: int = 1                 # Logging frequency (iterations)

    # MLP specific parameters
    # The user wants "3x256", meaning 3 hidden layers of width 256.
    # In MlpNetwork config, this corresponds to fc_layers=(256, 256, 256)
    mlp_hidden_layers: tuple = (256, 512, 128)

    # Optimization parameters
    amp: bool = True                  # Automatic mixed precision
    grad_clip: float = 1.0            # Gradient clipping
    
    # Final evaluation
    eval_save_video: bool = True


class ReplayBuffer:
    """Simple circular replay buffer for BC training."""

    def __init__(self, max_size, obs_dim, act_dim, device):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.device = device
        self.obs = torch.zeros((max_size, obs_dim), device=device)
        self.actions = torch.zeros((max_size, act_dim), device=device)

    def add(self, obs, actions):
        batch_size = obs.shape[0]
        if self.ptr + batch_size > self.max_size:
            # Wrap around
            remaining = self.max_size - self.ptr
            self.obs[self.ptr:] = obs[:remaining]
            self.actions[self.ptr:] = actions[:remaining]
            overflow = batch_size - remaining
            self.obs[:overflow] = obs[remaining:]
            self.actions[:overflow] = actions[remaining:]
            self.ptr = overflow
        else:
            self.obs[self.ptr:self.ptr+batch_size] = obs
            self.actions[self.ptr:self.ptr+batch_size] = actions
            self.ptr = (self.ptr + batch_size) % self.max_size

        self.size = min(self.size + batch_size, self.max_size)

    def sample(self, batch_size):
        indices = torch.randint(0, self.size, (batch_size,), device=self.device)
        return self.obs[indices], self.actions[indices]


def dagger_rollout(student_policy, teacher, envs, num_envs, num_steps, beta, device):
    """
    Collect data using beta-mixture policy with MLP student.
    """
    all_obs = []
    all_teacher_actions = []

    # Reset with new configurations at start of each rollout
    obs, _ = envs.reset(options={"reconfiguration_freq": 1})

    # Track which envs are still pre-termination
    active_mask = torch.ones(num_envs, dtype=torch.bool, device=device)
    # Track which envs should have data collected THIS step (terminated last step)
    collect_mask = torch.zeros(num_envs, dtype=torch.bool, device=device)

    for _ in range(num_steps):
        with torch.no_grad():
            # Compute student action
            student_action = student_policy(obs)

            # Teacher labels ALL states
            teacher_action = teacher(obs)

            # Per-step beta mixing: use teacher with probability beta
            use_teacher = torch.rand(num_envs, device=device) < beta
            action = torch.where(use_teacher.unsqueeze(-1), teacher_action, student_action)

        # Collect data
        collect_this_step = active_mask | collect_mask
        if collect_this_step.any():
            all_obs.append(obs[collect_this_step].clone())
            all_teacher_actions.append(teacher_action[collect_this_step].clone())

        # Clear collect_mask after collecting the post-success step
        collect_mask.fill_(False)

        # Step environment
        obs, _reward, terminations, truncations, _infos = envs.step(action)

        # Find which active envs just terminated
        done = terminations | truncations
        just_terminated = active_mask & done

        # Mark these for collection NEXT iteration (post-success step)
        collect_mask = just_terminated.clone()

        # Update active mask
        active_mask = active_mask & ~done

        # Mass reset if <20% still active
        if active_mask.sum() < num_envs * 0.2:
            obs, _ = envs.reset(options={"reconfiguration_freq": 1})
            active_mask = torch.ones(num_envs, dtype=torch.bool, device=device)
            collect_mask.fill_(False)

    if all_obs:
        obs_data = torch.cat(all_obs, dim=0)
        action_data = torch.cat(all_teacher_actions, dim=0)
    else:
        obs_data = torch.empty(0, obs.shape[1], device=device)
        action_data = torch.empty(0, teacher_action.shape[1], device=device)

    return obs_data, action_data


class Logger:
    def __init__(self, log_wandb=False, tensorboard: SummaryWriter = None) -> None:
        self.writer = tensorboard
        self.log_wandb = log_wandb

    def add_scalar(self, tag, scalar_value, step):
        if self.log_wandb:
            wandb.log({tag: scalar_value}, step=step)
        if self.writer:
            self.writer.add_scalar(tag, scalar_value, step)

    def close(self):
        if self.writer:
            self.writer.close()


if __name__ == "__main__":
    args = tyro.cli(Args)
    if args.teacher_checkpoint is None:
        raise ValueError("--teacher-checkpoint required")

    if args.exp_name is None:
        args.exp_name = os.path.basename(__file__)[: -len(".py")]
        run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    else:
        run_name = args.exp_name

    run_name = os.path.join(args.train_dir or "runs_dagger_mlp", run_name)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # ========== ENVIRONMENT SETUP ==========
    env_kwargs = dict(obs_mode="state", render_mode="rgb_array", sim_backend="physx_cuda")
    if args.control_mode:
        env_kwargs["control_mode"] = args.control_mode

    envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs)
    eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, reconfiguration_freq=args.eval_reconfiguration_freq,
                         human_render_camera_configs=dict(shader_pack="default"), **env_kwargs)

    if isinstance(envs.action_space, gym.spaces.Dict):
        envs = FlattenActionSpaceWrapper(envs)
        eval_envs = FlattenActionSpaceWrapper(eval_envs)

    if args.capture_video:
        eval_output_dir = f"{run_name}/videos"
        print(f"Saving eval videos to {eval_output_dir}")
        eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=False,
                                   max_steps_per_video=args.num_eval_steps, video_fps=30)

    envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=True, record_metrics=True)
    eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=True, record_metrics=True)

    n_act = math.prod(envs.single_action_space.shape)
    n_obs = math.prod(envs.single_observation_space.shape)
    max_episode_steps = gym_utils.find_max_episode_steps_value(envs._env)

    # ========== LOGGING SETUP ==========
    if args.track:
        config = vars(args)
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=False,
            config=config,
            name=run_name,
            save_code=True,
            group=args.wandb_group,
            tags=["dagger", "mlp", "bc"]
        )
    writer = SummaryWriter(run_name)
    logger = Logger(log_wandb=args.track, tensorboard=writer)

    # ========== LOAD TEACHER ==========
    print("Loading teacher checkpoint...")
    teacher = TeacherAgent(n_obs, n_act, device=device)
    teacher_state = torch.load(args.teacher_checkpoint, map_location=device)
    teacher.actor_mean.load_state_dict({k.replace('actor_mean.', ''): v
                                         for k, v in teacher_state.items() if 'actor_mean' in k})
    teacher.eval()

    # ========== INITIALIZE MLP STUDENT ==========
    print(f"Initializing MLP Student with layers {args.mlp_hidden_layers}...")
    
    # Using MlpNetwork from hyper.model which was used in the previous code
    # This ensures architecture definition is identical to what GHN predicted
    student_policy = MlpNetwork(
        fc_layers=args.mlp_hidden_layers,
        inp_dim=n_obs,
        out_dim=n_act
    ).to(device)

    optimizer = optim.Adam(student_policy.parameters(), lr=args.learning_rate, eps=1e-5)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=args.total_iterations,
        eta_min=args.min_learning_rate
    )

    scaler = torch.amp.GradScaler('cuda') if args.amp else None

    print(f"MLP Parameters: {sum(p.numel() for p in student_policy.parameters())}")
    
    # ========== INITIALIZE REPLAY BUFFER ==========
    replay_buffer = ReplayBuffer(args.buffer_size, n_obs, n_act, device)

    # ========== MAIN TRAINING LOOP ==========
    print("\nStarting Beta-DAgger training for MLP...")
    start_time = time.time()
    pbar = tqdm.tqdm(range(1, args.total_iterations + 1))

    for iteration in pbar:
        if args.beta_decay_rate == 0:
            beta = 0.0
        else:
            beta = args.beta_decay_rate ** (iteration - 1)

        # ===== COLLECT DAGGER DATA =====
        student_policy.eval()
        obs_data, action_data = dagger_rollout(
            student_policy, teacher, envs, args.num_envs, args.num_steps, beta, device
        )
        replay_buffer.add(obs_data, action_data)

        # ===== BC TRAINING ON MLP =====
        student_policy.train()
        total_loss = 0.0

        for bc_step in range(args.bc_updates_per_iter):
            # Sample data from replay buffer
            obs_batch, action_batch = replay_buffer.sample(args.bc_batch_size)

            optimizer.zero_grad()

            if args.amp:
                with torch.amp.autocast('cuda'):
                    pred_action = student_policy(obs_batch)
                    loss = F.mse_loss(pred_action, action_batch)

                scaler.scale(loss).backward()

                if args.grad_clip > 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(student_policy.parameters(), args.grad_clip)

                scaler.step(optimizer)
                scaler.update()
            else:
                pred_action = student_policy(obs_batch)
                loss = F.mse_loss(pred_action, action_batch)
                loss.backward()

                if args.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(student_policy.parameters(), args.grad_clip)

                optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / args.bc_updates_per_iter

        # ===== LOGGING =====
        if iteration % args.log_freq == 0:
            logger.add_scalar("dagger/beta", beta, iteration)
            logger.add_scalar("bc/loss", avg_loss, iteration)
            
            pbar.set_description(f"Beta: {beta:.3f}, Loss: {avg_loss:.4f}")

        # ===== EVALUATION =====
        if iteration % args.eval_freq == 0:
            student_policy.eval()
            
            eval_obs, _ = eval_envs.reset()
            for _ in range(args.num_eval_steps):
                with torch.no_grad():
                    eval_action = student_policy(eval_obs)
                    eval_obs, _, _, _, eval_infos = eval_envs.step(eval_action)
            
            if "final_info" in eval_infos:
                ep_data = eval_infos["final_info"]["episode"]
                logger.add_scalar("eval/return", ep_data["return"].float().mean(), iteration)
                logger.add_scalar("eval/success_once", ep_data["success_once"].float().mean(), iteration)
                print(f"Iter {iteration} Eval: Return={ep_data['return'].float().mean():.3f}, Success={ep_data['success_once'].float().mean():.3f}")

            student_policy.train()

    # ========== SAVE FINAL CHECKPOINT ==========
    if args.save_model:
        os.makedirs(run_name, exist_ok=True)
        print(f"Saving MLP checkpoint to {run_name}/mlp_final_ckpt.pt")
        # Save just the state dict and config
        torch.save({
            'model_state_dict': student_policy.state_dict(),
            'mlp_layers': args.mlp_hidden_layers,
        }, f"{run_name}/mlp_final_ckpt.pt")

    logger.close()
    envs.close()
    eval_envs.close()
    print("Training Complete!")
