import pickle
import os
import random
from typing import Dict, List, Tuple, Any, Optional
import pandas as pd
from collections import defaultdict
from datasets import load_dataset
from sklearn.model_selection import train_test_split

from data.templates import create_predicate_templates, generate_cross_domain_pairs
# from data.templates_n_3 import create_predicate_templates, generate_cross_domain_pairs


def get_predicate_domains():
    """Group predicates by their domains"""
    domains = {
        "LOCATION": ["P17", "P19", "P20", "P27", "P30", "P36"],
        "PERSON": ["P106", "P166", "P39", "P69", "P103", "P140"],
        "ORGANIZATION": ["P31", "P159", "P112", "P127", "P138", "P527"],
        "CREATIVE_WORK": ["P50", "P136", "P57", "P86", "P144", "P495"]
    }
    

    predicate_to_domain = {}
    for domain, predicates in domains.items():
        for predicate in predicates:
            predicate_to_domain[predicate] = domain
            
    return domains, predicate_to_domain

def generate_training_testing_data(df):
    training_pairs = []
    testing_pairs = {
        "in_domain": [],
        "cross_domain": []
    }
    
    predicate_templates = create_predicate_templates()
    cross_domain_pairs = generate_cross_domain_pairs()
    domains, predicate_to_domain = get_predicate_domains()
    
    # gather predicates
    predicate_examples = {}
    for _, row in df.iterrows():
        predicate = row['predicate_id']
        if predicate not in predicate_templates:
            continue
            
        if predicate not in predicate_examples:
            predicate_examples[predicate] = []
            
        predicate_examples[predicate].append({
            'sub_label': row['sub_label'],
            'obj_label': row['obj_label'],
            'sub_uri': row['sub_uri'],
            'obj_uri': row['obj_uri'],
            'predicate_id': row['predicate_id']
        })
    
    # create training and testing examples
    for predicate, examples in predicate_examples.items():
        template_info = predicate_templates[predicate]
        
        train_examples, test_examples = train_test_split(examples, test_size=0.2, random_state=42)
        
        for example in train_examples:
            subject = example['sub_label']
            object_ = example['obj_label']
            
            random_int = random.randint(0, len(template_info['variations']) - 1)
            variation_func = template_info['variations'][random_int]
            question = variation_func(subject)
            
            prompt = f"{question}"  
            training_pairs.append({
                "predicate_id": predicate,
                "text": prompt,
                "subject": subject,
                "object": object_,
                "template": "variations",
                "domain": predicate_to_domain.get(predicate, "OTHER")
            })
        
        for example in test_examples:
            subject = example['sub_label']
            object_ = example['obj_label']
            
            # test examples where the template is the same POS as training 
            question = template_info['synonym'](subject)
            prompt = f"{question}"
            testing_pairs["in_domain"].append({
                "predicate_id": predicate,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "synonym",
                "subject_uri": example['sub_uri'],
                "object_uri": example['obj_uri'],
                "domain": predicate_to_domain.get(predicate, "OTHER")
            })
            
            # add antonym test 
            question = template_info['antonym'](subject)
            prompt = f"{question}"
            testing_pairs["in_domain"].append({
                "predicate_id": predicate,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "antonym",
                "subject_uri": example['sub_uri'],
                "object_uri": example['obj_uri'],
                "domain": predicate_to_domain.get(predicate, "OTHER")
            })
            
            # add exact test 
            question = template_info['variations'][0](subject)
            prompt = f"{question}"
            testing_pairs["in_domain"].append({
                "predicate_id": predicate,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "exact",
                "subject_uri": example['sub_uri'],
                "object_uri": example['obj_uri'],
                "domain": predicate_to_domain.get(predicate, "OTHER")
            })
            
            # add disfluent test 
            question = template_info['disfluent'](subject)
            prompt = f"{question}"
            testing_pairs["in_domain"].append({
                "predicate_id": predicate,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "disfluent",
                "subject_uri": example['sub_uri'],
                "object_uri": example['obj_uri'],
                "domain": predicate_to_domain.get(predicate, "OTHER")
            })
            
            # add paraphrase test
            question = template_info['semantic_paraphrasing'](subject)
            prompt = f"{question}"
            testing_pairs["in_domain"].append({
                "predicate_id": predicate,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "semantic_paraphrasing",
                "subject_uri": example['sub_uri'],
                "object_uri": example['obj_uri'],
                "domain": predicate_to_domain.get(predicate, "OTHER")
            })
    
    all_examples = [ex for examples in predicate_examples.values() for ex in examples]
    cross_domain_tests = generate_cross_domain_tests(
        all_examples, predicate_templates, cross_domain_pairs
    )
    
    # add domain information to cross-domain tests
    for test in cross_domain_tests:
        test["source_domain"] = predicate_to_domain.get(test["source_predicate"], "OTHER")
        test["target_domain"] = predicate_to_domain.get(test["target_predicate"], "OTHER")
    
    testing_pairs["cross_domain"] = cross_domain_tests
    
    return training_pairs, testing_pairs


def generate_cross_domain_tests(domain_examples, predicate_templates, cross_domain_pairs):
    cross_domain_tests = []
    
    domain_pair_templates = defaultdict(list)
    for source_pred, target_pred, template_func in cross_domain_pairs:
        domain_pair_templates[(source_pred, target_pred)].append(template_func)
    
    for (source_pred, target_pred), template_funcs in domain_pair_templates.items():
        target_examples = [ex for ex in domain_examples if ex['predicate_id'] == target_pred]
        
        for example in target_examples:
            subject = example['sub_label']
            object_ = example['obj_label']
            
            source_template = predicate_templates[source_pred]
            
            question = source_template['synonym'](subject)
            
            prompt = f"{question}"
            cross_domain_tests.append({
                "source_predicate": source_pred,
                "target_predicate": target_pred,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "synonym"
            })
            
            # add antonym test
            question = source_template['antonym'](subject)
            
            prompt = f"{question}"
            cross_domain_tests.append({
                "source_predicate": source_pred,
                "target_predicate": target_pred,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "antonym"
            })
            
            # add exact test
            question = source_template['variations'][0](subject)
            
            prompt = f"{question}"
            cross_domain_tests.append({
                "source_predicate": source_pred,
                "target_predicate": target_pred,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "exact"
            })

            # add disfluent test
            question = source_template['disfluent'](subject)
            prompt = f"{question}"
            cross_domain_tests.append({
                "source_predicate": source_pred,
                "target_predicate": target_pred,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "disfluent"
            })
            
            # add paraphrase test
            question = source_template['semantic_paraphrasing'](subject)
            prompt = f"{question}"
            cross_domain_tests.append({
                "source_predicate": source_pred,
                "target_predicate": target_pred,
                "text": prompt,
                "subject": subject,
                "expected_answer": object_,
                "template": "semantic_paraphrasing"
            })
    return cross_domain_tests


def load_trex_dataset(cache_dir: Optional[str] = None) -> pd.DataFrame:
    trex = load_dataset('facebook/lama', split='train', cache_dir=cache_dir)
    return pd.DataFrame(trex)

def balance_domains(data_items, domain_key="domain", target_per_domain=None):
    """Balance items across domains to prevent domain collapse"""
    domain_to_items = defaultdict(list)
    
    # group items by domain
    for item in data_items:
        domain = item.get(domain_key, "OTHER")
        domain_to_items[domain].append(item)
    
    # if no target count specified, use the minimum count across domains
    if target_per_domain is None:
        if not domain_to_items:
            return []
        
        # get counts for each domain and find the median as target
        domain_counts = {domain: len(items) for domain, items in domain_to_items.items()}
        sorted_counts = sorted(domain_counts.values())
        middle_idx = len(sorted_counts) // 2
        target_per_domain = sorted_counts[middle_idx]
    
    # balance domains
    balanced_items = []
    for domain, items in domain_to_items.items():
        if len(items) <= target_per_domain:
            balanced_items.extend(items)
        else:
            # sample down to target if over
            balanced_items.extend(random.sample(items, target_per_domain))
    
    return balanced_items

def sample_dataset(training_data: List[Dict[str, Any]], testing_data: Dict[str, List[Dict[str, Any]]], sample_percentage: float = 0.3, random_seed: int = 42,balance_domains_count: Optional[int] = None  
                   ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
    random.seed(random_seed)
    domains, predicate_to_domain = get_predicate_domains()
    
    # balance domain
    if balance_domains_count:
        training_data = balance_domains(training_data, "domain", balance_domains_count)
        testing_data["in_domain"] = balance_domains(testing_data["in_domain"], "domain", balance_domains_count)
    else:
        training_data = balance_domains(training_data, "domain")
        testing_data["in_domain"] = balance_domains(testing_data["in_domain"], "domain")
    
    # sample according to config percentage (otherwise dataset is huge)
    predicate_to_training = {}
    for item in training_data:
        predicate = item["predicate_id"]
        if predicate not in predicate_to_training:
            predicate_to_training[predicate] = []
        predicate_to_training[predicate].append(item)
    
    sampled_training = []
    for predicate, examples in predicate_to_training.items():
        sample_size = max(1, int(len(examples) * sample_percentage))
        sampled_training.extend(random.sample(examples, sample_size))
    
    predicate_to_in_domain = {}
    for item in testing_data["in_domain"]:
        predicate = item["predicate_id"]
        if predicate not in predicate_to_in_domain:
            predicate_to_in_domain[predicate] = []
        predicate_to_in_domain[predicate].append(item)
    
    sampled_in_domain = []
    for predicate, examples in predicate_to_in_domain.items():
        sample_size = max(1, int(len(examples) * sample_percentage))
        sampled_in_domain.extend(random.sample(examples, sample_size))
    
    sampled_subjects = set(item["subject"] for item in sampled_in_domain)
    
    # cross-domain balance 
    domain_pairs_to_balance = defaultdict(list)
    for item in testing_data["cross_domain"]:
        source_pred = item["source_predicate"]
        target_pred = item["target_predicate"]
        
        source_domain = predicate_to_domain.get(source_pred, "OTHER")
        target_domain = predicate_to_domain.get(target_pred, "OTHER")
        
        domain_pair = f"{source_domain}->{target_domain}"
        domain_pairs_to_balance[domain_pair].append(item)
    
    # balance across domain pairs
    balanced_cross_domain = []
    if domain_pairs_to_balance:
        pair_counts = [len(items) for items in domain_pairs_to_balance.values()]
        target_per_pair = sorted(pair_counts)[len(pair_counts) // 2]
        
        for domain_pair, items in domain_pairs_to_balance.items():
            if len(items) <= target_per_pair:
                balanced_cross_domain.extend(items)
            else:
                balanced_cross_domain.extend(random.sample(items, target_per_pair))
    
    # sample with subjects filter 
    predicate_pair_to_cross = {}
    for item in balanced_cross_domain:
        source_pred = item["source_predicate"]
        target_pred = item["target_predicate"]
        pair_key = f"{source_pred}->{target_pred}"
        
        if pair_key not in predicate_pair_to_cross:
            predicate_pair_to_cross[pair_key] = []
        predicate_pair_to_cross[pair_key].append(item)
    
    sampled_cross_domain = []
    for pair_key, examples in predicate_pair_to_cross.items():
        filtered_examples = [ex for ex in examples if ex["subject"] in sampled_subjects]
        if not filtered_examples:
            continue
            
        sample_size = max(1, int(len(filtered_examples) * sample_percentage))
        
        if len(filtered_examples) > sample_size:
            sampled_cross_domain.extend(random.sample(filtered_examples, sample_size))
        else:
            sampled_cross_domain.extend(filtered_examples)
    
    sampled_testing = {
        "in_domain": sampled_in_domain,
        "cross_domain": sampled_cross_domain
    }
    
    return sampled_training, sampled_testing

def load_or_create_datasets(config):
    cache_version = f"v7_balanced_{config.experiment_name}"
    training_cache = os.path.join(config.dataset_cache_dir, f"training_data_{cache_version}.pkl")
    testing_in_domain_cache = os.path.join(config.dataset_cache_dir, f"in_domain_test_data_{cache_version}.pkl")
    testing_cross_domain_cache = os.path.join(config.dataset_cache_dir, f"cross_domain_test_data_{cache_version}.pkl")
    
    os.makedirs(config.dataset_cache_dir, exist_ok=True)
    
    if os.path.exists(training_cache) and os.path.exists(testing_in_domain_cache) and os.path.exists(testing_cross_domain_cache):
        with open(training_cache, 'rb') as f:
            training_data = pickle.load(f)
        with open(testing_in_domain_cache, 'rb') as f:
            in_domain_test_data = pickle.load(f)
        with open(testing_cross_domain_cache, 'rb') as f:
            cross_domain_test_data = pickle.load(f)
            
        testing_data = {
            "in_domain": in_domain_test_data,
            "cross_domain": cross_domain_test_data
        }
    else:
        df = load_trex_dataset(cache_dir=config.cache_dir)
        
        training_data_full, testing_data_full = generate_training_testing_data(df)
    
        target_count_per_domain = getattr(config, 'target_domain_count', None)
        
        training_data, testing_data = sample_dataset(
            training_data_full, 
            testing_data_full, 
            sample_percentage=config.sample_percentage,
            random_seed=42,
            balance_domains_count=target_count_per_domain
        )
        
        # summary of domain distribution in training
        training_df = pd.DataFrame(training_data)
        domains, _ = get_predicate_domains()
        domain_counts = training_df['domain'].value_counts().to_dict()
        
        print(f"Domain distribution in training data:")
        for domain in domains:
            print(f"  {domain}: {domain_counts.get(domain, 0)}")
        
        # cache data with domain info 
        training_df.to_csv(os.path.join(config.output_dir, "training_data_balanced.csv"), index=False)
        pd.DataFrame(testing_data["in_domain"]).to_csv(os.path.join(config.output_dir, "in_domain_test_data_balanced.csv"), index=False)
        pd.DataFrame(testing_data["cross_domain"]).to_csv(os.path.join(config.output_dir, "cross_domain_test_data_balanced.csv"), index=False)
        
        with open(training_cache, 'wb') as f:
            pickle.dump(training_data, f)
        with open(testing_in_domain_cache, 'wb') as f:
            pickle.dump(testing_data["in_domain"], f)
        with open(testing_cross_domain_cache, 'wb') as f:
            pickle.dump(testing_data["cross_domain"], f)
    
    if config.max_test_examples and config.max_test_examples > 0:
        if len(testing_data["in_domain"]) > config.max_test_examples:
            testing_data["in_domain"] = random.sample(testing_data["in_domain"], config.max_test_examples)
        
        if len(testing_data["cross_domain"]) > config.max_test_examples:
            testing_data["cross_domain"] = random.sample(testing_data["cross_domain"], config.max_test_examples)
    
    return training_data, testing_data