
import pandas as pd
import numpy as np
import os
import re
from scipy import interpolate
from collections import defaultdict
from tqdm import tqdm
import wandb


def find_common_divisor(spacings):
    """
    Tries to find a common divisor for the spacings that could represent
    the actual pattern frequency.
    """
    # Convert to integers and find GCD
    spacings = np.asarray(spacings, dtype=int)
    
    if len(spacings) <= 1:
        return spacings[0] if len(spacings) == 1 else None
    
    from math import gcd
    from functools import reduce
    
    # Calculate GCD of all spacings
    common_divisor = reduce(gcd, spacings)
    
    # If common divisor is too small, it's likely not meaningful
    if common_divisor <= 1:
        return None
        
    # Check if all spacings are close multiples of the common divisor
    for spacing in spacings:
        ratio = spacing / common_divisor
        if abs(round(ratio) - ratio) > 0.1:  # If not a close multiple
            return None
    
    return common_divisor

def process_wandb_csv(csv_path, output_dir, num_steps, num_rows, do_interpolate=False, sampling_freq=25, target_freq=10):
    """
    Process a CSV file with wandb run data and organize it by agent.
    
    Args:
        csv_path (str): Path to the CSV file
        output_dir (str): Directory to save the output numpy files
        num_steps (int): Number of steps to include in each row
        num_rows (int): Number of rows in the output array
        do_interpolate (bool): Whether to interpolate the data
        sampling_freq (int): The frequency at which data was sampled in the CSV (e.g., every 25 steps)
        target_freq (int): The target frequency for interpolation (e.g., every 10 steps)
    """
    # Read the CSV file
    df = pd.read_csv(csv_path)
    
    # Create a dictionary to store data for each agent
    agent_data = defaultdict(list)
    
    # Process each column (except the first one which is typically "Step")
    for col in df.columns[1:]:
        # Extract agent name from column header
        if col.startswith('Scratch_'):
            column_name = col.replace('Scratch_','')
        else:
            column_name = col
        if column_name.endswith('MIN') or column_name.endswith('MAX'):
            continue
        if column_name.startswith("C_MAML"):
            agent_name = "C_MAML" 
        elif column_name.startswith("MAML"):
            agent_name = "MAML"
        
        else:
            # Extract the agent name (typically before the first underscore)
            match = re.search(r'^([^_]+)', column_name)
            if match:
                agent_name = match.group(1)
            else:
                # Fallback if pattern doesn't match
                agent_name = "Unknown"
        
        # Get the data from this column
        column_data = df[col].values
        
        # Append this column's data to the agent's list
        agent_data[agent_name].append(column_data)
        
        
    
    # Create output directory structure
    os.makedirs(output_dir, exist_ok=True)
    
    # Process data for each agent
    for agent_name, data_list in tqdm(agent_data.items(), desc= 'going through agents'):
        # Create agent directory
        agent_dir = os.path.join(output_dir, agent_name)
        
        
        # Convert list of columns to numpy array
        data_array = np.array(data_list)
        
        # remove Nan values
        # clean = data_array[~np.isnan(data_array)]         # drops NaNs → 1-D
        # data_array = clean.reshape(1, -1)                      # back to shape (1, n)

        # Calculate total steps needed
        if 'random_MNIST' in csv_path and False: # exception - we need to interpolate it
            # wandb.login(key="b632d8301e5f74f0e0ccf6aec37574589a6c42f7")
            # wandb.init(
            #     project='bs nsync',
            #     name= f'cleaned_{agent_name}',
            # )
            # for idx in range(data_array.shape[1]):
            #     value = data_array[0,idx]
            #     log_data = {f"train_acc": value,}
            #     wandb.log(log_data)
            # wandb.finish()
            # The array is of shape (1, n), so we need to flatten it to a 1D array
            # for easier processing with Pandas Series
            breakpoint()
        
            data_flat = data_array.flatten()

            # Convert the numpy array to a pandas Series
            s = pd.Series(data_flat)

            # Perform linear interpolation to fill NaN values
            # 'linear' is the default method and works well for most cases
            s_interpolated = s.interpolate(method='linear')

            # Convert the pandas Series back to a numpy array if needed
            data_interpolated_flat = s_interpolated.values

            # If you need to reshape it back to (1, n)
            data_array = data_interpolated_flat.reshape(data_array.shape)
            data_array[0,0] = data_array[0,1]
            breakpoint()
        
        total_steps_needed = num_rows * num_steps
        
        # Ensure we have enough data
        # breakpoint()
        if data_array.shape[1] < total_steps_needed:
            print(f"Warning: Not enough data for agent {agent_name}. Retruning.")
            # breakpoint()
            continue
            padded_array = np.zeros((data_array.shape[0], total_steps_needed))
            padded_array[:, :data_array.shape[1]] = data_array
            data_array = padded_array
        
        os.makedirs(agent_dir, exist_ok=True)
        
        # Reshape to get the desired format (each column is one run, rows are groups of steps)
        # Transpose to get (num_runs, total_steps)
        # Then reshape to (num_runs, num_rows, num_steps)
        final_array = data_array[:, :total_steps_needed].reshape(data_array.shape[0], num_rows, num_steps)
        
        # Save the processed data
        output_path = os.path.join(agent_dir, f"{agent_name}.npy")
        np.save(output_path, final_array)
        print(f"Saved {agent_name} data to {output_path}")
        
        # Handle interpolation if requested
        if do_interpolate:
            if 'permuted_MNIST' in csv_path:
                
                interpolated_array = np.zeros((data_array.shape[0], num_rows, 63))
            
            else:
                interpolated_array = np.zeros((data_array.shape[0], num_rows, int(num_steps * sampling_freq / target_freq)))
                
            for run_idx in tqdm(range(data_array.shape[0]), desc = 'creating interpolated table', disable= True):
                for row_idx in range(num_rows):
                    # Get the original data for this row
                    start_idx = row_idx * num_steps
                    end_idx = start_idx + num_steps
                    original_data = data_array[run_idx, start_idx:end_idx]
                    
                    # Create x coordinates for original data (0, sampling_freq, 2*sampling_freq, ...)
                    original_x = np.arange(0, len(original_data) * sampling_freq, sampling_freq)
                    
                    # Create x coordinates for target interpolated data (0, target_freq, 2*target_freq, ...)
                    target_x = np.arange(0, original_x[-1] + 1, target_freq)
                    
                    # Limit target_x to not exceed the range of original_x
                    target_x = target_x[target_x <= original_x[-1]]
                    
                    if 'permuted_MNIST' in csv_path:
                        target_x = np.arange(0, 630, target_freq)
                    # Create interpolation function
                    f = interpolate.interp1d(original_x, original_data, kind='linear', bounds_error=False, fill_value="extrapolate")
                    
                    # Apply interpolation
                    interpolated_data = f(target_x)
                    
                    # Store interpolated data
                    interpolated_array[run_idx, row_idx, :len(interpolated_data)] = interpolated_data
                    # if 'permuted_MNIST' in csv_path:
                    #     breakpoint()
                # Save the interpolated data
                interp_output_path = os.path.join(agent_dir, f"{agent_name}_interpolate.npy")
                np.save(interp_output_path, interpolated_array)
                print(f"Saved interpolated {agent_name} data to {interp_output_path}")

# Example usage:
if __name__ == "__main__":

    tasks = ['random_label_cifar10', 'random_MNIST','shuffle_cifar10' ,'permuted_MNIST', 'continual_cifar100', 'continual_imagenet',]

    legend_name_dict = {'Base': 'Base', 'CBP': 'CBP', 'CReLU': 'CReLU',
                        'DeepF': 'DeepF', 'EWC' : 'EWC', 'L2': 'L2', 'L2Init': 'L2Init', 'LayerNorm': 'LayerNorm',
                        'NeuroSync': 'NeuroSync', 'PReLU': 'PReLU', 'ReDo': 'ReDo', 'Scratch': 'Scratch',
                        'MAML': 'MAML', 'C_MAML': 'C_MAML'}

    training_iterations ={
        'random_MNIST' : 600,
        'random_label_cifar10' : 150,
        'shuffle_cifar10' : 251,
        'permuted_MNIST': 7,
        'continual_cifar100' : 126,
        'continual_imagenet': 51,    
    }

    number_of_rows = {
        'random_MNIST' : 30,
        'random_label_cifar10' : 30,
        'shuffle_cifar10' : 30,
        'permuted_MNIST': 25,
        'continual_cifar100' : 20,
        'continual_imagenet': 100,    
    }
    desired_frequency = 10
    current_frequenc_y = 25
    mode = 'train'

    # for new train use these 3 lines
    fwt_path = r'C:\Users\khash\OneDrive\Desktop\Research-Coding\17\result_analysis\forward_transfer'
    fwt_path = os.path.join(fwt_path, mode)
    result_path = r'C:\Users\khash\OneDrive\Desktop\Research-Coding\17\processed_results\fwt_matrices'


    # for scratch use these 3 lines
    # fwt_path = r'C:\Users\khash\OneDrive\Desktop\Research-Coding\17\result_analysis\scratch_results'
    # fwt_path = os.path.join(fwt_path, mode)
    # result_path = r'C:\Users\khash\OneDrive\Desktop\Research-Coding\17\processed_results\scratch_matrices'
    
    
    result_path = os.path.join(result_path, mode)
    

    # tasks = ['random_MNIST']


    for task in tasks:
        task_fwt_path = os.path.join(fwt_path, f"{task}.csv")
        output_dir = os.path.join(result_path,task)
        

        process_wandb_csv(
            csv_path= task_fwt_path,
            output_dir= output_dir,
            num_steps= training_iterations[task],
            num_rows= number_of_rows[task],
            do_interpolate=True
        )



