from datetime import timedelta
import logging
import os

import torch
import torch.distributed as dist
import wandb
from peft import LoraConfig

from src.utils.load_hf_model import load_model_and_tokenizer

logger = logging.getLogger(__name__)


def setup_wandb(args, run_name: str):
    """Initialize W&B logging."""
    if args.wandb_project and os.environ.get('LOCAL_RANK', '0') == '0':
        wandb_run = wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=run_name,
            config=vars(args),
            resume="allow"
        )
        return wandb_run
    return None


def get_peft_config(args):
    """Get default LoRA configuration."""
    return LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )


def get_ds_config(args):
    """Get default DeepSpeed configuration."""
    return {
        "zero_optimization": {
            "stage": 2, 
            "overlap_comm": False,
            "reduce_bucket_size": "auto",
            "contiguous_gradients": True,
            "offload_optimizer": {"device": "none"},
            "offload_param": {"device": "none"}, 
        },
        "gradient_clipping": "auto",
        "train_batch_size": "auto",
        "train_micro_batch_size_per_gpu": args.per_device_train_batch_size,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,
        "steps_per_print": 200,
    }


def init_distributed_mode():
    # Ensure the script is being run with distributed launch
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ['LOCAL_RANK'])
    else:
        raise RuntimeError("Please set RANK, WORLD_SIZE, and LOCAL_RANK in the environment variables.")
    # Initialize the process group with a timeout of one day
    timeout = timedelta(days=1)  # 1 day timeout
    dist.init_process_group(
        backend='nccl',  
        init_method=None,
        timeout=timeout
    )
    torch.cuda.set_device(local_rank)
    dist.barrier()
    print(f"Distributed initialized. Rank: {rank}, World Size: {world_size}")


def initialize_models(args, model_path, wrap_value_head=False):
    """Initialize or reload models for each iteration."""
    peft_config = get_peft_config(args)
    model, tokenizer = load_model_and_tokenizer(
        model_name=model_path,
        peft_config=peft_config,
        wrap_value_head=wrap_value_head
    )
    # Set chat template if needed
    if tokenizer.chat_template is None:
        DEFAULT_CHAT_TEMPLATE = """
        {% for message in messages %}
        {% if message['role'] == 'user' %}
        {{ '<|user|>\n' + message['content'] + eos_token }}
        {% elif message['role'] == 'system' %}
        {{ '<|system|>\n' + message['content'] + eos_token }}
        {% elif message['role'] == 'assistant' %}
        {{ '<|assistant|>\n' + message['content'] + eos_token }}
        {% endif %}
        {% if loop.last and add_generation_prompt %}
        {{ '<|assistant|>' }}
        {% endif %}
        {% endfor %}
        """
        tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
        logger.info("Set tokenizer chat template")
    
    return model, tokenizer

