import json
import argparse
import random
import os
from typing import Dict, List, Tuple, Any, Optional
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import re
from datasets import load_dataset, Dataset
from flan_v2_templates_subset_variations import DOMAIN_TEMPLATES

# template variation types
TEMPLATE_TYPES = ["exact", "synonym", "antonym", "semantic_paraphrase", "disfluent"]

class FlanTemplates:
    
    def __init__(self):
        self.domain_templates = DOMAIN_TEMPLATES
    
    
    def get_cot_esnli_templates(self, variation_type="exact") -> List[Tuple[str, str]]:
        """
        cot esnli: (input_template, output_template)
        """
        if variation_type in self.domain_templates["cot_esnli"]:
            return self.domain_templates["cot_esnli"][variation_type]
        return self.domain_templates["cot_esnli"]["exact"]
    
    
    def get_cot_ecqa_templates(self, variation_type="exact") -> List[Tuple[str, str]]:
        """
        cot_ecqa (input_template, output_template)
        """
        if variation_type in self.domain_templates["cot_ecqa"]:
            return self.domain_templates["cot_ecqa"][variation_type]
        return self.domain_templates["cot_ecqa"]["exact"]
    
    
    def get_cot_sensemaking_templates(self, variation_type="exact") -> List[Tuple[str, str]]:
        """
        cot_sensemaking (input_template, output_template)
        """
        if variation_type in self.domain_templates["cot_sensemaking"]:
            return self.domain_templates["cot_sensemaking"][variation_type]
        return self.domain_templates["cot_sensemaking"]["exact"]
    
    
    def get_sentiment140_templates(self, variation_type="exact") -> List[Tuple[str, str]]:
        """
        sentiment140 (input_template, output_template)
        """
        if variation_type in self.domain_templates["sentiment140"]:
            return self.domain_templates["sentiment140"][variation_type]
        return self.domain_templates["sentiment140"]["exact"]
    
    
    def get_newsroom_templates(self, variation_type="exact") -> List[Tuple[str, str]]:
        """
        newsroom (input_template, output_template)
        """
        if variation_type in self.domain_templates["newsroom"]:
            return self.domain_templates["newsroom"][variation_type]
        return self.domain_templates["newsroom"]["exact"]


class DatasetGenerator:
    def __init__(self, sample_size=100):
        self.sample_size = sample_size
        self.templates = FlanTemplates()

        self.cached_datasets = {}
       
        
    def load_huggingface_dataset(self, dataset_name, dataset_config=None, split="train"):
        
        cache_key = f"{dataset_name}_{dataset_config or 'default'}_{split}"
        
        # check if data cache alr exists 
        if cache_key in self.cached_datasets:
            return self.cached_datasets[cache_key]
        
        if dataset_config:
            dataset = load_dataset(dataset_name, dataset_config, split=split, trust_remote_code=True)
        else:
            dataset = load_dataset(dataset_name, split=split, trust_remote_code=True)
        
        self.cached_datasets[cache_key] = dataset
        
        print(f"Successfully loaded {cache_key} with {len(dataset)} examples")
        return dataset

    
    def generate_esnli_entities(self) -> List[Dict[str, Any]]:
        """generate entities for cot esnli dataset."""
        entities = []
        
        dataset = self.load_huggingface_dataset("esnli/esnli", split="train")
        
        if dataset is not None and len(dataset) > 0:
            # conver label ints to natural language for flexible eval
            label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
            options = ["entailment", "neutral", "contradiction"]
            options_str = ", ".join(options)
            
            indices = random.sample(range(len(dataset)), min(len(dataset), self.sample_size))
            
            # get (input, output), label
            for idx in indices:
                example = dataset[idx]
                
                premise = example.get("premise", "")
                hypothesis = example.get("hypothesis", "")
                
                if "label" in example:
                    if isinstance(example["label"], int):
                        answer = label_map.get(example["label"], "neutral")
                    else:
                        answer = example["label"]
                else:
                    continue
                
                question = f"Premise: {premise}\nHypothesis: {hypothesis}\nDoes the premise entail the hypothesis?"
                
                # cot esnli version, includ cot if avail
                chain_of_thought = ""
                if "explanation_1" in example:
                    chain_of_thought = example.get("explanation_1", "")
                elif "explanation" in example:
                    chain_of_thought = example.get("explanation", "")
                
                if chain_of_thought:
                    entities.append({
                        "question": question,
                        "answer": answer,
                        "chain_of_thought": chain_of_thought
                    })
        
        return entities
    
    def generate_ecqa_entities(self) -> List[Dict[str, Any]]:
        """generate entities for ecqa dataset."""
        entities = []
        
        dataset = self.load_huggingface_dataset("tasksource/ecqa", split="train")
        
        if dataset is not None and len(dataset) > 0:
            indices = random.sample(range(len(dataset)), min(len(dataset), self.sample_size))
            
            for idx in indices:
                example = dataset[idx]
                
                question = example.get("question", "")
                answer = example.get("answer", "")
                explanation = example.get("explanation", "")
                choices = example.get("choices", [])
                
                if isinstance(choices, list) and len(choices) > 0:
                    options_str = ", ".join(choices)
                else:
                    options_str = "A, B, C, D, E"

                if explanation:
                    entities.append({
                        "question": question,
                        "answer": answer,
                        "options_": options_str,
                        "chain_of_thought": explanation
                    })
        
        return entities
    
    # def generate_sensemaking_entities(self) -> List[Dict[str, Any]]:
    #     """generate entities for cot sensemaking dataset."""
    #     entities = []
        
    #     dataset_name = "pharaouk/CoT-Collection"
    
    #     dataset = self.load_huggingface_dataset(dataset_name, split="train")
    #     if dataset is not None and len(dataset) > 0:
    #         indices = random.sample(range(len(dataset)), min(len(dataset), self.sample_size))
            
    #         for idx in indices:
    #             example = dataset[idx]
    #             question = example.get('source', '')
    #             answer = example.get('target', '')
    #             chain_of_thought = example.get('rationale', '')
    
    #             if not question or not answer or not chain_of_thought:
    #                 continue
                
    #             entities.append({
    #                 "question": question,
    #                 "answer": answer,
    #                 "chain_of_thought": chain_of_thought
    #             })
            
    #         if entities:
    #             break

    #     return entities
    
    def generate_sentiment140_entities(self) -> List[Dict[str, Any]]:
        """generate entities for sentiment140 dataset."""
        entities = []

        dataset = self.load_huggingface_dataset("stanfordnlp/sentiment140", split="train")

        
        if dataset is not None and len(dataset) > 0:
            indices = random.sample(range(len(dataset)), min(len(dataset), self.sample_size))

            sentiment_map = {
                0: "negative",
                1: "neutral",
                2: "neutral",
                4: "positive"
            }
            
            options = ["positive", "negative", "neutral"]
            options_str = ", ".join(options)
            
            for idx in indices:
                example = dataset[idx]
                
                if 'text' in example:
                    text = example['text']
                elif 'Tweet' in example:
                    text = example['Tweet']
                else:
                    continue
                    
                if 'sentiment' in example:
                    sentiment_val = example['sentiment']
                    sentiment = sentiment_map.get(sentiment_val, "neutral")
                elif 'label' in example:
                    sentiment_val = example['label']
                    sentiment = sentiment_map.get(sentiment_val, "neutral")
                else:
                    continue
                    
                entities.append({
                    "text": text,
                    "answer": sentiment,
                    "options_": options_str
                })

        return entities
    
    
    def _extract_required_fields(self, input_template, output_template):
        fields = set()
        for template in [input_template, output_template]:
            matches = re.findall(r'\{([^}]+)\}', template)
            fields.update(matches)
        return fields
    
    def apply_templates_to_entities(self, entities: List[Dict[str, Any]], templates: List[Tuple[str, str]], dataset_name: str) -> List[Dict[str, Any]]:
        result = []
        
        for entity in tqdm(entities, desc=f"Applying templates to {dataset_name}"):
            for template_idx, (input_template, output_template) in enumerate(templates):
                required_fields = self._extract_required_fields(input_template, output_template)
                if not all(field in entity for field in required_fields):
                    continue

                input_text = input_template.format(**entity)
                output_text = output_template.format(**entity)
                
                result.append({
                    "dataset": dataset_name,
                    "template_id": template_idx,
                    "input": input_text,
                    "output": output_text,
                    "entity": entity
                })
        
        return result
    
    
    def generate_full_dataset(self) -> List[Dict[str, Any]]:
        """generate the full dataset with all entity-template combinations."""
        all_data = []
        
        print(f"Generating dataset with sample_size={self.sample_size}")

        datasets = [
            ("cot_esnli", self.generate_esnli_entities, self.templates.get_cot_esnli_templates),
            ("cot_ecqa", self.generate_ecqa_entities, self.templates.get_cot_ecqa_templates),
            ("cot_sensemaking", self.generate_sensemaking_entities, self.templates.get_cot_sensemaking_templates),
            ("sentiment140", self.generate_sentiment140_entities, self.templates.get_sentiment140_templates),
        ]
        
        template_variations = TEMPLATE_TYPES
        
        for dataset_name, entity_generator, template_getter in datasets:
            try:
                print(f"Generating entities for {dataset_name}...")
                entities = entity_generator()
                
                if not entities:
                    print(f"No entities found for {dataset_name}, skipping...")
                    continue
                
                print(f"Applying templates to entities for {dataset_name}...")
                
                # for each template variation type
                for variation_type in template_variations:
                    print(f"  Processing {variation_type} templates...")
                    templates = template_getter(variation_type=variation_type)

                    dataset_examples = self.apply_templates_to_entities(
                        entities, 
                        templates, 
                        f"{dataset_name}_{variation_type}"
                    )
                    
                    print(f"  Generated {len(dataset_examples)} examples for {dataset_name}_{variation_type}")
                    all_data.extend(dataset_examples)
                    
            except Exception as e:
                print(f"Error processing {dataset_name}: {e}")
        
        print(f"Total examples generated: {len(all_data)}")
        return all_data
    
    
    def save_to_csv(self, output_file: str = "flan_v2_dataset.csv"):
        data = self.generate_full_dataset()
        
        df = pd.DataFrame([{
            "dataset": item["dataset"],
            "template_id": item["template_id"],
            "input": item["input"],
            "output": item["output"]
        } for item in data])
        
        df.to_csv(output_file, index=False)
        
        print(f"Dataset saved to {output_file} with {len(data)} examples")

        print("\nDataset counts:")
        dataset_counts = df['dataset'].value_counts()
        for dataset, count in dataset_counts.items():
            print(f"{dataset}: {count} examples")
        
        return data


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--sample_size', type=int, default=20)
    parser.add_argument('--output_dir', type=str, default='./')
    parser.add_argument('--file_prefix', type=str, default='flan_v2_dataset')
    parser.add_argument('--template_variations', type=str, default='all')

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    if args.template_variations != 'all':
        requested_variations = args.template_variations.split(',')
        for var in requested_variations:
            if var not in TEMPLATE_TYPES:
                print(f"Unknown template variation '{var}'. Using default variations.")
                break
        else:
            TEMPLATE_TYPES.clear()
            TEMPLATE_TYPES.extend(requested_variations)

    print(f"Using template variations: {', '.join(TEMPLATE_TYPES)}")

    generator = DatasetGenerator(sample_size=args.sample_size)

    formats = args.formats

    data = None

    if 'csv' in formats:
        output_file = os.path.join(args.output_dir, f"{args.file_prefix}.csv")
        data = generator.save_to_csv(output_file)