import time
from argparse import Namespace
from pathlib import Path
from typing import Dict, List, Tuple

import pandas as pd
from tqdm import tqdm

from lm_understanding.explanations.explanations import LocalExplanationSet
from lm_understanding.question_template import (TemplateDataset,
                                                TemplateModelBehavior)


def make_run_name() -> str:
    return str(int(time.time()))


def template_path(config: Namespace, new: bool = False) -> Path:
    run_name = make_run_name() if new else '_combined'
    return Path('templates') / config.task.name / run_name


def dataset_path(config: Namespace, new: bool = False) -> Path:
    run_name = make_run_name() if new else 'final_dataset'
    model_name = config.model.name if new else config.model.dataset_model_name
    return Path('datasets') / model_name / config.task.name / config.distributional_shift.name / run_name


def model_behavior_path(config: Namespace, new: bool = False) -> Path:
    run_name = make_run_name() if new else 'final_behavior'
    base_path = Path('model_behavior_results') / config.model.name / config.task.name / config.distributional_shift.name
    return base_path / run_name


def baseline_results_path(config: Namespace) -> Path:
    return Path('baseline_results') / config.model.name / config.task.name / config.distributional_shift.name


def load_local_explanations(config: Namespace) -> List[LocalExplanationSet]:
    explanation_name = config.explanation.name if 'explanation' in config else config.baseline.explanation.name
    explanations = []
    for template_dir in model_behavior_path(config).iterdir():
        explanation_path = template_dir / 'explanations' / f'{explanation_name}.csv'
        if not explanation_path.exists():
            continue
        global_explanation_path = template_dir / 'explanations' / f'{explanation_name}_global.txt'
        df = pd.read_csv(explanation_path)
        explanation = LocalExplanationSet.from_df(df)
        if global_explanation_path.exists():
            explanation.global_explanation = global_explanation_path.read_text()
        explanations.append(explanation)
    return explanations


def load_model_behaviors_from_dataset(config: Namespace) -> Tuple[List[TemplateModelBehavior], Dict]:
    dataset = TemplateDataset.load(dataset_path(config))
    initial_templates = len(dataset.templates)
    dataset = dataset.filter_from_config(config)
    kept_templates = len(dataset.templates)
    behavior_path = dataset_path(config) / 'all_dataset_behavior'
    model_behaviors = []
    for model_behavior_dir in tqdm(list(behavior_path.iterdir()), desc='Loading Model Behavior'):
        if not model_behavior_dir.is_dir():
            continue
        if not model_behavior_dir.name in dataset.template_ids:
            continue
        model_behavior = TemplateModelBehavior.load(model_behavior_dir)
        model_behaviors.append(model_behavior)
        if len(model_behaviors) >= config.dataset_construction.templates_per_topic:
            break
    return model_behaviors, dict(filtering_kept_fraction=kept_templates / initial_templates)

def load_model_behaviors(config: Namespace) -> Tuple[List[TemplateModelBehavior], Dict]:
    behavior_path = model_behavior_path(config)
    model_behaviors = []
    for model_behavior_dir in tqdm(list(behavior_path.iterdir()), desc='Loading Model Behavior'):
        if not model_behavior_dir.is_dir():
            continue
        model_behavior = TemplateModelBehavior.load(model_behavior_dir)
        model_behaviors.append(model_behavior)
        if len(model_behaviors) >= config.dataset_construction.templates_per_topic:
            break
    return model_behaviors, dict()
