"""Shared training utilities for supervised and RL training"""

import os
import json
import torch
import torch.distributed as dist
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import wandb
import functools

from .model_utils import apply_alpaca_chat_template, apply_round_chat_template, load_lora_hist
from .notification_utils import notify_exception, notify_completion

def setup_distributed():
    """Setup distributed training with proper CUDA context management"""
    multi_gpu = torch.cuda.device_count() > 1
    
    if multi_gpu and not dist.is_initialized():
        try:
            # Initialize process group with timeout
            dist.init_process_group(
                backend="nccl", 
                timeout=torch.distributed.default_pg_timeout
            )
            local_rank = dist.get_rank() % torch.cuda.device_count()
            
            # Set device for current process
            torch.cuda.set_device(local_rank)
            
            print(f"Distributed training initialized. Rank: {dist.get_rank()}, Local rank: {local_rank}")
        except Exception as e:
            print(f"Failed to initialize distributed training: {e}")
            local_rank = 0
            multi_gpu = False
    else:
        local_rank = 0
    
    return local_rank, multi_gpu

def setup_experiment(config, local_rank, project_name="molgen"):
    """Setup experiment directory and wandb"""
    timestamp = str(datetime.now()).replace(" ", "_")
    config.exp_save_dir = os.path.join(config.exp_save_dir, timestamp)
    config.lora_dir = os.path.join(config.lora_dir, timestamp)
    
    if local_rank == 0:
        os.makedirs(config.exp_save_dir, exist_ok=True)
        wandb.init(project=project_name, name=config.exp_name, dir=config.exp_save_dir)
        params_dict = vars(config)
        table = wandb.Table(
            columns=list(params_dict.keys()),
            data=[list(params_dict.values())])
        wandb.log({"config": table})

def setup_tokenizer(config):
    """Setup tokenizer with proper configuration"""
    tokenizer = AutoTokenizer.from_pretrained(config.base_model_name, cache_dir=config.cache_dir)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right" if not hasattr(config, 'padding_side') else config.padding_side
    
    if config.use_alpaca:
        tokenizer.apply_chat_template = apply_alpaca_chat_template.__get__(tokenizer)
        tokenizer.response_split_id = "### Response:"
    elif "Qwen" in config.base_model_name:
        tokenizer.response_split_id = "<|im_start|>assistant\n<think>\n\n</think>\n\n"
    elif "ChemLLM" in config.base_model_name:
        tokenizer.response_split_id = "<|im_start|>assistant\n"
    elif "ChemDFM" in config.base_model_name:
        tokenizer.apply_chat_template = apply_round_chat_template.__get__(tokenizer)
        tokenizer.response_split_id = "Assistant:"
    else:
        tokenizer.response_split_id = "<|start_header_id|>assistant<|end_header_id|>"
    
    return tokenizer

def setup_model(config):
    """Setup model with LoRA history for distributed training"""
    try:
        # For distributed training with torchrun, don't use device_map
        # The model will be moved to the correct device after distributed setup
        model = AutoModelForCausalLM.from_pretrained(
            config.base_model_name, 
            torch_dtype=torch.float16, 
            cache_dir=config.cache_dir
        ).to("cuda")
        
        # Load LoRA history if specified
        if hasattr(config, 'load_directory') and config.load_directory:
            model, lora_hist = load_lora_hist(config.load_directory, model, cache_dir=config.cache_dir)
        else:
            lora_hist = []
        
        return model, lora_hist
    except Exception as e:
        print(f"Error loading model: {e}")
        # Clean up CUDA memory on error
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        raise

def setup_lora(model, config):
    """Setup LoRA configuration"""
    peft_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=config.lora_target_modules
    )
    model = get_peft_model(model, peft_config)
    return model

def save_experiment(config, model, lora_hist, local_rank):
    """Save model and experiment artifacts"""
    if local_rank == 0:
        model.save_pretrained(config.lora_dir)
        lora_hist.append(config.lora_dir)
        
        with open(os.path.join(config.exp_save_dir, "config.json"), "w") as f:
            json.dump(vars(config), f, indent=2)
        with open(os.path.join(config.exp_save_dir, "lora_hist.json"), "w") as f:
            json.dump(lora_hist, f, indent=2)
        
        wandb.finish()

def with_telegram_notifications(func):
    """Decorator to add telegram notifications for training functions"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            result = func(*args, **kwargs)
            # Only send notifications from rank 0 in distributed training
            if not dist.is_initialized() or dist.get_rank() == 0:
                config = args[0] if args else kwargs.get('config')
                exp_name = getattr(config, 'exp_name', 'Unknown') if config else 'Unknown'
                notify_completion(exp_name)
            return result
        except Exception as e:
            if not dist.is_initialized() or dist.get_rank() == 0:
                config = args[0] if args else kwargs.get('config')
                exp_name = getattr(config, 'exp_name', 'Unknown') if config else 'Unknown'
                notify_exception(exp_name, e)
            raise
    return wrapper

def cleanup_distributed(multi_gpu):
    """Cleanup distributed training with proper error handling"""
    if multi_gpu and dist.is_initialized():
        try:
            dist.barrier()  # Synchronize before cleanup
            dist.destroy_process_group()
            print("Distributed training cleaned up successfully")
        except Exception as e:
            print(f"Error during distributed cleanup: {e}")
    
    # Additional CUDA cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()