import os
import sys
sys.path.append('../')
sys.path.append('../src/sweeps/')

import numpy as np

import pandas as pd
import pickle
from glob import glob


import wandb
# from config_wandb import entity
import importlib.util

class WandBDownloader:
    def __init__(self, sweep_file = None, entity=None, sweep_id=None, project_name=None):
        self.entity = entity
        self.sweep_id = sweep_id
        self.sweep_file = sweep_file
        self.project_name = project_name
        # get project name from sweep file
        if sweep_file is not None:
            self.get_project_name(sweep_file)
        self.fetch_runs(entity, self.project_name, sweep_id=sweep_id)
        self.info_runs = self.process_runs()


    def get_project_name(self, file_path):
        # Set up the spec
        spec = importlib.util.spec_from_file_location("listops_module", file_path)
        listops_module = importlib.util.module_from_spec(spec)

        # Execute the module
        spec.loader.exec_module(listops_module)

        # Now you can access project_name
        project_name = listops_module.project_name
        # project_name = "ListOps-excluded-trip-last"
        print(f"Project name: {project_name}")
        self.project_name = project_name
        
    def fetch_runs(self, entity, project_name, sweep_id=None):
        self.api = wandb.Api()
        if sweep_id:
            self.sweep = self.api.sweep(f"{entity}/{project_name}/{sweep_id}")
            self.runs = self.sweep.runs
        else:
            self.sweep = None
            self.runs = self.api.runs(f"{entity}/{project_name}")

    def process_runs(self):
        runs = self.runs

        info_runs = []
        # summary_list, config_list, name_list = [], [], []
        # combine summary and config of runs into a single dataframe
        for i,run in enumerate(runs): 
            if i % 50 == 0:
                print(f"Run {i}: {run.name} ({run.id})", end='\t\t\t\r')
            # .summary contains the output keys/values for metrics like accuracy.
            #  We call ._json_dict to omit large files 
            info_runs.append(run.summary._json_dict)

            # .config contains the hyperparameters.
            #  We remove special values that start with _.
            info_runs[-1].update(
                {k: v for k,v in run.config.items() if not k.startswith('_')})

            # .name is the human-readable name of the run.
            info_runs[-1]['name'] = run.name
            # .id is the unique identifier of the run.
            info_runs[-1]['id'] = run.id
            
        print(f"Processed {len(runs)} runs.")
        # self.info_runs = info_runs
        return info_runs
    
    def get_dataframe(self, clean=True):
        # Convert the list of dictionaries into a pandas DataFrame.
        df = pd.DataFrame(self.info_runs)
        # bring the following columns to the beginning
        cols = ['name', 'ops', 'n_embed', 'n_layer','accuracy_final', 'number_of_parameters']
        df = df[cols + [col for col in df.columns if col not in cols]]
        # clean df by removing rows with NaN values in the 'accuracy_final' column
        df = df.dropna(subset=['accuracy_final']) if clean else df
        return df

class DataFrameProcessor:
    def __init__(self, df):
        self.df = df
        self.fill_ops_column(df)
        self.get_all_model_files(df)
        self.get_model_files()
        self.get_results()


    def strip_num(self, func_name):
        """Strip the number from the function name."""
        if '_' in func_name: 
            return '_'.join(func_name.split('_')[:-1])
        else:
            return func_name
    
    def fill_ops_column(self, df):
        # add new column 'ops_stripped' to the dataframe
        df['ops_stripped'] = None

        for row in df.itertuples():
            data_file = row.data_file
            # with open(data_file, 'rb') as f:
            #     data = pickle.load(f)
            # print(f"Data loaded from {data_file}")
            funcs_to_use = data_file.split('_funcs')[1].split('_depth')[0][1:-1].split(',')
            # funcs_to_use = data['metadata']['funcs_to_use'] 
            ops = ','.join([func for func in funcs_to_use])
            df.at[row.Index, 'ops'] = ops
            ops = ','.join([self.strip_num(func) for func in funcs_to_use])
            df.at[row.Index, 'ops_stripped'] = ops
        # df['ops'] = df['ops'].astype(str)
        
    def get_all_model_files(self, df):
        all_model_files = {}
        for pth in df['save_path'].unique():
            saved_files = glob(pth + '/model*.pt')
            for f in saved_files:
                # extract name by splitting the path
                name = f.split('/')[-1].lstrip('model_').rstrip('.pt')
                # get the time stamp of all saved files and 
                time_stamp = os.path.getmtime(f)
                all_model_files[name] = {'time_stamp': time_stamp, 'file': f}
        self.all_model_files = all_model_files
        # return all_model_files

    def get_model_files(self, dt=120):    
        # add model_file column to df
        df = self.df
        all_model_files = self.all_model_files
        
        df['model_file'] = None
        # match model files with runs
        model_files = {}
        c = 0
        for i, row in df.iterrows():
            # get the name of the run
            name = row['name']
            # find the model file that matches the name
            # matches = [f for f in all_model_files if f.startswith(name)]
            if name in all_model_files:
                # check if time stamps are within 1 min
                if abs(all_model_files[name]['time_stamp'] - row['_timestamp']) < dt:
                    df.at[i, 'model_file'] = all_model_files[name]['file']
                    model_files[name] = all_model_files[name]['file']
                else: 
                    print(f"{c}) Time diff: {abs(all_model_files[name]['time_stamp'] - row['_timestamp']):.3g}: {name}")
                    c += 1
                    df.at[i, 'model_file'] = None
            else:
                # if no match found, set the model file to empty string
                # model_files[name] = ''
                print(f"No match found for {name} in all_model_files")
                df.at[i, 'model_file'] = None
            # if len(matches) == 0:
            #     model_files[name] = ''
            # elif len(matches) == 1:
            #     model_files[name] = all_model_files[matches[0]]['file']
            # else:
            #     print(f"Multiple matches found for {name}: {matches}")
            #     model_files[name] = all_model_files[matches[0]]['file']  # take the first match
            
    def get_results(self):
        # some model files are not on this machine, so we need to check if the file exists
        # drop the rows where file does not exist
        # df['model_file'] = df.apply(get_model_file, axis=1)
        self.df['results_file'] = self.df.apply(self.get_results_file, axis=1)
        # df_model_exists = df[df['model_file'].apply(lambda x: os.path.exists(x))]
        # df_model_exists = df[df['results_file'] != '']
        # use dropna
        print('Before dropping, number of rows:', len(self.df))
        self.df = self.df.dropna(subset=['model_file'])
        print('After dropping, number of rows:', len(self.df))
        # return self.df
        
    def get_results_file(self,row):
        # check if the results file exists
        # if model_files[row['name']]:
        if row['model_file'] is not None and os.path.exists(row['model_file']):
            # return model_files[row['name']].replace('model', 'results').replace('.pt', '.pkl')
            return row['model_file'].replace('model_', 'results_').replace('.pt', '.pkl')
        else:
            return None

