import os
import tempfile
import shutil
from pathlib import Path
from utils.gcp import upload_model_to_gcp
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch


def process_skyrl_model(ckpt_dir, step_num=None):
    """
    Process SkyRL FSDP checkpoint and convert to HuggingFace format.

    Args:
        ckpt_dir: Path to the checkpoint directory containing global_step_* subdirectories
        step_num: Optional specific step number to use. If None, uses the latest checkpoint.

    Returns:
        Path to temporary directory containing the processed HuggingFace model
    """
    ckpt_path = Path(ckpt_dir)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist")

    if step_num is not None:
        # Use specific step number
        latest_checkpoint = ckpt_path / f"global_step_{step_num}"
        if not latest_checkpoint.exists():
            raise FileNotFoundError(
                f"Checkpoint directory {latest_checkpoint} does not exist"
            )
        print(f"Using specified checkpoint: {latest_checkpoint}")
    else:
        # Look for global_step directories and use the latest
        checkpoint_dirs = [
            d
            for d in ckpt_path.iterdir()
            if d.is_dir() and d.name.startswith("global_step_")
        ]
        if not checkpoint_dirs:
            raise FileNotFoundError("No global_step directories found")

        # Sort by step number and take the last one
        latest_checkpoint = sorted(
            checkpoint_dirs, key=lambda x: int(x.name.split("_")[-1])
        )[-1]
        print(f"Using latest checkpoint: {latest_checkpoint}")

    # The model weights are in policy/ directory
    policy_dir = latest_checkpoint / "policy"
    if not policy_dir.exists():
        raise FileNotFoundError(f"Policy directory not found at {policy_dir}")

    # Create temporary directory for HF format
    temp_dir = tempfile.mkdtemp(prefix="skyrl_model_")

    try:
        print("Loading FSDP checkpoint...")

        # Load tokenizer and config
        hf_tokenizer_dir = policy_dir / "huggingface"
        if hf_tokenizer_dir.exists():
            config = AutoConfig.from_pretrained(hf_tokenizer_dir)
            tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_dir)
        else:
            # Fallback to base model - adjust this based on your model
            config = AutoConfig.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
            tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

        # Create model with the config
        model = AutoModelForCausalLM.from_config(config)

        # Load sharded checkpoint files and merge them
        print("Loading and merging sharded checkpoint files...")
        rank_files = list(policy_dir.glob("model_world_size_*_rank_*.pt"))
        print(f"Found {len(rank_files)} rank files")

        if not rank_files:
            raise FileNotFoundError(f"No FSDP checkpoint files found in {policy_dir}")

        # Load all rank files and merge the state dict
        merged_state_dict = {}
        for rank_file in sorted(rank_files):
            print(f"Loading {rank_file.name}...")
            rank_state = torch.load(rank_file, map_location="cpu")

            for key, value in rank_state.items():
                # Handle DTensor conversion
                if hasattr(value, "_local_tensor"):
                    tensor_value = value._local_tensor
                elif hasattr(value, "to_local"):
                    tensor_value = value.to_local()
                else:
                    tensor_value = value

                # For FSDP sharded parameters, we need to reconstruct the full tensor
                if key in merged_state_dict:
                    # This is a sharded parameter - concatenate along the sharding dimension
                    existing_tensor = merged_state_dict[key]

                    # For most parameters, FSDP shards along dim 0
                    if tensor_value.dim() > 0 and existing_tensor.dim() > 0:
                        merged_state_dict[key] = torch.cat(
                            [existing_tensor, tensor_value], dim=0
                        )
                    else:
                        # For scalars or if dimensions don't match, keep the first one
                        pass
                else:
                    merged_state_dict[key] = tensor_value

        print(f"Loaded {len(merged_state_dict)} parameters from checkpoint")

        # Load the merged state dict into the model
        missing_keys, unexpected_keys = model.load_state_dict(
            merged_state_dict, strict=False
        )
        if missing_keys:
            print(f"Missing keys: {missing_keys[:10]}...")  # Show first 10
        if unexpected_keys:
            print(f"Unexpected keys: {unexpected_keys[:10]}...")  # Show first 10

        # Save model in HuggingFace format
        print("Saving model in HuggingFace format...")
        model.save_pretrained(temp_dir)
        tokenizer.save_pretrained(temp_dir)

        return temp_dir

    except Exception as e:
        # Clean up on error
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
        raise e


def is_skyrl_checkpoint(model_path):
    """Check if the model path contains SkyRL checkpoint structure."""
    path = Path(model_path)
    if not path.exists():
        return False

    # Look for global_step directories
    checkpoint_dirs = [
        d for d in path.iterdir() if d.is_dir() and d.name.startswith("global_step_")
    ]
    return len(checkpoint_dirs) > 0


def upload_model_to_gcp_main():
    """Upload the trained model to GCP."""
    # Use environment variables like the HF upload script
    gcp_model_name = os.environ.get(
        "GCP_MODEL_NAME", "advisor-models/reviews-level-advisor-qwen2.5-7b-v1"
    )
    ckpt_dir = os.environ.get(
        "CKPT_DIR", "/root/ckpts/reviews_level_v1_qwen2.5_7b_5ep_hf"
    )
    bucket_name = "advisor-models"

    if not bucket_name:
        print("Warning: GCP_BUCKET_NAME not set. Skipping GCP upload.")
        return

    print(f"Uploading model to GCP as {gcp_model_name}")
    print(f"Checkpoint directory: {ckpt_dir}")
    print(f"GCP bucket: {bucket_name}")

    try:
        model_path = ckpt_dir

        # Check if this is a SkyRL checkpoint that needs processing
        if is_skyrl_checkpoint(model_path):
            print("Detected SkyRL checkpoint format. Processing FSDP model...")
            processed_model_dir = process_skyrl_model(model_path)
            upload_path = processed_model_dir
            cleanup_temp = True
        else:
            print("Using model directory as-is...")
            upload_path = model_path
            cleanup_temp = False

        gcp_path = upload_model_to_gcp(
            bucket_name=bucket_name,
            model_directory=upload_path,
            save_name=gcp_model_name,
        )
        print(f"Model successfully uploaded to GCP: {gcp_path}")

        # Clean up temporary directory if we created one
        if cleanup_temp and os.path.exists(upload_path):
            shutil.rmtree(upload_path)
            print(f"Cleaned up temporary directory: {upload_path}")

    except Exception as e:
        print(f"Error uploading model to GCP: {str(e)}")
        import traceback

        traceback.print_exc()


if __name__ == "__main__":
    upload_model_to_gcp_main()
