import os
import json
from collections import defaultdict

import numpy as np
import torch
from tqdm import tqdm

class Trainer(object):
    def __init__(self, meta_learner, meta_dataset, log_interval,
                 save_interval, model_type, save_folder, total_iter):
        self._meta_learner = meta_learner
        self._meta_dataset = meta_dataset
        self._log_interval = log_interval
        self._save_interval = save_interval
        self._model_type = model_type
        self._save_folder = save_folder
        self._total_iter = total_iter
        self._task_wise = False if meta_learner._num_updates > 0 else True
        

    def run(self, is_training):
        if not is_training:
            all_pre_val_measurements = defaultdict(list)
            all_pre_train_measurements = defaultdict(list)
            all_post_val_measurements = defaultdict(list)
            all_post_train_measurements = defaultdict(list)

        fast_lr_list = []
        for i, (train_tasks, val_tasks) in enumerate(
                iter(self._meta_dataset), start=1):
            # Save model
            if (i % self._save_interval == 0 or i == 1) and is_training:
                save_name = 'maml_{0}_{1}.pt'.format(self._model_type, i)
                save_path = os.path.join(self._save_folder, save_name)
                with open(save_path, 'wb') as f:
                    torch.save(self._meta_learner.state_dict(), f)

            if self._task_wise:
                pre_train_measurements = self._meta_learner.step_nonmaml((train_tasks,), is_training)
                if (i % self._log_interval ==0 or i==1):
                    post_train_measurements = None
                    self.log_output(i,pre_train_measurements=pre_train_measurements,post_train_measurements=post_train_measurements)
            else:
                (pre_train_measurements, adapted_params, fast_lrs) = self._meta_learner.adapt(train_tasks, val_tasks)
                fast_lr_list += fast_lrs
                if not is_training:
                    post_val_measurements = self._meta_learner.step(
                        adapted_params, val_tasks, is_training)[1]
                else:
                    post_val_measurements = self._meta_learner.step(
                        adapted_params, val_tasks, is_training)

                # Tensorboard
                if (i % self._log_interval == 0 or i == 1):
                    pre_val_measurements = self._meta_learner.measure(
                        tasks=val_tasks)
                    post_train_measurements = self._meta_learner.measure(
                        tasks=train_tasks, adapted_params_list=adapted_params)

                    # _grads_mean = np.mean(self._meta_learner._grads_mean)
                    # self._meta_learner._grads_mean = []
                    
                    self.log_output(i,
                        pre_val_measurements, pre_train_measurements,
                        post_val_measurements, post_train_measurements,
                        fast_lr_list)

            # Collect evaluation statistics over full dataset
            if not is_training:
                for key, value in sorted(pre_val_measurements.items()):
                    all_pre_val_measurements[key].append(value)
                for key, value in sorted(pre_train_measurements.items()):
                    all_pre_train_measurements[key].append(value)
                for key, value in sorted(post_val_measurements.items()):
                    all_post_val_measurements[key].append(value)
                for key, value in sorted(post_train_measurements.items()):
                    all_post_train_measurements[key].append(value)
            
            if i >= self._total_iter:
                break
                

        # Compute evaluation statistics assuming all batches were the same size
        if not is_training:
            results = {'num_batches': i}
            for key, value in sorted(all_pre_val_measurements.items()):
                results['pre_val_' + key] = value
            for key, value in sorted(all_pre_train_measurements.items()):
                results['pre_train_' + key] = value
            for key, value in sorted(all_post_val_measurements.items()):
                results['post_val_' + key] = value
            for key, value in sorted(all_post_train_measurements.items()):
                results['post_train_' + key] = value

            print('Evaluation results:')
            for key, value in sorted(results.items()):
                if not isinstance(value, int):
                    print('{}: {} +- {}'.format(
                        key, np.mean(value), self.compute_confidence_interval(value)))
                else:
                    print('{}: {}'.format(key, value))

            results_path = os.path.join(self._save_folder, 'results.json')
            with open(results_path, 'w') as f:
                json.dump(results, f)

    def compute_confidence_interval(self, value):
        """
        Compute 95% +- confidence intervals over tasks
        change 1.960 to 2.576 for 99% +- confidence intervals
        """
        return np.std(value) * 1.960 / np.sqrt(len(value))

    def train(self):
        self.run(is_training=True)

    def eval(self):
        self.run(is_training=False)

    def log_output(self, iteration, pre_val_measurements=None, pre_train_measurements=None,
                   post_val_measurements=None, post_train_measurements=None, 
                   fast_lr_list=None):
        log_str = 'Iteration: {}/{} '.format(iteration, self._total_iter)
        if pre_val_measurements is not None:
            for key, value in sorted(pre_val_measurements.items()):
                log_str = (log_str + '{} meta_val before: {:.3f} '
                                    ''.format(key, value))
        if pre_train_measurements is not None:
            for key, value in sorted(pre_train_measurements.items()):
                log_str = (log_str + '{} meta_train before: {:.3f} '
                                    ''.format(key, value))
        if post_train_measurements is not None:
            for key, value in sorted(post_train_measurements.items()):
                log_str = (log_str + '{} meta_train after: {:.3f} '
                                    ''.format(key, value))
        if post_val_measurements is not None:
            for key, value in sorted(post_val_measurements.items()):
                log_str = (log_str + '{} meta_val after: {:.3f} '
                                    ''.format(key, value))
        if fast_lr_list is not None:
            log_str += 'Alpha* Mean is:{}'.format(np.mean(fast_lr_list))
        # if embedding_grads_mean is not None:
        #     log_str = (log_str + 'embedding_grad_norm after: {:.3f} '
        #                             ''.format(embedding_grads_mean))
        print(log_str)
