from src.portfolio_optimization.predictor import NetworkPredictor as PortfolioNetworkPredictor
from src.portfolio_optimization.portfolio_problem import PortfolioProblem

from src.vopf.predictor import NetworkPredictor as VOPFNetworkPredictor
from src.vopf.vopf_problem import VOPFProblem

from src.runner import Runner
from collections import defaultdict
import numpy as np
import pickle
import wandb
import torch
import yaml


PROBLEM_CLASSES_DICT = {'OPF': (VOPFProblem, VOPFNetworkPredictor),
                        'portfolio_optimization': (PortfolioProblem, PortfolioNetworkPredictor),
                        }


class Experiment:

    def __init__(self, path_to_config, problem_class, use_wandb=True, config_name='config.yml', ):

        assert problem_class in PROBLEM_CLASSES_DICT, 'Unknown problem class %s' % problem_class

        self.problem_class = problem_class
        self.use_wandb = use_wandb
        self.path_to_config = path_to_config
        self.config_name = config_name
        print('Initializing experiment:')
        print('Path: %s' % self.path_to_config)
        print()

        self.training_epochs = 0
        self.logger_initialized = False
        self._load_config()
        self._initialize_components()
        self._initialize_logger()
        self.training_history = defaultdict(list)
        self.validate_history = defaultdict(list)
        self.test_history = defaultdict(list)
        try:
            self.load()
        except:
            print('Could not load experiment, starting from zero!')

    def _load_config(self, ):
        with open('%s/%s' % (self.path_to_config, self.config_name), 'r') as f:
            self.full_config = yaml.safe_load(f)
        self.problem_config = self.full_config['problem_config']
        self.predictor_config = self.full_config['predictor_config']
        self.runner_config = self.full_config['runner_config']
        self.training_config = self.full_config['training_config']
        self.wandb_id = self.full_config['wandb_id']

    def _initialize_components(self):
        problem_class, predictor_class = PROBLEM_CLASSES_DICT[self.problem_class]
        self.problem = problem_class(**self.problem_config)
        print('Created problem instance!')
        self.predictor = predictor_class(self.problem.n_variables, self.problem.observation_size,
                                         **self.predictor_config)
        print('Initialized predictor!')
        self.runner = Runner(self.problem, self.predictor, **self.runner_config)
        print('Created runner!')

    def _initialize_logger(self):
        if not self.use_wandb:
            print('use_wandb=False, hence not initializing logger!')
            return
        self.logger = wandb.init(name=self.full_config['run_name'], config=self.full_config,
                                 resume='allow', project=self.full_config['wandb_project_name'],
                                 id=self.full_config['wandb_id'],
                                 dir=self.path_to_config)
        self.logger.define_metric("training_epochs")
        self.logger.define_metric("*", step_metric="training_epochs")
        if self.logger.resumed:
            try:
                print('Resumed existing W&B run, training_epochs=%s' % self.logger.summary['training_epochs'])
            except:
                pass
            
    def log_training_epoch(self, epoch_results):
        for key, val in epoch_results.items():
            if self.problem_class == 'bipartite_matching' and key in ['W_sq', 'w_lin', 'x_u']:
                continue
            self.training_history[key].append(np.mean(val, axis=0))

        if self.use_wandb:
            log_dict = dict()
            log_dict['train/reward'] = np.mean(epoch_results['reward'], axis=0)
            log_dict['train/regret'] = np.mean(epoch_results['regret'], axis=0)
            log_dict['train/r_max'] = np.mean(epoch_results['r_max'], axis=0)
            log_dict['train/reward_loss'] = np.mean(epoch_results['reward_loss'], axis=0)
            log_dict['train/proximity'] = np.mean(epoch_results['proximity'], axis=0)
            log_dict['train/projection_distance'] = np.mean(epoch_results['projection_distance'], axis=0)
            log_dict['train/solution_distance'] = np.mean(epoch_results['solution_distance'], axis=0)
            log_dict['train/mse_loss'] = np.mean(epoch_results['mse_loss'], axis=0)
            log_dict['training_epochs'] = self.training_epochs
            self.logger.log(log_dict)

        self.training_epochs += 1

    def log_validate_epoch(self, epoch_results):

        for key, val in epoch_results.items():
            if self.problem_class == 'bipartite_matching' and key in ['W_sq', 'w_lin', 'x_u']:
                continue
            self.validate_history[key].append(np.mean(val, axis=0))

        if self.use_wandb:
            log_dict = dict()
            log_dict['validate/reward'] = np.mean(epoch_results['reward'], axis=0)
            log_dict['validate/regret'] = np.mean(epoch_results['regret'], axis=0)
            log_dict['validate/r_max'] = np.mean(epoch_results['r_max'], axis=0)
            log_dict['validate/reward_loss'] = np.mean(epoch_results['reward_loss'], axis=0)
            log_dict['validate/proximity'] = np.mean(epoch_results['proximity'], axis=0)
            log_dict['validate/projection_distance'] = np.mean(epoch_results['projection_distance'], axis=0)
            log_dict['validate/solution_distance'] = np.mean(epoch_results['solution_distance'], axis=0)
            log_dict['validate/mse_loss'] = np.mean(epoch_results['mse_loss'], axis=0)
            log_dict['training_epochs'] = self.training_epochs
            self.logger.log(log_dict)

    def log_test_epoch(self, epoch_results):
        for key, val in epoch_results.items():
            if self.problem_class == 'bipartite_matching' and key in ['W_sq', 'w_lin', 'x_u']:
                continue
            self.test_history[key].append(np.mean(val, axis=0))

    @property
    def should_early_stop(self):
        early_stopping_th = self.training_config['early_stopping_th']
        if self.training_epochs <= early_stopping_th:
            return False
        for ind in range(1, early_stopping_th + 1):
            val_reward = self.validate_history['reward'][-ind]
            val_reward_before = self.validate_history['reward'][-ind - 1]
            if val_reward_before <= val_reward:
                #  Reward before was smaller, meaning the step was good, hence no need to early stop.
                return False
        print('Early stopping!')
        return True

    @property
    def is_finished(self):
        if self.training_epochs >= self.training_config['n_epochs']:
            return True
        elif self.training_config['early_stopping'] and self.should_early_stop:
            return True
        else:
            return False

    def save(self):
        with open(self.path_to_config + '/training_history.pickle', 'wb') as f:
            pickle.dump(self.training_history, f)
        with open(self.path_to_config + '/validate_history.pickle', 'wb') as f:
            pickle.dump(self.validate_history, f)
        with open(self.path_to_config + '/test_history.pickle', 'wb') as f:
            pickle.dump(self.test_history, f)
        self.predictor.save(self.path_to_config)
        torch.save(self.runner.optimizer.state_dict(), self.path_to_config + "optimizer")

    def load(self):
        with open(self.path_to_config + '/training_history.pickle', 'rb') as f:
            self.training_history = pickle.load(f)
        self.training_epochs = len(self.training_history['reward'])
        print('Loaded training epochs %s' % self.training_epochs)
        with open(self.path_to_config + '/validate_history.pickle', 'rb') as f:
            self.validate_history = pickle.load(f)
        with open(self.path_to_config + '/test_history.pickle', 'rb') as f:
            self.test_history = pickle.load(f)
        self.predictor.load(self.path_to_config)
        self.runner.optimizer.load_state_dict(torch.load(self.path_to_config + "optimizer"))

    def run(self):
        try:
            while not self.is_finished:
                if self.training_epochs % self.training_config['validate_each'] == 0:
                    # Validation run

                    validate_epoch_results = self.runner.run_epoch(batch_size=self.training_config['batch_size'],
                                                                   track_gradients=self.training_config['track_gradients'],
                                                                   normalize_losses=self.training_config['normalize_losses'],
                                                                   mode='validate')
                    self.log_validate_epoch(validate_epoch_results)
                    if self.training_epochs % self.training_config['print_each'] == 0:
                        print('Avg. validate regret in last epochs is %s' % (np.mean(self.validate_history['regret'][-1:])))
                    # Test run
                    test_epoch_results = self.runner.run_epoch(batch_size=self.training_config['batch_size'],
                                                               track_gradients=self.training_config['track_gradients'],
                                                               normalize_losses=self.training_config['normalize_losses'],
                                                               mode='test')
                    self.log_test_epoch(test_epoch_results)

                # Training epoch
                pretraining = self.training_epochs <= self.runner.pretrain_for if self.runner.pretrain_for else False
                train_epoch_results = self.runner.run_epoch(batch_size=self.training_config['batch_size'],
                                                            track_gradients=self.training_config['track_gradients'],
                                                            normalize_losses=self.training_config['normalize_losses'],
                                                            mode='train', pretraining=pretraining)
                self.log_training_epoch(train_epoch_results)
                if self.training_epochs % self.training_config['print_each'] == 0:
                    print('Running epoch %d . . . ' % self.training_epochs)
                    print('Avg. train regret in last epochs is %s' % (np.mean(self.training_history['regret'][-1:])))
                self.predictor.post_epoch_routine()
            if self.use_wandb:
                self.logger.finish()
            self.save()

        except KeyboardInterrupt:
            print('Interrupted')
            if self.use_wandb:
                self.logger.finish()
