import os
import sys
import torch
from torch.utils.tensorboard import SummaryWriter
from utils.sweeper import Sweeper

class Parameters:
    def __init__(self, args):
        args.config_idx = int(float(args.config_idx))
        sweeper = Sweeper(args.config_file)
        cfg = sweeper.generate_config_for_idx(args.config_idx)
        if args.config_idx > sweeper.config_dicts['num_combinations']:
            print("Config ID is invalid", flush=True)
            sys.exit()
        print("Running configuration ID: {}".format(args.config_idx), flush=True)
        self.num_agents = cfg['num_agents']
        self.env_name = cfg['env_type']
        self.T = cfg['episode_max_length']

        # # Transitive Algo Params
        self.agent_type = cfg['agent_type']
        self.rollout_size = cfg['rollout_size']
        self.num_evals = cfg['num_evals']
        self.total_frames = int(cfg['total_frames'])
        self.evo_popn_size = cfg['evo_popn_size']
        self.actualize = cfg['actualize']
        self.seed = cfg['seed']
        self.ps = cfg['parameter_shapping']
        self.global_update = cfg['global_update']
        self.learning_start = cfg['learning_start']
        self.autoreward_iterations = cfg['autoreward_iterations']
        self.ratio = cfg['ratio']

        self.tracedump = False
        self.trace_path = '/workdisk/football/merl/Results/traces/'

        # # Env domain
        self.config = cfg['env_name']
        self.render = cfg['render']
        # self.policy_seed = vars(parser.parse_args())['policy_seed']
        self.frameskip = cfg['frameskip']

        # # Fairly Stable Algo params
        self.lineage_alpha = cfg['lineage_alpha']
        self.hidden_size = cfg["hidden_size"]
        self.algo_name = cfg['algo']
        self.actor_lr = cfg['actor_lr']
        self.critic_lr = cfg['critic_lr']
        self.tau = cfg['tau']
        self.init_w = cfg['weight_init']
        self.gradperstep = cfg['gradperstep']
        self.gamma = cfg["gamma"]
        self.batch_size = cfg['batch_size']
        self.buffer_size = cfg['buffer_size']
        self.reward_scaling = cfg["reward_scaling"]
        self.target_update_interval = 1
        self.alpha =  cfg['alpha']
        # # NeuroEvolution stuff
        self.scheme = cfg['scheme']  # 'multipoint' vs 'standard'
        self.crossover_prob = cfg['crossover_prob']
        self.mutation_prob = cfg['mutation_prob']
        self.extinction_prob = cfg['extinction_prob']  # Probability of extinction event
        self.extinction_magnitude = cfg['extinction_magnitude']  # Probabilty of extinction for each genome, given an extinction event
        self.weight_clamp = cfg['weight_clamp']
        self.mut_distribution = cfg['mut_distribution']  # 1-Gaussian, 2-Laplace, 3-Uniform
        self.lineage_depth = cfg['lineage_depth']
        self.ccea_reduction = cfg['ccea_reduction']
        self.num_anchors = cfg['num_anchors']
        self.elite_ratio = cfg['elite_ratio']
        self.kill_ratio = cfg['kill_ratio']
        self.num_blends = int(0.15 * self.evo_popn_size)
        self.num_test = 20
        self.test_gap = 1

        # Save Filenames
        self.savetag = cfg['save_tag'] + \
                       str(self.agent_type)+'_'+\
                       'pop' + str(self.evo_popn_size) + \
                       '_roll' + str(self.rollout_size) + \
                       '_' + str(self.env_name) + '_'+str(self.config) + \
                       '_alpha' + str(self.alpha) + \
                       '_seed' + str(self.seed) + \
                       '_control' + str(self.num_agents) + \
                       '_its' + str(self.autoreward_iterations) + \
                       '_actor_lr' + str(self.actor_lr) + \
                       '_critic_lr' + str(self.critic_lr) + \
                       '_batch_size' + str(self.batch_size) + \
                       '_ratio' + str(self.ratio) + \
                       '_fskip' + str(self.frameskip) + \
                       ('_alz' if self.actualize else '') + \
                       ('_multipoint' if self.scheme == 'multipoint' else '') + \
                       ('_ps' if self.ps else '') +\
                       ('_global' if self.global_update else '')




        self.save_foldername = 'Results/' + str(self.env_name) + '_' + str(self.config)
        self.save_foldername += '_control'+str(self.num_agents) + '/' if self.env_name == 'gfootball' else '/'
        if not os.path.exists(self.save_foldername): os.makedirs(self.save_foldername)
        self.metric_save = self.save_foldername #+ 'metrics/'
        self.model_save = self.save_foldername + 'models/'
        self.aux_save = self.save_foldername + 'auxiliary/'
        if not os.path.exists(self.save_foldername): os.makedirs(self.save_foldername)
        if not os.path.exists(self.metric_save): os.makedirs(self.metric_save)
        if not os.path.exists(self.model_save): os.makedirs(self.model_save)
        if not os.path.exists(self.aux_save): os.makedirs(self.aux_save)

        self.critic_fname = 'critic_' + self.savetag
        self.actor_fname = 'actor_' + self.savetag
        self.log_fname = 'reward_' + self.savetag
        self.best_fname = 'best_' + self.savetag


        self.writer = SummaryWriter(log_dir='Results/tensorboard/' + self.savetag)
