import os
import torch
import pandas as pd
from typing import Dict, List, Any, Tuple
from collections import defaultdict

def evaluate_model(model, tokenizer, testing_data, model_name, batch_size=16, output_dir="."):
    model.eval()
    os.makedirs(output_dir, exist_ok=True)

    used_template_types = ["variations", "paraphrased_synonym", "antonym", "disfluent", "exact"]

    results = {
        "in_domain": {"correct": 0, "total": 0},
        "cross_domain": {"correct": 0, "total": 0},
        "domain_pairs": {},
        "template_types": {},
        "examples": {
            "in_domain_correct": [],
            "in_domain_incorrect": [],
            "cross_domain_correct": [],
            "cross_domain_incorrect": []
        }
    }
    
    for template_type in used_template_types:
        results["template_types"][template_type] = {
            "in_domain": {"correct": 0, "total": 0},
            "cross_domain": {"correct": 0, "total": 0}
        }

    def get_domain_from_predicate(predicate_id):
        predicate_domains = {
            'P17': 'location (country)',       
            'P19': 'location (place of birth)',       
            'P27': 'location (country of citizenship)',       
            'P30': 'location (continent)',       
            'P36': 'location (capital)',     
            'P131': 'location (located in admin entity)',    
            'P106': 'person (occupation)',      
            'P166': 'person (award received)',   
            'P569': 'person (DOB)',       
            'P570': 'person (DOD)',   
            'P1412': 'person (languages spoken)',      
            'P31': 'organization (instance of)', 
            'P159': 'organization (headquarter loc)', 
            'P176': 'organization (manufacturer)', 
            'P452': 'organization (industry)',  
            'P749': 'organization (parent org)',  
            'P50': 'creative_work (author)',  
            'P57': 'creative_work (director)', 
            'P86': 'creative_work (composer)', 
            'P136': 'creative_work (genre)',
            'P577': 'creative_work (publication date)' 
        }
        return predicate_domains.get(predicate_id, "unknown")

    def collect_example(category, item, generated, expected, is_correct, max_per_category=5):
        if len(results["examples"][category]) < max_per_category:
            add_example = False
            
            if category == "in_domain_correct" and is_correct:
                add_example = True
            elif category == "in_domain_incorrect" and not is_correct:
                add_example = True
            elif category == "cross_domain_correct" and is_correct:
                add_example = True
            elif category == "cross_domain_incorrect" and not is_correct:
                add_example = True
                
            if add_example:
                if "source_predicate" in item and "target_predicate" in item:
                    source_domain = get_domain_from_predicate(item["source_predicate"])
                    template_domain = get_domain_from_predicate(item["target_predicate"])
                    template_type = item.get("template", "unknown")
                    results["examples"][category].append({
                        "source_domain": source_domain,
                        "template_domain": template_domain,
                        "template_type": template_type,
                        "subject": item["subject"],
                        "prompt": item["text"],
                        "expected": expected,
                        "generated": generated
                    })
                else:
                    domain = get_domain_from_predicate(item["predicate_id"])
                    template_type = item.get("template", "unknown")
                    results["examples"][category].append({
                        "source_domain": domain,
                        "template_domain": domain,
                        "template_type": template_type,
                        "subject": item["subject"],
                        "prompt": item["text"],
                        "expected": expected,
                        "generated": generated
                    })
                return True
        return False
    
    
    def process_batch(batch_items, domain_type):
        batch_prompts = [item["text"].rstrip() for item in batch_items]
        batch_expected = [item["expected_answer"] for item in batch_items]
        batch_templates = [item.get("template", "unknown") for item in batch_items]
        
        inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=16,
                min_new_tokens=1,
                # num_return_sequences=1,
                do_sample=False,
                # top_p = 0.95,
                # top_k = 50,
                # # repetition penalty 
                # repetition_penalty=1.2,
            )
        
        batch_results = []
        for i, output in enumerate(outputs):
            try:
                input_length = inputs["input_ids"][i].size(0)
                generated_tokens = output[input_length:]
                
                generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
                
                expected = batch_expected[i].lower()
                generated_lower = generated_text.lower()
                
                is_correct = expected in generated_lower
                
                template_type = batch_templates[i]
                
                if domain_type == "in_domain":
                    collect_example("in_domain_correct", batch_items[i], generated_text, 
                                batch_expected[i], is_correct)
                    collect_example("in_domain_incorrect", batch_items[i], generated_text, 
                                batch_expected[i], is_correct)
                else:
                    collect_example("cross_domain_correct", batch_items[i], generated_text, 
                                batch_expected[i], is_correct)
                    collect_example("cross_domain_incorrect", batch_items[i], generated_text, 
                                batch_expected[i], is_correct)
                
                results[domain_type]["total"] += 1
                if is_correct:
                    results[domain_type]["correct"] += 1
                
                if template_type in results["template_types"]:
                    results["template_types"][template_type][domain_type]["total"] += 1
                    if is_correct:
                        results["template_types"][template_type][domain_type]["correct"] += 1
                
                batch_results.append({
                    "expected": batch_expected[i],
                    "generated": generated_text,
                    "full_response": generated_text,
                    "is_correct": is_correct,
                    "template_type": template_type
                })
                
            except Exception:
                batch_results.append({
                    "expected": batch_expected[i] if i < len(batch_expected) else "unknown",
                    "generated": "",
                    "full_response": "",
                    "is_correct": False,
                    "template_type": batch_templates[i] if i < len(batch_templates) else "unknown"
                })
                
        return batch_results
    
    pair_results = {}
    in_domain_details = []
    
    predicate_template_to_examples = defaultdict(list)
    for item in testing_data["in_domain"]:
        predicate = item["predicate_id"]
        template_type = item.get("template", "unknown")
        domain = get_domain_from_predicate(predicate)
        key = (domain, template_type)
        predicate_template_to_examples[key].append(item)
    
    for (domain, template_type), examples in predicate_template_to_examples.items():
        pair_key = f"{domain}->{domain}"
        if pair_key not in pair_results:
            pair_results[pair_key] = defaultdict(lambda: {"correct": 0, "total": 0})
        
        for i in range(0, len(examples), batch_size):
            batch = examples[i:i+batch_size]
            batch_results = process_batch(batch, "in_domain")
            
            pair_results[pair_key][template_type]["total"] += len(batch)
            pair_results[pair_key][template_type]["correct"] += sum(r["is_correct"] for r in batch_results)

            for j, res in enumerate(batch_results):
                in_domain_details.append({
                    **batch[j],
                    "domain": domain,
                    "template_type": template_type,
                    "generated": res["generated"],
                    "correct": res["is_correct"]
                })
    
    cross_domain_details = []
    cross_domain_keys_to_examples = defaultdict(list)
    
    for item in testing_data["cross_domain"]:
        source_predicate = item["source_predicate"]
        target_predicate = item["target_predicate"]
        template_type = item.get("template", "unknown")
        source_domain = get_domain_from_predicate(source_predicate)
        target_domain = get_domain_from_predicate(target_predicate)
        key = (source_domain, target_domain, template_type)
        cross_domain_keys_to_examples[key].append(item)
    
    for (source_domain, target_domain, template_type), examples in cross_domain_keys_to_examples.items():
        pair_key = f"{source_domain}->{target_domain}"
        if pair_key not in pair_results:
            pair_results[pair_key] = defaultdict(lambda: {"correct": 0, "total": 0})
        
        for i in range(0, len(examples), batch_size):
            batch = examples[i:i+batch_size]
            batch_results = process_batch(batch, "cross_domain")
            
            pair_results[pair_key][template_type]["total"] += len(batch)
            pair_results[pair_key][template_type]["correct"] += sum(r["is_correct"] for r in batch_results)
            
            for j, res in enumerate(batch_results):
                cross_domain_details.append({
                    **batch[j],
                    "source_domain": source_domain,
                    "template_domain": target_domain,
                    "template_type": template_type,
                    "generated": res["generated"],
                    "correct": res["is_correct"]
                })
    
    details_dir = os.path.join(output_dir, "details")
    os.makedirs(details_dir, exist_ok=True)
    
    pd.DataFrame(in_domain_details).to_csv(os.path.join(details_dir, f"in_domain_results_{model_name}.csv"), index=False)
    pd.DataFrame(cross_domain_details).to_csv(os.path.join(details_dir, f"cross_domain_results_{model_name}.csv"), index=False)
