import pandas as pd
import numpy as np
import warnings
import pathlib
import torch
import random
import json
import copy
import os

if torch.cuda.is_available():
    from torch.multiprocessing import Pool
    torch.multiprocessing.set_start_method('spawn', force=True)
else:
    from multiprocessing.pool import Pool

from src.rl_main import ReinforcementLearning

warnings.filterwarnings('ignore')


class RLTuning:
    def __init__(self, use_wandb=False, wandb_entity='< WANDB ENTITY >', wandb_project='< WANDB PROJECT >'):
        self.experiment = ''
        self.agent = None
        self.env = None
        self.state_representation = None
        self.num_runs = None
        self.max_steps = None
        self.num_episodes = None

        self.use_wandb = use_wandb
        self.wandb_entity = wandb_entity
        self.wandb_project = wandb_project

        self.output_folder = pathlib.Path("./tuning")
        self.output_folder.mkdir(parents=True, exist_ok=True)
        self.output_folder = str(self.output_folder)

        self.raw_folder = pathlib.Path("./tuning/raw")
        self.raw_folder.mkdir(parents=True, exist_ok=True)
        self.raw_folder = str(self.raw_folder)

    @staticmethod
    def set_random_seed(seed):
        torch.manual_seed(seed)  # Sets seed for CPU operations
        torch.cuda.manual_seed_all(seed)  # Sets seed for current GPU and all future GPUs
        np.random.seed(seed)  # Sets seed for NumPy
        random.seed(seed)  # Sets seed for Python's random module
        torch.backends.cudnn.deterministic = True  # Ensures deterministic behavior
        torch.backends.cudnn.benchmark = False  # Disables cudnn benchmark for reproducibility

    def run_experiment_continuing(self, case):
        results = {
            'avg_reward': [],
            'cvar': [],
            'avg_reward_type': [],
            'avg_reward_estimate': [],
            'seed': [],
        }
        for param in case['params'].keys():
            results[param] = []

        agent = copy.deepcopy(self.agent)
        env = copy.deepcopy(self.env)

        # dynamically update agent instance variables
        for key, value in case['params'].items():
            if 'agent_' in key:
                setattr(agent, key.replace('agent_',''), value)

        try:
            # run case
            for run in range(self.num_runs):
                rl = ReinforcementLearning(
                    experiment=self.experiment + '_case_' + str(case['case_id']) + '_run_' + str(run),
                    agent_class=agent,
                    environment_class=env,
                    state_representation=self.state_representation,
                    track_data=True,
                    use_wandb=self.use_wandb,
                    wandb_entity=self.wandb_entity,
                    wandb_project=self.wandb_project,
                )
                rl.agent.load_pytorch_networks()

                # start first episode
                episode = 0
                self.set_random_seed(1000 * run + episode)
                last_state, last_action = rl.rl_start(seed=1000 * run + episode, epsilon=case['params']['epsilon'])
                for step_n in range(self.max_steps):
                    # set step sizes
                    if case['params']['step_size_value'] == '1/n':
                        use_step_size = {'value': 1 / (step_n + 1)}
                    else:
                        use_step_size = {'value': case['params']['step_size_value']}

                    for param in case['params'].keys():
                        if 'step_size' in param and param != 'step_size_value':
                            use_step_size[param.replace('step_size_','')] = case['params'][param] * use_step_size['value']

                    reward, state, action, terminal = rl.rl_step(
                        last_state,
                        last_action,
                        step_size=use_step_size,
                        epsilon=case['params']['epsilon'],
                        discount=case['params']['discount'],
                    )

                    last_state = state
                    last_action = action

                    if terminal:
                        # start next episode
                        episode += 1
                        last_state, last_action = rl.rl_start(seed=1000 * run + episode, epsilon=case['params']['epsilon'])

                # get experiment data
                results_df = rl.get_data()

                # get data for final 1000 steps
                results_df = results_df.iloc[-1000:]
                final_VAR = results_df['reward'].quantile(rl.agent.var_quantile)

                results['avg_reward'].append(results_df['reward'].mean())
                results['cvar'].append(results_df[results_df['reward'] <= final_VAR]['reward'].mean())
                results['avg_reward_type'].append('per_step')
                results['avg_reward_estimate'].append(rl.agent.avg_reward)
                results['seed'].append(run)

                for param in case['params'].keys():
                    results[param].append(case['params'][param])

                # cleanup run
                rl.rl_end()

            # get output df
            output_df = pd.DataFrame(results)

            output_df['case_id'] = case['case_id']

            output_df.to_csv(os.path.join(self.raw_folder, self.experiment + '_case_' + str(case['case_id']) + '.csv'), index=False)
            return
        except:
            print('Error tuning case ' + str(case['case_id']) + ':\n' + str(case))
            return

    @staticmethod
    def expand_cases(cases, param, values):
        expanded_cases = []
        for i in range(len(cases)):
            case = cases.pop()
            for value in values:
                case['params'][param] = value
                expanded_cases.append(copy.deepcopy(case))
        return expanded_cases

    def tuning_cleanup(self):
        raw_path = os.path.join(self.raw_folder)
        raw_outputs = [f for f in os.listdir(raw_path) if os.path.isfile(os.path.join(raw_path, f))]

        output_df = pd.DataFrame()
        for file in raw_outputs:
            if self.experiment in file and '.csv' in file:
                if len(output_df) == 0:
                    output_df = pd.read_csv(os.path.join(raw_path, file))
                else:
                    output_to_add = pd.read_csv(os.path.join(raw_path, file))
                    output_df = pd.concat([output_df, output_to_add], ignore_index=True)

        output_df.to_csv(os.path.join(self.output_folder, self.experiment + '_results_raw.csv'), index=False)

        output_df.loc[output_df['step_size_value'] == '1/n', 'step_size_value'] = 0

        # get average results by case
        output_df['avg_reward_std'] = output_df['avg_reward']
        output_df['cvar_std'] = output_df['cvar']
        output_df['avg_reward_estimate_std'] = output_df['avg_reward_estimate']
        output_df['avg_reward_estimate_error'] = abs(output_df['avg_reward_estimate'] - output_df['avg_reward'])

        agg_dict = {}
        for col in output_df.columns:
            if col not in ['case_id', 'avg_reward_type']:
                agg_dict[col] = 'mean'

        agg_dict['avg_reward_type'] = 'first'
        agg_dict['avg_reward_std'] = 'std'
        agg_dict['cvar_std'] = 'std'
        agg_dict['avg_reward_estimate_std'] = 'std'
        agg_dict['avg_reward_estimate_error'] = 'max'

        averaged_df = output_df.groupby(by='case_id').agg(agg_dict).reset_index()

        averaged_df.loc[averaged_df['step_size_value'] == 0, 'step_size_value'] = '1/n'
        averaged_df = averaged_df.drop(columns=['seed'])

        averaged_df.to_csv(os.path.join(self.output_folder, self.experiment + '_results.csv'), index=False)

    def run_tuning(self, experiment, experiment_type, agent, env, tuning_params, num_runs, max_steps=None, num_episodes=None,
                   state_representation=None, continue_run=False, n_cores=5):

        self.experiment = experiment
        self.agent = agent
        self.env = env
        self.state_representation = state_representation
        self.num_runs = num_runs
        self.max_steps = max_steps
        self.num_episodes = num_episodes

        if continue_run:
            # load auto-generated cases
            with open(os.path.join(self.output_folder, self.experiment + '_cases.json'), 'r') as file:
                all_cases = json.loads(file.read())

            # check which cases are missing
            cases_path = os.path.join(self.raw_folder)
            file_list = [f for f in os.listdir(cases_path) if os.path.isfile(os.path.join(cases_path, f))]

            finished_cases = []
            for file in file_list:
                if self.experiment in file and '.csv' in file:
                    finished_cases.append(int(file.split('_')[-1].replace('.csv', '')))

            cases = []
            for case in all_cases:
                if case['case_id'] not in finished_cases:
                    cases.append(case)
        else:
            # auto-generate cases
            cases = [{'case_id': None, 'params': {}}]
            for param in tuning_params.keys():
                values = tuning_params[param]
                cases = self.expand_cases(cases, param, values)

            counter = 1
            for case in cases:
                case['case_id'] = counter
                counter += 1

            # save cases dict
            with open(os.path.join(self.output_folder, self.experiment + '_cases.json'), 'w') as file:
                file.write(json.dumps(cases))

        num_cases = len(cases)
        if num_cases > 0:
            print('total number of ' + experiment + ' cases: ' + str(num_cases))
            print('Number of cores: ' + str(n_cores))
            print('starting multiprocessing...')

            p = Pool(processes=n_cores)

            if experiment_type == 'continuing':
                p.map(self.run_experiment_continuing, cases)
            elif experiment_type == 'episodic':
                p.map(self.run_experiment_episodic, cases)
            else:
                print("ERROR: no max_steps or num_runs specified.")

            p.close()
            p.join()
        else:
            print('no cases to run!')

        # run tuning cleanup
        self.tuning_cleanup()

        print('finished tuning!')
