import os
import json
import csv
import torch
import gc
from pathlib import Path
from .utils.configs import ApibenchDataConfig, TrainConfig
from .utils.utility import set_seed
from .openmodel import LoRAModelManager
from .utils.prepareDataset import convert_to_conversational, load_dataset_json
from .utils.wandb import WandbLogger
from dotenv import load_dotenv
from .utils.parser import TrainParser
from .train_loop import train
from skopt import gp_minimize
from skopt.space import Real, Integer
from skopt.utils import use_named_args
from skopt.space import Categorical

load_dotenv()

# Global trial counter and results storage
trial_counter = 0
trial_results = []
dataset_train = None
dataset_val = None
variant_name = None

space = [
    # LoRA parameters
    # Integer(16, 128, name="lora_r"),
    # Integer(16, 256, name="lora_alpha"),
    # Real(0.0, 0.3, name="lora_dropout"),
    # Categorical([
    #     "q_proj,v_proj",                           # Query and Value (most common)
    #     "q_proj,k_proj,v_proj,o_proj",            # All attention layers (comprehensive)
    #     "q_proj,o_proj"                            # Query and Output (alternative approach)
    # ], name="target_modules"),
    
    # Training hyperparameters
    Real(1e-5, 1e-3, name="lr"),                 # Learning rate: 0.00001 to 0.001
    Real(0.5, 2.0, name="max_grad_norm"),        # Gradient clipping: 0.5 to 2.0
    Real(0.0, 0.2, name="label_smoothing"),      # Label smoothing: 0.0 to 0.2
    Real(0.0001, 0.1, name="weight_decay"),      # Weight decay: 0.0001 to 0.1
    Integer(5, 50, name="warmup_steps"),         # Warmup steps: 5 to 50
]


def create_trial_result(trial_num, params, eval_loss):
    """Create a trial result dictionary with hyperparameters from params."""
    hyperparameters = {}
    
    for param_name, param_value in params.items():
        if param_name == "target_modules":
            # Handle special case for target_modules (convert string to list)
            hyperparameters[param_name] = param_value.split(',')
        elif isinstance(param_value, float):
            hyperparameters[param_name] = float(param_value)
        else:
            hyperparameters[param_name] = int(param_value)
    
    return {
        "trial": trial_num,
        "hyperparameters": hyperparameters,
        "eval_loss": eval_loss,
    }


def save_results_to_files(trial_results, variant_name, results_dir):
    """Save trial results to both JSON and CSV files."""
    # Sort results by eval_loss (lowest to highest)
    sorted_results = sorted(trial_results, key=lambda x: x['eval_loss'])
    
    # Save JSON file
    json_file = results_dir / f"{variant_name}.json"
    with open(json_file, 'w') as f:
        json.dump({
            "best_result": sorted_results[0] if sorted_results else None,
            "all_trials": sorted_results,
            }, f, indent=2)
    
    # Save CSV file (only all_trials, excluding best_result)
    csv_file = results_dir / f"{variant_name}.csv"
    
    if sorted_results:
        # Get parameter names from the first trial
        param_names = [dim.name for dim in space]
        fieldnames = ['trial'] + param_names + ['eval_loss']
        
        with open(csv_file, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            
            for trial in sorted_results:
                row = {'trial': trial['trial'], 'eval_loss': trial['eval_loss']}
                
                # Add hyperparameters to row
                for param_name in param_names:
                    param_value = trial['hyperparameters'].get(param_name, '')
                    
                    # Handle target_modules specially (convert list to comma-separated string)
                    if param_name == "target_modules" and isinstance(param_value, list):
                        row[param_name] = ','.join(param_value)
                    else:
                        row[param_name] = param_value
                
                writer.writerow(row)
    return sorted_results


@use_named_args(space)
def objective(**params):
    global trial_counter, trial_results, dataset_train, dataset_val, variant_name
    trial_counter += 1
    
    train_config = TrainParser().parse_args()
    
    train_config['variant_name'] = f"{variant_name}"
    train_config['extra_info'] = f"{trial_counter}"
    train_config['seed'] = 42
    train_config["hyperparameters_search"] = True
    # config_overrides['epochs'] = 1
    
    # Dynamically set config overrides from params
    for param_name, param_value in params.items():
        if param_name == "target_modules":
            train_config[param_name] = param_value.split(',')
        elif isinstance(param_value, float):
            train_config[param_name] = float(param_value)
        else:
            train_config[param_name] = int(param_value)
    
    wandb_key = os.getenv("WANDB_API_KEY")
    wandb_logger = WandbLogger(wandb_key, train_config, mode="train")
    
    # Clean up any existing models and free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = LoRAModelManager(config=train_config, device_map=device)
    
    try:
        loss = train(
            trainConfig=train_config,
            model=model,
            dataset_train=dataset_train,
            dataset_val=dataset_val,
            wandb_logger=wandb_logger
        )
    finally:
        # Cleanup after training
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    # Store trial results using the helper function
    trial_result = create_trial_result(trial_counter, params, loss)
    trial_results.append(trial_result)
    
    # Finish WandB logging
    if wandb_logger:
        wandb_logger.finish()
        
    return loss

def main():
    global dataset_train, dataset_val, variant_name
    
    set_seed(42)
    fake_train_config = TrainConfig(
        experience_name="temp",
        output_path=Path("temp"),
        retriever=None,
        repo_id="huggyllama/llama-7b"
    )
    
    model = LoRAModelManager(config=fake_train_config)
    dataset_config = ApibenchDataConfig()
    
    dataset_json_train = load_dataset_json(dataset_config.train_set)
    
    dataset_train = convert_to_conversational(
        raw_data=dataset_json_train,
        config=fake_train_config,
        tokenizer=model.tokenizer
    )

    dataset_json_val = load_dataset_json(dataset_config.val_set)
    dataset_val = convert_to_conversational(
        raw_data=dataset_json_val,
        config=fake_train_config,
        tokenizer=model.tokenizer
    )

    # Clean up the temporary model
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    
    variant_name = "explanation"
    
    res = gp_minimize(
        objective,
        dimensions=space,
        n_calls=20,        # total trials
        n_random_starts=5, # random warmup points
        random_state=42,
    )
    
    # Create results directory if it doesn't exist
    results_dir = Path("results/hyperparms_search")
    results_dir.mkdir(parents=True, exist_ok=True)
    
    # Save results to both JSON and CSV files
    sorted_results = save_results_to_files(trial_results, variant_name, results_dir)
    
    print(f"Best hyperparameters: {sorted_results[0]['hyperparameters'] if sorted_results else 'None'}")
    print(f"Best eval loss: {sorted_results[0]['eval_loss'] if sorted_results else 'None'}")

    

if __name__ == "__main__":
    main()
