import torch, time
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
import copy
def count_improvement(base_result, new_result, weight):
    improvement = 0
    count = 0
    for task in list(base_result.keys()):
        improvement += (((-1)**np.array(weight[task]))*(np.array(base_result[task])-np.array(new_result[task]))/(np.array(base_result[task])+1e-4)).mean()
        count += 1
        print(task, 'improvement', improvement)
    print('\n')
    return improvement/count

class _PerformanceMeter(object):
    def __init__(self, task_dict, device, base_result=None):
        
        self.task_dict = copy.deepcopy(task_dict)
        self.task_num = len(self.task_dict)
        self.task_name = list(self.task_dict.keys())
        self.device=device
        
        self.weight = {task: self.task_dict[task]['weight'] for task in self.task_name}
        self.base_result = base_result
        self.best_result = {'improvement': -1e+2, 'epoch': 0, 'result': 0}
        
        self.losses = {task: self.task_dict[task]['loss_fn'] for task in self.task_name}
        self.metrics = {task: self.task_dict[task]['metrics_fn'] for task in self.task_name}
        
        self.results = {task:[] for task in self.task_name}
        self.loss_item = np.zeros(self.task_num)
        
        self.has_val = False
        
        self._init_display()
        self.improvement = None

        if self.device.type != 'cpu':
            self.start_data = torch.cuda.Event(enable_timing=True)
            self.end_data = torch.cuda.Event(enable_timing=True)
        
    def record_time(self, mode='begin'):
        if mode == 'begin':
            if self.device.type == 'cpu':
                self.beg_time = time.time()   
            else:
                self.start_data.record()
        elif mode == 'end':
            if self.device.type == 'cpu':
                self.end_time = time.time()
            else:
                self.end_data.record()
                torch.cuda.synchronize()
        else:
            raise ValueError('No support time mode {}'.format(mode))
        
    def update(self, preds, gts, batch_data, task_name=None, ds_to_sub_end=None):
        with torch.no_grad():
            if task_name is None:
                for tn, task in enumerate(self.task_name):
                    if task=='trigger' and ds_to_sub_end is not None:
                        self.metrics[task].update_fun(preds[task], gts[task], ds_to_sub_end, batch_data)
                    else:
                        self.metrics[task].update_fun(preds[task], gts[task], batch_data)
            else:
                self.metrics[task_name].update_fun(preds, gts, data)
        
    def get_score(self):
        with torch.no_grad():
            for tn, task in enumerate(self.task_name):
                self.results[task] = self.metrics[task].score_fun()
                self.loss_item[tn] = self.losses[task]._average_loss()
    
    def _init_display(self):
        print('='*40)
        print('LOG FORMAT | ', end='')
        for tn, task in enumerate(self.task_name):
            print(task+'_LOSS ', end='')
            for m in self.task_dict[task]['metrics']:
                print(m+' ', end='')
            print('| ', end='')
        print('TIME')
    
    def display(self, mode, epoch,flag=0, writer=None, global_step=None):
        if (epoch == 0 and self.base_result is None and mode==('val' if self.has_val else 'test') ) or self.base_result is None:
            self.base_result = self.results
        if mode == 'train':
            print('Epoch: {:04d} | '.format(epoch), end='')
            self._init_display()
        if not self.has_val and mode == 'test':
            self._update_best_result(self.results, epoch)
        if self.has_val and mode != 'train':
            self._update_best_result_by_val(self.results, epoch, mode)
        if mode == 'train':
            p_mode = 'TRAIN'
        elif mode == 'val':
            p_mode = 'VAL'
        else:
            p_mode = 'TEST'
        print('{}:flag:{} '.format(p_mode, flag), end='')
        for tn, task in enumerate(self.task_name):
            print('{:.4f} '.format(self.loss_item[tn]), end='')
            if writer is not None:
                # writer.add_scalar(mode+'/'+task+'_loss', self.loss_item[tn], epoch)
                # writer.close()
                if global_step is not None:
                    writer.add_scalar(mode+'/'+task+'_loss', self.loss_item[tn], global_step)
                    writer.close()
            for i in range(len(self.results[task])):
                print('{:.7f} '.format(self.results[task][i]), end='')
                if writer is not None:
                    # writer.add_scalar(mode+'/'+task+'_'+self.task_dict[task]['metrics'][i], self.results[task][i], epoch)
                    # writer.close()
                    if global_step is not None:
                        writer.add_scalar(mode+'/'+task+'_'+self.task_dict[task]['metrics'][i], self.results[task][i], global_step)
                        writer.close()
            print('| ', end='')
        if self.device.type == 'cpu':
            print('Time: {:.4f}'.format(self.end_time-self.beg_time), end='')
        else:
            print('Time: {:.4f}'.format(self.start_data.elapsed_time(self.end_data)/1000), end='')
        print(' | ', end='') if mode!='test' else print()
        print('\n')
        
    def display_best_result(self):
        print('='*40)
        print('Best Result: Epoch {}, result {}'.format(self.best_result['epoch'], self.best_result['result']))
        print('='*40)
        
    def _update_best_result_by_val(self, new_result, epoch, mode):
        if mode == 'val':
            # improvement = count_improvement(self.base_result, new_result, self.weight)
            # self.improvement = improvement
            # if improvement > self.best_result['improvement']:
            # self.best_result['improvement'] = improvement
            self.best_result['epoch'] = epoch
        else:
            if epoch == self.best_result['epoch']:
                self.best_result['result'] = new_result
        
    def _update_best_result(self, new_result, epoch):
        # improvement = count_improvement(self.base_result, new_result, self.weight)
        # self.improvement = improvement
        # if improvement > self.best_result['improvement']:
        # self.best_result['improvement'] = improvement
        self.best_result['epoch'] = epoch
        self.best_result['result'] = new_result
        
    def reinit(self):
        for task in self.task_name:
            self.losses[task]._reinit()
            self.metrics[task].reinit()
        self.loss_item = np.zeros(self.task_num)
        self.results = {task:[] for task in self.task_name}


