import os
import time
import torch
import random
import argparse
import tracemalloc
import numpy as np
import pandas as pd
from pathlib import Path
from omegaconf import OmegaConf
import pickle


def measure_performance(func):
    def wrapper(*args, **kwargs):
        tracemalloc.start()
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        print(f"Time taken: {end_time - start_time:.2f} seconds")
        print(f"Current memory usage: {current / 10**6:.2f} MB; Peak: {peak / 10**6:.2f} MB")
        return result
    return wrapper

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',  type=str, default='../configs/config.yaml')
    parser.add_argument('--data-root-dir',  type=str, default='./data')
    parser.add_argument('--context-length', '-l',  type=int, default=20)
    parser.add_argument('--attr-method', '-m',  type=str, default=None)
    parser.add_argument('--num-seeds', '-n', default=1, type=int, help="Number of seeds to test.")
    parser.add_argument('overrides', nargs='*', help="Any key=svalue arguments to override config values")

    flags =  parser.parse_args()
    config = OmegaConf.load(flags.config)

    cli_conf = OmegaConf.create({
        "data_root_dir": flags.data_root_dir,
        "context_length": flags.context_length,
        "attr_method": flags.attr_method,
        "num_seeds": flags.num_seeds,
    })

    args = OmegaConf.merge(config, cli_conf)

    for override in flags.overrides:
        key, value = override.split('=')
        OmegaConf.update(args, key, value)
    return args

def create_output_directory(args):
    output_root_dir = Path(args.experiment.output_root_dir)
    experiment_tag = args.dataset.name + '_' + args.model.name + '_' + str(args.context_length)
    experiment_dir = output_root_dir / experiment_tag
    if not os.path.isdir(experiment_dir):
        os.makedirs(experiment_dir, exist_ok=True)
    return experiment_dir

def generate_seeds(num_seeds: int):
    random.seed(2024)
    return [random.randint(0, 2**32 - 1) for _ in range(num_seeds)]

def set_seed(seed: int):
    """
    Sets the seed for random number generators in random, numpy, torch, and tensorflow for reproducibility.
    
    Parameters:
    seed (int): The seed value to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_dataset_attributes(dataset_name: str):
    assert dataset_name in ["HarryPotter", 'MothRadioHour'], "Dataset not supported"
    if dataset_name == 'HarryPotter':
        subject_idxs = ["F", "H", "I", "J", "K", "L", "M", "N"]
        rois = ["PostTemp", "AntTemp", "AngularG", "IFG", "MFG", "IFGorb", "pCingulate", "dmpfc"]
        dataset_dir = Path("data") / dataset_name / "fMRI"
        roi_file = dataset_dir / "HP_subj_roi_inds.npy"
        return subject_idxs, rois, dataset_dir, roi_file
    elif dataset_name == 'MothRadioHour':
        subject_idxs = ["01", "02", "03", "04", "05", "06", "07", "08", "09"]
        dataset_dir = Path("/BRAIN/ssms/work/") / dataset_name
        return subject_idxs, None, dataset_dir, None

def convert_dict_to_df(corr_scores: dict):
    # List to store each row of data
    line_data, scatter_data = [], []

    # Loop through brain areas
    for brain_area, model_scores in corr_scores.items():

        # Loop through the models in corr_scores
        for model, scores in model_scores.items():
            
            # Loop through context lengths stored in 'x'
            for idx, context_length in enumerate(scores['line']['x']):
                line = scores['line']['y'][idx]

                line_data.append({
                    'Model': model,
                    'brain_area': brain_area,
                    'Context length': context_length,
                    'Correlation': line
                })

            for idx, context_length in enumerate(scores['scatter']['x']):
                # Extract 'scatter' values for current context length
                scatter = scores['scatter']['y']
                scatter_x = scores['scatter']['x']
                
                # Append the data for this model and context length
                scatter_data.append({
                    'Model': model,
                    'brain_area': brain_area,
                    'Correlation': scatter[idx],  # Scatter lists are long; take current context's value
                    'Context length': scatter_x[idx]  # context_length repeated for scatter data
                })

    return pd.DataFrame(line_data), pd.DataFrame(scatter_data)

def load_pickled_data(file_path: Path):
    """
    Loads data from a pickle file.

    Parameters:
    data_path (Path): The directory where the files are stored.
    layer_idx (int): The index of the layer to load attributions for.

    Returns:
    dict: The loaded attributions for the specified layer.
    """
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return data
