import os
import spacy
from config import parse_args, ExperimentConfig
from data.dataset import load_or_create_datasets
from models.training import fine_tune_model, load_model
from models.evaluation import evaluate_model
from utils.metrics import plot_domain_impact, generate_comparison_report
import torch 

def run_experiment(config: ExperimentConfig):
    print(f"\n{'='*80}")
    print(f"Starting experiment: {config.experiment_name}")
    print(f"Model name: {config.model_name}")
    print(f"{'='*80}\n")
    
    # create output directories
    os.makedirs(config.output_dir, exist_ok=True)
    
    # load or create the datasets
    training_data, testing_data = load_or_create_datasets(config)

    
    # load or train the model
    if config.load_model_path:
        print(f"Loading fine-tuned model from {config.load_model_path}")
        model, tokenizer = load_model(config.load_model_path)
        # set tokenizer padding to be left 
        tokenizer.padding_side = "left"
    elif config.test_only:
        print(f"Test-only mode: Loading base model {config.model_name} without fine-tuning")
        model, tokenizer = load_model(config.model_name)
        tokenizer.padding_side = "left"
    else:
        print(f"Fine-tuning model {config.model_name}")
        model, tokenizer = fine_tune_model(config.model_name, training_data, config)
        tokenizer.padding_side = "left"

    
    # add validation checks after finetuning
    if not config.test_only:
        # sanity check
        print("\nRunning validation check on training examples...")
        validation_examples = training_data[:5] if len(training_data) >= 5 else training_data
        
        for i, example in enumerate(validation_examples):
            prompt = example["text"]
            expected = example["object"]
            
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            with torch.no_grad():
                outputs = model.generate(
                    inputs["input_ids"],
                    max_new_tokens=20,
                    num_return_sequences=1,
                    do_sample=False 
                )
            
            input_length = inputs["input_ids"].size(1)
            generated_tokens = outputs[0][input_length:]
            generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
            
            print(f"Validation Example {i+1}:")
            print(f"  Prompt: {prompt}")
            print(f"  Expected: {expected}")
            print(f"  Generated: {generated_text}")
            print()
        
    # eval model
    print("\nStarting model evaluation...")
    
    results = evaluate_model(
        model, 
        tokenizer, 
        testing_data, 
        model_name=config.experiment_name, 
        batch_size=config.batch_size,
        output_dir=config.output_dir
    )
    
    print("\n--- EVALUATION RESULTS ---")
    
    if 'in_domain' in results and 'accuracy' in results['in_domain']:
        print(f"In-Domain Accuracy: {results['in_domain']['accuracy']*100:.2f}%")
        print(f"  - Correct: {results['in_domain']['correct']} / {results['in_domain']['total']}")
    else:
        print("Warning: In-domain results are incomplete or missing")
        print("Available keys in results['in_domain']:", list(results.get('in_domain', {}).keys()))
    
    if 'cross_domain' in results and 'accuracy' in results['cross_domain']:
        print(f"\nCross-Domain Accuracy: {results['cross_domain']['accuracy']*100:.2f}%")
        print(f"  - Correct: {results['cross_domain']['correct']} / {results['cross_domain']['total']}")
        
    
    if 'template_types' in results:
        print("\nTemplate-Specific Performance:")
        template_types = ["variations", "paraphrased_synonym"] #, "domain_template", "semantic_paraphrasing"]:, "domain_template", "semantic_paraphrasing"]
        for template_type in template_types:
            if template_type in results['template_types']:
                template_data = results['template_types'][template_type]
                
                in_acc = template_data['in_domain'].get('accuracy', 0) * 100 if 'in_domain' in template_data else 0
                cross_acc = template_data['cross_domain'].get('accuracy', 0) * 100 if 'cross_domain' in template_data else 0
                drop = in_acc - cross_acc
                
                print(f"  {template_type:<25}: {in_acc:.2f}% → {cross_acc:.2f}% (Drop: {drop:.2f}%)")
    
    print(f"\nResults saved to {config.output_dir}")
    print(f"{'='*80}\n")

    
    return results

def main():
    config = parse_args()
    results = run_experiment(config)


if __name__ == "__main__":
    main()