import json
import os
from pathlib import Path

def get_config_regression(model_name, dataset_name, config_file=""):
    """
    Get configuration for regression tasks
    """
    if config_file == "":
        config_file = Path(__file__).parent / "config" / "dec_config.json"
    
    # Default configuration
    default_config = {
        'model_name': model_name.lower(),
        'dataset_name': dataset_name.lower(),
        'train_mode': 'regression',
        
        # Model architecture
        'use_bert': True,
        'use_finetune': True,
        'transformers': 'bert',
        'pretrained': 'bert-base-uncased',
        'need_data_aligned': False,
        
        # Feature dimensions (will be updated by dataloader)
        'feature_dims': [768, 74, 35],  # BERT, audio, video
        'dst_feature_dim_nheads': [40, 10],  # feature_dim, num_heads
        'nlevels': 5,
        
        # Conv1D parameters
        'conv1d_kernel_size_l': 1,
        'conv1d_kernel_size_a': 1, 
        'conv1d_kernel_size_v': 1,
        
        # Prototype parameters
        'num_prototypes': 8,
        'lambda_ot': 0.1,
        'ot_num_iters': 50,
        
        # Dropout parameters
        'attn_dropout': 0.1,
        'attn_dropout_a': 0.0,
        'attn_dropout_v': 0.0,
        'relu_dropout': 0.1,
        'embed_dropout': 0.25,
        'res_dropout': 0.1,
        'output_dropout': 0.0,
        'text_dropout': 0.0,
        'attn_mask': True,
        
        # Loss weights
        'alpha1': 0.1,  # decoupling loss weight
        'alpha2': 0.1,  # alignment loss weight
        
        # Training parameters
        'batch_size': 24,
        'learning_rate': 1e-3,
        'weight_decay': 0.0,
        'num_epochs': 100,
        'patience': 20,
        'clip': 0.8,
        'when': 20,
        'factor': 0.1,
        
        # Data paths (to be set by user)
        'featurePath': '',
        'feature_T': '',
        'feature_A': '', 
        'feature_V': '',
    }
    
    # Dataset specific configurations
    if dataset_name.lower() == 'mosi':
        dataset_config = {
            'featurePath': './data/MOSI/mosi_data.pkl',
            'seq_lens': [50, 500, 375],
            'feature_dims': [768, 74, 35],
        }
    elif dataset_name.lower() == 'mosei':
        dataset_config = {
            'featurePath': './data/MOSEI/mosei_data.pkl', 
            'seq_lens': [50, 500, 500],
            'feature_dims': [768, 74, 35],
        }
    elif dataset_name.lower() == 'iemocap':
        dataset_config = {
            'featurePath': './data/IEMOCAP/iemocap_data.pkl',
            'seq_lens': [50, 375, 500], 
            'feature_dims': [768, 74, 35],
        }
    else:
        dataset_config = {}
    
    # Update default config with dataset specific config
    default_config.update(dataset_config)
    
    # Load from file if exists
    if config_file and os.path.exists(config_file):
        try:
            with open(config_file, 'r') as f:
                file_config = json.load(f)
            default_config.update(file_config)
        except Exception as e:
            print(f"Warning: Could not load config file {config_file}: {e}")
    
    return default_config

def save_config(config, save_path):
    """Save configuration to JSON file"""
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as f:
        json.dump(config, f, indent=4)

def load_config(config_path):
    """Load configuration from JSON file"""
    with open(config_path, 'r') as f:
        return json.load(f)