import numpy as np
import pandas as pd
from pathlib import Path
from omegaconf import OmegaConf
from os import makedirs, listdir
from os.path import join, isdir, abspath, dirname, exists
from typing import List, Callable, Union, Dict
import pandas as pd
import shutil
# import papermill as pm
from typing import List, Callable
import pandas as pd
import os
import warnings
from pathlib import Path
import boto3
from botocore import UNSIGNED
from botocore.config import Config
from io import StringIO, BytesIO
import torch
import re

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"



def find_filepath_from_prefix(prefix: str, search_dir: str, s3: bool = False, bucket: str = None) -> Union[str, List[str], Dict[str, List[str]], None]:
    """
    Look for files having a precise prefix, either locally or in S3.
    Special cases:
    - when prefix is "grad", looks for studentcumulative_gradients.pth in parameters/ subdirectory.
    - when prefix is "checkpoints", looks for checkpoint files in parameters/ subdirectory and returns
      a dictionary with separate lists for task_0 and task_1 checkpoints.

    Args:
        prefix (str): identifier for files to look for
        search_dir (str): folder or S3 prefix to search under
        s3 (bool): whether to search in S3
        bucket (str): S3 bucket name (required if s3=True)

    Returns:
        Union[str, List[str], Dict[str, List[str]], None]: 
        - For "checkpoints": dict with keys "task_0_checkpoints" and "task_1_checkpoints"
        - For other prefixes: matching file paths or single match
    """
    # Special case for checkpoint files (task-specific)
    if prefix == "checkpoints":
        # Define regex patterns for task checkpoints
        task_0_pattern = re.compile(r"studenttask_0_epoch_\d+\.pth")
        task_1_pattern = re.compile(r"studenttask_1_epoch_\d+\.pth")
        
        if s3:
            if not bucket:
                raise ValueError("Bucket name must be provided in S3 mode.")
            
            s3_client = get_s3_client()
            # Look for checkpoint files in parameters/ subdirectory
            parameters_prefix = f"{search_dir.rstrip('/')}/parameters/"
            
            try:
                paginator = s3_client.get_paginator("list_objects_v2")
                pages = paginator.paginate(Bucket=bucket, Prefix=parameters_prefix)
                
                task_0_checkpoints = []
                task_1_checkpoints = []
                
                for page in pages:
                    for obj in page.get('Contents', []):
                        key = obj['Key']
                        filename = key.split('/')[-1]
                        
                        if task_0_pattern.match(filename):
                            task_0_checkpoints.append(key)
                        elif task_1_pattern.match(filename):
                            task_1_checkpoints.append(key)
                
                return {
                    "task_0_checkpoints": task_0_checkpoints if task_0_checkpoints else None,
                    "task_1_checkpoints": task_1_checkpoints if task_1_checkpoints else None
                }
            except Exception as e:
                return {
                    "task_0_checkpoints": None,
                    "task_1_checkpoints": None
                }
        else:
            search_dir = Path(search_dir)
            parameters_dir = search_dir / "parameters"
            
            if not parameters_dir.exists():
                return {
                    "task_0_checkpoints": None,
                    "task_1_checkpoints": None
                }
            
            task_0_checkpoints = []
            task_1_checkpoints = []
            
            # Search for checkpoint files in parameters directory
            for file_path in parameters_dir.glob("*.pth"):
                filename = file_path.name
                
                if task_0_pattern.match(filename):
                    task_0_checkpoints.append(file_path)
                elif task_1_pattern.match(filename):
                    task_1_checkpoints.append(file_path)
            
            return {
                "task_0_checkpoints": task_0_checkpoints if task_0_checkpoints else None,
                "task_1_checkpoints": task_1_checkpoints if task_1_checkpoints else None
            }
    
    # Special case for gradient files
    if prefix == "grad":
        if s3:
            if not bucket:
                raise ValueError("Bucket name must be provided in S3 mode.")
            
            s3_client = get_s3_client()
            # Look for student_cumulative_gradients.pth in parameters/ subdirectory
            grad_key = f"{search_dir.rstrip('/')}/parameters/student_cumulative_gradients.pth"
            
            try:
                s3_client.head_object(Bucket=bucket, Key=grad_key)
                return grad_key
            except:
                return []
        else:
            search_dir = Path(search_dir)
            grad_file = search_dir / "parameters" / "student_cumulative_gradients.pth"
            if grad_file.exists():
                return grad_file
            else:
                return []
    
    # Special case for residuals files
    if prefix == "residuals":
        if s3:
            if not bucket:
                raise ValueError("Bucket name must be provided in S3 mode.")
            
            s3_client = get_s3_client()
            # Look for student_cumulative_residuals.pth in parameters/ subdirectory
            residuals_key = f"{search_dir.rstrip('/')}/parameters/student_cumulative_residuals.pth"
            
            try:
                s3_client.head_object(Bucket=bucket, Key=residuals_key)
                return residuals_key
            except:
                return []
        else:
            search_dir = Path(search_dir)
            residuals_file = search_dir / "parameters" / "student_cumulative_residuals.pth"
            if residuals_file.exists():
                return residuals_file
            else:
                return []
    
    # Special case for sharpness files
    if prefix == "sharpness":
        if s3:
            if not bucket:
                raise ValueError("Bucket name must be provided in S3 mode.")
            
            s3_client = get_s3_client()
            # Look for student_sharpness.pth in parameters/ subdirectory
            sharpness_key = f"{search_dir.rstrip('/')}/parameters/student_sharpness.pth"
            
            try:
                s3_client.head_object(Bucket=bucket, Key=sharpness_key)
                return sharpness_key
            except:
                return []
        else:
            search_dir = Path(search_dir)
            sharpness_file = search_dir / "parameters" / "student_sharpness.pth"
            if sharpness_file.exists():
                return sharpness_file
            else:
                return []
    
    # Special case for co-activations files
    if prefix == "co_activations":
        if s3:
            if not bucket:
                raise ValueError("Bucket name must be provided in S3 mode.")
            
            s3_client = get_s3_client()
            # Look for student_co_activations.pth in parameters/ subdirectory
            co_activations_key = f"{search_dir.rstrip('/')}/parameters/student_co_activations.pth"
            
            try:
                s3_client.head_object(Bucket=bucket, Key=co_activations_key)
                return co_activations_key
            except:
                return []
        else:
            search_dir = Path(search_dir)
            co_activations_file = search_dir / "parameters" / "student_co_activations.pth"
            if co_activations_file.exists():
                return co_activations_file
            else:
                return []
    
    # Original logic for other prefixes
    if s3:
        if not bucket:
            raise ValueError("Bucket name must be provided in S3 mode.")

        s3_client = get_s3_client()
        paginator = s3_client.get_paginator("list_objects_v2")
        pages = paginator.paginate(Bucket=bucket, Prefix=search_dir)

        matching_keys = []

        for page in pages:
            for obj in page.get('Contents', []):
                key = obj['Key']
                filename = key.split('/')[-1]
                if filename.startswith(prefix):
                    matching_keys.append(key)

        if len(matching_keys) != 1:
            return matching_keys
        else:
            return matching_keys[0]
    
    else:
        search_dir = Path(search_dir)
        matching_files = list(search_dir.glob(f"{prefix}*"))
        if len(matching_files) != 1:
            return matching_files
        else:
            return matching_files[0]



# def save_eval_nb(filepath: str, case: str)->None:
#     """
#     Runs and save a copy of an evaluation notebook inside an experiment folder
#     Args:
#         filepath (str): folder of the experiments
#         case (str): specific experiment
#     """
    
#     notebooks_dir = join(dirname(dirname(abspath(__file__))),"notebooks")
#     template_notebook = join(notebooks_dir,"evaluation_templates", "eval_MNIST_v2.ipynb")  # The original template

#     working_notebook = join(filepath, case, f"nb_eval_{case}.ipynb")  # Copy of the template in current directory

#     # Step 1: Copy the template notebook
#     shutil.copy(template_notebook, working_notebook)

#     # Step 2: Execute the copied notebook in place
#     pm.execute_notebook(
#         working_notebook,  # Execute this file (modified version of template)
#         working_notebook,  # Save results in the same file
#         parameters={"expath": filepath, "experiment": case}  # Inject parameters
#     )

#     print(f"Execution complete. Updated notebook saved as {working_notebook}")



def get_cf(df_res: pd.DataFrame, config: dict)-> float:
    """
    Calculates catastrophic forgetting (CF) for the current experiment

    Args:
        filepath (str): experiment folder

    Returns:
        float: calculated CF
    """
    
    def calculate_accuracy_drop(df, label, epoch_t0):
        a0 = df[label][20:epoch_t0].mean()
        a1 = df[label][epoch_t0+20:].mean()
        return a0-a1, a0, a1
    
    epoch_t0 = config["epochs_t0"]
    
    # hard-labels
    hard_label = "h0_metric"
    hard_metrics = calculate_accuracy_drop(df_res, hard_label, epoch_t0)
    
    # soft-labels
    soft_label = "h0_student_vs_teacher"
    if soft_label in df_res.columns:
        soft_metrics = calculate_accuracy_drop(df_res, soft_label, epoch_t0)
    else:
        soft_metrics = (np.nan, np.nan, np.nan)
    
    if "h0_teacher_vs_true" in df_res.columns and "h1_teacher_vs_true" in df_res.columns:
        teachers_metrics = (df_res["h0_teacher_vs_true"].mean(), df_res["h1_teacher_vs_true"].mean())
    else:
        teachers_metrics = (np.nan, np.nan)
    
    return hard_metrics, soft_metrics, teachers_metrics


# bottom functions________________________________________________________________________________________________

def get_runs_metrics(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->pd.DataFrame:
    """
    Loops over the level_folders and extract the metrics, like the forgetting, as follows 
    - reads the result dataframe inside each folder and the experiment configuration.
    - passes this information to `get_cf()` to calculate forgetting
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        pd.DataFrame: result dataframe containing metrics for the respective level_folders
    """
    
    cols = {
        "bottom_level"  : [],
        "cf"            : [],
        "a0_hard"       : [],  
        "a1_hard"       : [],
        "cf_soft"       : [],
        "a0_soft"       : [],
        "a1_soft"       : [],
        "teacher_a0"    : [],
        "teacher_a1"    : [],
        }
    
    s3_client = get_s3_client() if s3 else None
    
    # loop over the bottom folders
    for case in level_folders:
        
        # update internal filepath
        filepath_ = join(filepath,case) + "/" 
        
        
        # find files
        config_path = find_filepath_from_prefix("config", filepath_, s3, bucket)
        result_path = find_filepath_from_prefix("result", filepath_, s3, bucket)
        
        if isinstance(config_path, list) or isinstance(result_path, list):
            print(f"Multiple/no files found for {case}. Expected only one. Skipping this case.")
            continue
        
        else:
            # get df and config file
            if s3:
                try:
                    obj_res = s3_client.get_object(Bucket=bucket, Key=result_path)
                    df_res = pd.read_csv(StringIO(obj_res["Body"].read().decode("utf-8")))

                except Exception as e:
                    print(f"Failed to read {result_path} from S3: {e}")
                    continue
                try:
                    obj_config = s3_client.get_object(Bucket=bucket, Key=config_path)
                    config = OmegaConf.load(StringIO(obj_config["Body"].read().decode("utf-8")))

                except Exception as e:
                    print(f"Failed to read {config_path} from S3: {e}")
                    continue
            else:
                df_res = pd.read_csv(result_path)
                config = OmegaConf.load(config_path)
            
            
            (cf_hard, a0_hard, a1_hard), (cf_soft, a0_soft, a1_soft), (teacher_a0, teacher_a1) = get_cf(df_res,config)
            
            cols["bottom_level"].append(case)
            
            cols["cf"].append(cf_hard)
            cols["a0_hard"].append(a0_hard)
            cols["a1_hard"].append(a1_hard)
            
            cols["cf_soft"].append(cf_soft)
            cols["a0_soft"].append(a0_soft)
            cols["a1_soft"].append(a1_soft)
            
            cols["teacher_a0"].append(teacher_a0)
            cols["teacher_a1"].append(teacher_a1)


    return pd.DataFrame().from_dict(cols)



def get_cumulative_loss(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->pd.DataFrame:
    """
    Loops over the level_folders and extract the metrics, like the forgetting, as follows 
    - reads the result dataframe inside each folder and the experiment configuration.
    - passes this information to `get_cf()` to calculate forgetting
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        pd.DataFrame: result dataframe containing metrics for the respective level_folders
    """
    
    cols = {
        "bottom_level": [],
        "train_loss_task_1": [],
        "train_loss_task_2": [],
        }
    
    s3_client = get_s3_client() if s3 else None
    
    # loop over the bottom folders
    for case in level_folders:
        
        # update internal filepath
        filepath_ = join(filepath,case) + "/" 
        
        
        # find files
        config_path = find_filepath_from_prefix("config", filepath_, s3, bucket)
        checkpoints_dict = find_filepath_from_prefix("checkpoints", filepath_, s3, bucket)
        
        if isinstance(config_path, list) or not isinstance(checkpoints_dict, dict):
            print(f"Multiple/no config files found or no checkpoints dict for {case}. Skipping this case.")
            continue
        
        else:
            # get config file
            if s3:
                try:
                    obj_config = s3_client.get_object(Bucket=bucket, Key=config_path)
                    config = OmegaConf.load(StringIO(obj_config["Body"].read().decode("utf-8")))
                except Exception as e:
                    print(f"Failed to read {config_path} from S3: {e}")
                    continue
            else:
                config = OmegaConf.load(config_path)
            
            # Extract checkpoint lists
            task_0_checkpoints = checkpoints_dict.get("task_0_checkpoints")
            task_1_checkpoints = checkpoints_dict.get("task_1_checkpoints")
            
            # Check if we have checkpoints for both tasks
            if task_0_checkpoints is None and task_1_checkpoints is None:
                print(f"No checkpoints found for {case}. Skipping this case.")
                continue
            else:
                
                train_loss_task_1 = sum([torch.load(ckpt)["train_loss"] for ckpt in task_0_checkpoints])
                train_loss_task_2 = sum([torch.load(ckpt)["train_loss"] for ckpt in task_1_checkpoints])
            
            
            cols["bottom_level"].append(case)
            cols["train_loss_task_1"].append(train_loss_task_1)  # Placeholder
            cols["train_loss_task_2"].append(train_loss_task_2)  # Placeholder

    return pd.DataFrame().from_dict(cols)




def get_cumulative_grads(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->dict:
    """
    Loops over the level_folders and extract the cumulative gradients from checkpoints
    
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        dict: dictionary containing gradient data for each case
    """
    
    cols = {
        "bottom_level": [],
        }
    
    stats = [
        ("sum", np.sum),
        ("mean", np.mean),
        ("std", np.std),
        ]
    
    eta = 0.01
    
    s3_client = get_s3_client() if s3 else None
    
    # loop over the bottom folders
    for case in level_folders:
        
        # update internal filepath
        filepath_ = join(filepath, case) + "/" 
        
        # find gradient checkpoint file
        grad_path = find_filepath_from_prefix("grad", filepath_, s3, bucket)
        
        if grad_path is None or isinstance(grad_path, list):
            print(f"Gradient file not found for {case}. Skipping this case.")
            continue
        
        try:
            # load gradient data
            if s3:
                obj_res = s3_client.get_object(Bucket=bucket, Key=grad_path)
                # Use BytesIO for binary PyTorch files
                data = torch.load(BytesIO(obj_res["Body"].read()))
            else:
                data = torch.load(grad_path)
            
            # calculate quantities of interest and store in cols
            for key in data["cumulative_gradients"][0].keys():
                
                grad_task_1 = eta *torch.flatten(data["cumulative_gradients"][0][key])
                grad_task_2 = eta *torch.flatten(data["cumulative_gradients"][1][key])
                
                # calculate proxy of gradient alignment
                # proxy = torch.nn.functional.kl_div(input=torch.nn.functional.log_softmax(grad_task_1.abs(), dim=0), target=torch.nn.functional.log_softmax(grad_task_2.abs(), dim=0), reduction="sum", log_target=True).item()
                
                
                # calculate product of cumulative gradients
                grads_prod = torch.mul(grad_task_1, grad_task_2).detach().cpu().numpy()
                
                for (stat, fun) in stats:
                    key_ = key + "_" + stat
                    if key_ not in cols:
                        cols[key_] = []
                        
                    cols[key_].append(fun(grads_prod))
                    
                # key_proxy = key + "_proxy"
                # if key_proxy not in cols:
                #         cols[key_proxy] = []
                        
                # cols[key_proxy].append(proxy)
                
            # Store the gradient data for this case
            cols["bottom_level"].append(case)
            
        except Exception as e:
            print(f"Failed to read gradient file {grad_path} for case {case}: {e}")
            continue
    
    
    return pd.DataFrame().from_dict(cols)


def get_head_phi(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->pd.DataFrame:
    """
    Loops over the level_folders and extract the cosine similarity between h0 and h1 heads from checkpoints
    
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        pd.DataFrame: dataframe containing head phi similarity data for each case
    """
    
    cols = {
        "bottom_level"          : [],
        "head_phi_similarity"   : [],
        "head_frob_prod"        : []
    }
    
    s3_client = get_s3_client() if s3 else None
    
    # loop over the bottom folders
    for case in level_folders:
        
        # update internal filepath
        filepath_ = join(filepath, case) + "/" 
        
        # find checkpoint files
        checkpoints_dict = find_filepath_from_prefix("checkpoints", filepath_, s3, bucket)
        
        if not isinstance(checkpoints_dict, dict):
            print(f"No checkpoints dict found for {case}. Skipping this case.")
            continue
        
        # Extract task_0 checkpoints
        task_0_checkpoints = checkpoints_dict.get("task_0_checkpoints")
        
        if task_0_checkpoints is None or len(task_0_checkpoints) == 0:
            print(f"No task_0 checkpoints found for {case}. Skipping this case.")
            continue
        
        try:
            # Use the first available task_0 checkpoint (heads are frozen during training)
            checkpoint_path = task_0_checkpoints[0]
            
            # load checkpoint data
            if s3:
                obj_res = s3_client.get_object(Bucket=bucket, Key=checkpoint_path)
                # Use BytesIO for binary PyTorch files
                checkpoint = torch.load(BytesIO(obj_res["Body"].read()))
            else:
                checkpoint = torch.load(checkpoint_path)
            
            # Extract model state dict
            model_state_dict = checkpoint.get("model_state_dict", {})
            
            # Check if both h0.weight and h1.weight exist
            if "h0.weight" not in model_state_dict or "h1.weight" not in model_state_dict:
                print(f"Missing h0.weight or h1.weight in model_state_dict for {case}. Skipping this case.")
                continue
            
            # Extract head weights
            h0 = model_state_dict["h0.weight"]
            h1 = model_state_dict["h1.weight"]
            
            # Calculate metrics
            cosine_sim = torch.cosine_similarity(h0, h1, dim=1)[0].item() # take only the first value
            frobeneus_prod = torch.sum(h0*h1).item()
            
            # Store results
            cols["bottom_level"].append(case)
            cols["head_phi_similarity"].append(cosine_sim)
            cols["head_frob_prod"].append(frobeneus_prod)
            
        except Exception as e:
            print(f"Failed to process checkpoint {checkpoint_path} for case {case}: {e}")
            continue
    
    return pd.DataFrame().from_dict(cols)


def get_cumulative_residuals(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->pd.DataFrame:
    """
    Loops over the level_folders and extract the cumulative residuals from checkpoints
    
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        pd.DataFrame: dataframe containing residuals data for each case
    """
    
    cols = {
        "bottom_level": [],
        "cumulative_residual_norm_task_1": [],
        "cumulative_residual_norm_task_2": [],
    }
    
    s3_client = get_s3_client() if s3 else None
    
    # loop over the bottom folders
    for case in level_folders:
        
        # update internal filepath
        filepath_ = join(filepath, case) + "/" 
        
        # find residuals checkpoint file
        residuals_path = find_filepath_from_prefix("residuals", filepath_, s3, bucket)
        
        if residuals_path is None or isinstance(residuals_path, list):
            continue
        
        try:
            # load residuals data
            if s3:
                obj_res = s3_client.get_object(Bucket=bucket, Key=residuals_path)
                # Use BytesIO for binary PyTorch files
                data = torch.load(BytesIO(obj_res["Body"].read()))
            else:
                data = torch.load(residuals_path)
                
            # Store the case for now (data processing will be added once structure is known)
            cols["bottom_level"].append(case)
            cols["cumulative_residual_norm_task_1"].append(torch.norm(data[0]).cpu().item())
            cols["cumulative_residual_norm_task_2"].append(torch.norm(data[1]).cpu().item())
            
        except Exception as e:
            continue
    
    return pd.DataFrame().from_dict(cols)


def get_sam(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->pd.DataFrame:
    """
    Loops over the level_folders and extract the sharpness data from checkpoints
    
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        pd.DataFrame: dataframe containing sharpness data for each case
    """
    
    cols = {
        "bottom_level": [],
    }
    
    s3_client = get_s3_client() if s3 else None
    
    # loop over the bottom folders
    for case in level_folders:
        
        # update internal filepath
        filepath_ = join(filepath, case) + "/" 
        
        # find sharpness checkpoint file
        sharpness_path = find_filepath_from_prefix("sharpness", filepath_, s3, bucket)
        
        if sharpness_path is None or isinstance(sharpness_path, list):
            continue
        
        try:
            # load sharpness data
            if s3:
                obj_res = s3_client.get_object(Bucket=bucket, Key=sharpness_path)
                # Use BytesIO for binary PyTorch files
                data = torch.load(BytesIO(obj_res["Body"].read()))
            else:
                data = torch.load(sharpness_path)
            
            max_epoch_task1 = max([k for k in data[0].keys()])
            max_epoch_task2 = max([k for k in data[1].keys()])
            
            for task, max_epoch in [(0,max_epoch_task1), (1,max_epoch_task2)]:
                
                for layer in data[task][0].keys():
                    key = "sam_" + f"task_{task+1}_" + layer + "_start"
                    if key not in cols.keys():
                        cols[key] = []
                    cols[key].append(data[task][0][layer])
                    
                for layer in data[task][max_epoch].keys():
                    key = "sam_" + f"task_{task+1}_" + layer + "_end"
                    if key not in cols.keys():
                        cols[key] = []
                    cols[key].append(data[task][max_epoch][layer])
            
            # Store the case for now (data processing will be added once structure is known)
            cols["bottom_level"].append(case)
            
        except Exception as e:
            continue
    
    return pd.DataFrame().from_dict(cols)


def get_feat_part(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->pd.DataFrame:
    """
    Loops over the level_folders and extract the feature participation data from co-activations checkpoints
    
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        pd.DataFrame: dataframe containing feature participation data for each case
    """
    
    cols = {
        "bottom_level": [],
    }
    
    s3_client = get_s3_client() if s3 else None
    
    # loop over the bottom folders
    for case in level_folders:
        
        # update internal filepath
        filepath_ = join(filepath, case) + "/" 
        
        # find co-activations checkpoint file
        co_activations_path = find_filepath_from_prefix("co_activations", filepath_, s3, bucket)
        
        if co_activations_path is None or isinstance(co_activations_path, list):
            continue
        
        try:
            # load co-activations data
            if s3:
                obj_res = s3_client.get_object(Bucket=bucket, Key=co_activations_path)
                # Use BytesIO for binary PyTorch files
                data = torch.load(BytesIO(obj_res["Body"].read()))
            else:
                data = torch.load(co_activations_path)
            
            
            def get_pr(C:torch.Tensor)-> float:
                """
                Helper to calculate the participation ration PR
                Args:
                    C (torch.Tensor): features covariance matrix
                Returns:
                    float: PR
                """
                tr = torch.trace(C)
                fro2 = (C**2).sum()
                pr = (tr*tr) / fro2
                return pr.item()
                
                
            max_epoch_task1 = max([k for k in data[0].keys()])
            max_epoch_task2 = max([k for k in data[1].keys()])
            
            for task, max_epoch in [(0,max_epoch_task1), (1,max_epoch_task2)]:
                
                for layer in data[task][0].keys():
                    key = "pr_" + f"task_{task+1}_" + layer + "_start"
                    if key not in cols.keys():
                        cols[key] = []
                    cols[key].append(get_pr(data[task][0][layer]))
                    
                for layer in data[task][max_epoch].keys():
                    key = "pr_" + f"task_{task+1}_" + layer + "_end"
                    if key not in cols.keys():
                        cols[key] = []
                    cols[key].append(get_pr(data[task][max_epoch][layer]))
            
            # Store the case for now (data processing will be added once structure is known)
            cols["bottom_level"].append(case)
            
        except Exception as e:
            continue
    
    return pd.DataFrame().from_dict(cols)


def get_weight_displacement(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->pd.DataFrame:
    """
    Loops over the level_folders and extracts the first and last checkpoints for each task
    
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        pd.DataFrame: dataframe containing checkpoint data for each case
    """
    
    cols = {
        "bottom_level": [],
    }
    
    s3_client = get_s3_client() if s3 else None
    
    # loop over the bottom folders
    for case in level_folders:
        
        # update internal filepath
        filepath_ = join(filepath, case) + "/" 
        
        # find checkpoint files
        checkpoints_dict = find_filepath_from_prefix("checkpoints", filepath_, s3, bucket)
        
        if not isinstance(checkpoints_dict, dict):
            continue
        
        # Extract checkpoint lists
        task_0_checkpoints = checkpoints_dict.get("task_0_checkpoints")
        task_1_checkpoints = checkpoints_dict.get("task_1_checkpoints")
        
        if task_0_checkpoints is None and task_1_checkpoints is None:
            continue
        
        try:
            # Initialize checkpoint data storage
            checkpoint_data = {
                "task_0_epoch_0": None,
                "task_0_epoch_49": None,
                "task_1_epoch_0": None,
                "task_1_epoch_49": None,
            }
            
            # Process task 0 checkpoints
            if task_0_checkpoints:
                for checkpoint_path in task_0_checkpoints:
                    filename = checkpoint_path.split('/')[-1] if s3 else checkpoint_path.name
                    
                    if "studenttask_0_epoch_0.pth" in filename:
                        if s3:
                            obj_res = s3_client.get_object(Bucket=bucket, Key=checkpoint_path)
                            checkpoint_data["task_0_epoch_0"] = torch.load(BytesIO(obj_res["Body"].read()))
                        else:
                            checkpoint_data["task_0_epoch_0"] = torch.load(checkpoint_path)
                    
                    elif "studenttask_0_epoch_49.pth" in filename:
                        if s3:
                            obj_res = s3_client.get_object(Bucket=bucket, Key=checkpoint_path)
                            checkpoint_data["task_0_epoch_49"] = torch.load(BytesIO(obj_res["Body"].read()))
                        else:
                            checkpoint_data["task_0_epoch_49"] = torch.load(checkpoint_path)
            
            # Process task 1 checkpoints
            if task_1_checkpoints:
                for checkpoint_path in task_1_checkpoints:
                    filename = checkpoint_path.split('/')[-1] if s3 else checkpoint_path.name
                    
                    if "studenttask_1_epoch_0.pth" in filename:
                        if s3:
                            obj_res = s3_client.get_object(Bucket=bucket, Key=checkpoint_path)
                            checkpoint_data["task_1_epoch_0"] = torch.load(BytesIO(obj_res["Body"].read()))
                        else:
                            checkpoint_data["task_1_epoch_0"] = torch.load(checkpoint_path)
                    
                    elif "studenttask_1_epoch_49.pth" in filename:
                        if s3:
                            obj_res = s3_client.get_object(Bucket=bucket, Key=checkpoint_path)
                            checkpoint_data["task_1_epoch_49"] = torch.load(BytesIO(obj_res["Body"].read()))
                        else:
                            checkpoint_data["task_1_epoch_49"] = torch.load(checkpoint_path)
            
            params = checkpoint_data["task_0_epoch_0"]["model_state_dict"].keys()
            
            for param in params:
                if param[0] != 'h':  # skip heads
                    disp_0 = torch.abs(checkpoint_data["task_0_epoch_0"]["model_state_dict"][param] - checkpoint_data["task_0_epoch_49"]["model_state_dict"][param])
                    disp_1 = torch.abs(checkpoint_data["task_1_epoch_0"]["model_state_dict"][param] - checkpoint_data["task_1_epoch_49"]["model_state_dict"][param])
                    
                    disp_prod = torch.mul(disp_0, disp_1).sum().item()
                    
                    key_ = param + "_collision_weights"
                    if key_ not in cols:
                        cols[key_] = []
                    
                    cols[key_].append(disp_prod)
            
            cols["bottom_level"].append(case)
                        
        except Exception as e:
            continue
    
    return pd.DataFrame().from_dict(cols)


def get_everything(filepath: str, level_folders: List[str], s3: bool=False, bucket: str=None)->pd.DataFrame:
    """
    Combines metrics and gradient analysis by extracting both run metrics and cumulative gradients
    from experiment folders, then merging the results into a single comprehensive DataFrame.
    
    This function serves as a convenience wrapper that:
    - Extracts catastrophic forgetting metrics using get_runs_metrics()
    - Extracts gradient dot products using get_cumulative_grads()
    - Extracts cumulative loss using get_cumulative_loss()
    - Extracts head phi similarity using get_head_phi()
    - Extracts cumulative residuals using get_cumulative_residuals()
    - Extracts sharpness data using get_sam()
    - Extracts feature participation data using get_feat_part()
    - Extracts first and last checkpoint data using get_first_last_checkpoints()
    - Merges both datasets on the experiment case identifier
    
    Args:
        filepath (str): location of the folders to loop over
        level_folders (List[str]): name of the folders to loop over
        s3 (bool, optional): Using s3? Defaults to False.
        bucket (str, optional): s3 bucket name, usually `scipi1-public`. Defaults to None.

    Returns:
        pd.DataFrame: merged dataframe containing both metrics and gradient data for each case
    """
    
    df_checkpoints = get_weight_displacement(filepath, level_folders, s3, bucket)
    df_metrics = get_runs_metrics(filepath, level_folders, s3, bucket)
    df_grad = get_cumulative_grads(filepath, level_folders, s3, bucket)
    df_loss = get_cumulative_loss(filepath, level_folders, s3, bucket)
    df_head_phi = get_head_phi(filepath, level_folders, s3, bucket)
    df_residuals = get_cumulative_residuals(filepath, level_folders, s3, bucket)
    df_sam = get_sam(filepath, level_folders, s3, bucket)
    df_feat_part = get_feat_part(filepath, level_folders, s3, bucket)
    
    
    # Use concat to combine all dataframes
    dataframes = [df_metrics, df_grad, df_loss, df_head_phi, df_residuals, df_sam, df_feat_part, df_checkpoints]
    df = pd.concat(dataframes, axis=1)
    
    # Remove duplicate bottom_level columns (keep only the first one)
    df = df.loc[:, ~df.columns.duplicated()]

    return df



# condition to identify the bottom level________________________________________________________________________________________________

def has_logs_subfolder(directory: str, s3: bool = False, bucket: str = None) -> bool:
    """
    Check if the given directory contains at least one logs subfolder
    """
    target_folder = "logs"

    if s3:
        s3_client = get_s3_client()
        paginator = s3_client.get_paginator("list_objects_v2")
        pages = paginator.paginate(Bucket=bucket, Prefix=directory, Delimiter='/')

        for page in pages:
            for prefix in page.get("CommonPrefixes", []):
                if prefix["Prefix"].rstrip('/').endswith(target_folder):
                    return True
        return False

    else:
        directory_path = Path(directory)
        return any(subdir.name == target_folder for subdir in directory_path.iterdir() if subdir.is_dir())


# main recursive function________________________________________________________________________________________________

def get_df_recursive(filepath: str, bottom_action: Callable, is_bottom: Callable, s3: bool=False, bucket: str=None, lev: int=0)->pd.DataFrame:
    """
    Loops recursively inside folders, keeping track of the various levels, until the bottom is reached
    At the bottom, performs the bottom_action.
    
    N.B. The condition for the bottom is hard-coded

    Args:
        filepath (str): level path, if user input, starting level
        bottom_action (Callable): function to perform at the bottom level
        s3 (bool): AWS s3 flag
        lev (int, optional): Current level, leave default value. Defaults to 0.

    Returns:
        pd.DataFrame: multi-level dataframe
    """
    
    # files on s3 bucket
    if s3:
        
        s3_client = get_s3_client()
        
        # List all "directories" one level under the current prefix
        paginator = s3_client.get_paginator("list_objects_v2")
        pages = paginator.paginate(Bucket=bucket, Prefix=filepath, Delimiter='/')
        
        level_folders = []
        for page in pages:
            for prefix in page.get("CommonPrefixes", []):
                level_folders.append(prefix["Prefix"].rstrip('/').split('/')[-1])
        
        if not level_folders:
            return pd.DataFrame()  # empty folder

        
        # Construct full path for the first to check if it's bottom
        first_subfolder = f"{filepath.rstrip('/')}/{level_folders[0]}/"
        
        
        if is_bottom(first_subfolder, s3=s3, bucket=bucket):
            df = bottom_action(filepath, level_folders, s3=s3, bucket=bucket)
            
        else:
            df = None
            for case in level_folders:
                subpath = f"{filepath.rstrip('/')}/{case}/"
                df_temp = get_df_recursive(subpath, bottom_action, is_bottom, s3=s3, bucket=bucket, lev=lev+1)
                if df_temp is not None:
                    df_temp[f"level_{lev}"] = case
                    df = df_temp if df is None else pd.concat([df, df_temp])
    
    # files on local machine
    else:
        
        # get all folders of current level
        level_folders = [d for d in listdir(filepath) if isdir(join(filepath,d))]
        
        # check if bottom level is reached, condition might change for other applications
        if is_bottom(join(filepath,level_folders[0])):
            print("reached bottom level")
            df = bottom_action(filepath, level_folders, s3=s3, bucket=bucket)
            

        else:

            # init dataframe
            df = None

            # loop over the sweep folders
            for case in level_folders:

                # update internal filepath
                filepath_ = join(filepath,case)

                # recursive call
                df_temp = get_df_recursive(filepath_, bottom_action, is_bottom, s3=s3, bucket=bucket, lev = lev+1)

                # update sweep colums
                df_temp[f"level_{lev}"] = case


                # append to df
                df = df_temp if df is None else pd.concat([df,df_temp])

    return df




def eval_sweeps(filepath: str, outpath: str, s3:bool):
    """
    Evaluates sweep experiment by calling the processing function and saving its result
    Args:
        filepath (str): _description_
        outpath (str): _description_
    """
    df = get_df_recursive(filepath=filepath, bottom_action=get_everything, is_bottom=has_logs_subfolder, s3=s3, bucket="scipi1-public")
    df.to_csv(join(outpath,"eval_sweeps.csv"))



# helpers________________________________________________________________________________________________
def get_s3_client(public_only: bool = True):
    if public_only:
        return boto3.client("s3", config=Config(signature_version=UNSIGNED))
    else:
        return boto3.client("s3")


if __name__ == "__main__":
    
    s3= False
    
    ROOT_DIR = dirname(dirname(abspath(__file__)))
    outpath = join(ROOT_DIR,"experiments","evaluations", "three_layer_relu_mlp_mnist_a_sweep")
    
    if not(exists(outpath)):
        makedirs(outpath) 
    
    if s3:
        filepath = "cf_relu_nets/linear_mnist/combinations/"
    else:
        filepath = join(ROOT_DIR,"experiments","training", "three_layer_relu_mlp_mnist_a_sweep", "combinations")
    
    eval_sweeps(filepath, outpath, s3=s3)
