import os
import torch
from src.models import ImageEncoder
from src.utils.variables_and_paths import get_zeroshot_path
from src.utils.distributed import is_main_process

def load_pretrained_model(args):
    """
    Load pretrained model parameters, ensuring only main process downloads model in distributed environment
    """
    # Ensure cache directory uses absolute path
    cache_dir = os.path.abspath(os.path.join(args.save_dir, "pretrained_models"))
    
    # Distributed environment handling
    if hasattr(args, 'world_size') and args.world_size > 1:
        # Main process responsible for creating directory and model
        if is_main_process():
            # Ensure cache directory exists
            os.makedirs(cache_dir, exist_ok=True)
            print(f"Main process: Creating cache directory {cache_dir}")
            
            model_cache_path = os.path.join(cache_dir, f"{args.model}_pretrained.pt")
            print(f"Main process: Model will be saved to {model_cache_path}")

            # First try to load pretrained model using get_zeroshot_path
            try:
                ptm_path = get_zeroshot_path(
                    args.model_location, "MNISTVal", model=args.model)
                if os.path.exists(ptm_path):
                    print(f"Main process: Loading pretrained model from {ptm_path}")
                    ptm_check = torch.load(ptm_path, map_location="cpu")
                    # Save to cache
                    torch.save(ptm_check, model_cache_path)
                    print(f"Main process: Model saved to {model_cache_path}")
                elif os.path.exists(model_cache_path):
                    print(f"Main process: Loading pretrained model from cache {model_cache_path}")
                    ptm_check = torch.load(model_cache_path, map_location="cpu")
                else:
                    print("Main process: Pretrained model not found, creating new model")
                    ptm_check = ImageEncoder(args.model).state_dict()
                    # Save model to cache
                    torch.save(ptm_check, model_cache_path)
                    print(f"Main process: New model saved to {model_cache_path}")
            except Exception as e:
                print(f"Main process: Error loading pretrained model: {e}")
                # Create new model as fallback
                ptm_check = ImageEncoder(args.model).state_dict()
                torch.save(ptm_check, model_cache_path)
                print(f"Main process: Fallback model saved to {model_cache_path}")
        
        # Broadcast cache path to all processes
        if is_main_process():
            cache_path_tensor = torch.tensor([ord(c) for c in model_cache_path] + [0] * (256 - len(model_cache_path)), 
                                          dtype=torch.int32, device=args.device)
        else:
            cache_path_tensor = torch.zeros(256, dtype=torch.int32, device=args.device)
        
        # Broadcast
        torch.distributed.broadcast(cache_path_tensor, 0)
        
        # Decode path
        model_cache_path = ''.join([chr(i) for i in cache_path_tensor.cpu().numpy() if i != 0])
        
        # Ensure all processes synchronize and wait
        torch.distributed.barrier()
        
        # All processes load model from cache
        if not is_main_process():
            print(f"Process {args.rank}: Loading pretrained model from {model_cache_path}")
        
        if os.path.exists(model_cache_path):
            ptm_check = torch.load(model_cache_path, map_location=args.device)
        else:
            raise FileNotFoundError(f"Process {args.rank} cannot find model file: {model_cache_path}")
    
    # Non-distributed environment
    else:
        os.makedirs(cache_dir, exist_ok=True)
        model_cache_path = os.path.join(cache_dir, f"{args.model}_pretrained.pt")
        
        try:
            ptm_path = get_zeroshot_path(
                args.model_location, "MNISTVal", model=args.model)
            if os.path.exists(ptm_path):
                print(f"Loading pretrained model from {ptm_path}")
                ptm_check = torch.load(ptm_path, map_location="cpu")
            elif os.path.exists(model_cache_path):
                print(f"Loading pretrained model from cache {model_cache_path}")
                ptm_check = torch.load(model_cache_path, map_location="cpu")
            else:
                print("Pretrained model not found, creating new model")
                ptm_check = ImageEncoder(args.model).state_dict()
                # Save to cache
                torch.save(ptm_check, model_cache_path)
        except Exception as e:
            print(f"Error loading pretrained model: {e}")
            ptm_check = ImageEncoder(args.model).state_dict()

    return ptm_check

def apply_merged_vector(base_state_dict, merged_vector, alpha, device, method, model_name):
    """Apply merged vector to model state
    
    Args:
        base_state_dict: Base state dictionary
        merged_vector: Merged vector
        alpha: Scaling coefficient
        device: Device
        method: Merging method
        model_name: Model name
        
    Returns:
        merged_model: Merged model
    """
    merged_model = ImageEncoder(model_name)
    merged_model.load_state_dict(base_state_dict)
    merged_state_dict = merged_model.state_dict()
    
    for key in merged_vector:
        if key in merged_state_dict:
            merged_state_dict[key] = merged_state_dict[key].to(device) + alpha * merged_vector[key].to(device)
    
    merged_model.load_state_dict(merged_state_dict)
    return merged_model