import torch
import numpy as np
import torch.nn.functional as F
import pandas as pd
import abstract_cf.text_generation.bios.utils as bios_utils 
import abstract_cf.text_generation.goemotions.utils as goemotions_utils
import ravfogel_lm_counterfactuals.utils as ravfogel_utils 
from clearml import Task, Dataset
import pickle
import pandas as pd
import os


def load_model(model_name: str):
    if model_name.startswith('ravfogel'):
        return ravfogel_utils.get_counterfactual_model(model_name.replace('ravfogel_', ''))
    # TODO can we use generic transformers model loading here? 
    return ravfogel_utils.load_model(model_name)



def load_dataset(task: str='profession') -> dict[str, pd.DataFrame]:
    if task == 'profession':
        return bios_utils.load_dataset()
    elif task == 'emotion':
        return goemotions_utils.load_dataset()
    elif task=='emotion_ekman':
        return goemotions_utils.load_dataset(ekman=True)
    else:
        raise ValueError(f"Unknown abstraction task: {task}")


def load_clearml_dataset(dataset_id: str = None, dataset_name: str = None, project_name: str = None) -> pd.DataFrame:
    """
    Load a dataset from ClearML by its ID or name.
    
    Args:
        dataset_id: The ID of the ClearML dataset to load.
        dataset_name: The name of the ClearML dataset to load.
        project_name: The project name where the dataset is stored (required if using dataset_name).
        
    Returns:
        A pandas DataFrame containing the loaded dataset.
    """
    # Get the dataset object using either ID or name+project
    if dataset_id:
        dataset_obj = Dataset.get(dataset_id=dataset_id)
    elif dataset_name and project_name:
        dataset_obj = Dataset.get(
            dataset_project=project_name,
            dataset_name=dataset_name,
            only_completed=True,
            auto_create=False
        )
    else:
        raise ValueError("Either dataset_id or both dataset_name and project_name must be provided")
    
    # Download the dataset to a local path
    local_dataset_path = dataset_obj.get_local_copy()
    
    # Find any CSV files in the downloaded directory
    csv_files = [f for f in os.listdir(local_dataset_path) if f.endswith('.csv')]
    
    if not csv_files:
        raise ValueError(f"No CSV files found in the downloaded dataset")
    
    # Load and return the first CSV file as a DataFrame
    csv_path = os.path.join(local_dataset_path, csv_files[0])
    return pd.read_csv(csv_path)


def sample_from_model(
    model, 
    tokenizer, 
    inputs, 
    n_samples, 
    batch_size=5, 
    max_length=80,  
    verbose=False,
    return_scores=False,
) -> tuple[list[str], list[int]] | tuple[list[str], list[int], torch.Tensor]:
    samples = []
    scores = [] if return_scores else None
    if verbose:
        import tqdm
        iteration_range = tqdm.trange(0, n_samples, batch_size)
    else:
        iteration_range = range(0, n_samples, batch_size)
    for i in iteration_range:
        batch_output = model.generate(
            **inputs,
            do_sample=True, 
            num_return_sequences=min(batch_size, n_samples - i), 
            max_length=max_length,
            temperature=1,
            top_k=200,
            output_scores=True,
            return_dict_in_generate=True,
            pad_token_id=tokenizer.pad_token_id,
        )
        # print the length of the generated sequences
        samples.extend(batch_output.sequences)
        if return_scores:
            batch_scores = torch.stack(batch_output.scores)
            scores.append(batch_scores)

    decoded_samples = tokenizer.batch_decode(samples, skip_special_tokens=True)
    if return_scores:
        scores = torch.cat(scores, dim=1).permute(1, 0, 2)
        return decoded_samples, samples, scores
    return decoded_samples, samples


def importance_counterfactual_posterior_estimate(
    all_candidate_probabilities: np.array,
    abstraction_value_index_cf_samples: np.array, 
) -> np.array: 

    weights = []
    for P_a in all_candidate_probabilities:
        sample_weight = 0
        for y_sample in abstraction_value_index_cf_samples:
            gamma_y_given_a = P_a[y_sample].item()
            sample_weight += gamma_y_given_a
        weights.append(sample_weight)

    weights = np.array(weights)
    posterior_A = weights / sum(weights)
    return posterior_A


# Helper function to fetch a pickled artifact from ClearML
def fetch_all_samples_artifact(task_id: str, artifact_name: str='experiment_data.pkl') -> dict:
    """
    Fetches and loads the sample dictionary from a ClearML task artifact.
    """
    # Retrieve the remote task using its task id.
    remote_task = Task.get_task(task_id=task_id)
    
    # Get a local copy of the artifact.
    artifact_path = remote_task.artifacts[artifact_name].get_local_copy()
    
    # Open the pickle file and load the dictionary.
    with open(artifact_path, 'rb') as f:
        all_samples = pickle.load(f)
    return all_samples


def get_samples_from_acf_task(task_id: str) -> pd.DataFrame:
    abstract_cf_samples = fetch_all_samples_artifact(task_id)
    # Create the DataFrame where each row comes from the nested 'prompts' dict
    df = pd.DataFrame.from_dict(
        {sample_id: sample_info['prompts'] for sample_id, sample_info in abstract_cf_samples.items()},
        orient='index'
    )

    # Optionally, you can set the name of the index
    df.index.name = 'id'
    return df

