import os
import pandas as pd
import shutil
import torch


# other util
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)  # (64, k)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum()
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def save_checkpoint(state, is_best, checkpoint, model_best):
    torch.save(state, checkpoint)
    if is_best:
        shutil.copyfile(checkpoint, model_best)


def record_info(info, filename, mode, distill=False, triple_ce=False, eval_naive_encoders=False):

    if mode == 'train':

        if eval_naive_encoders:

            result = ('Flow Loss {} '
                      'Flow Prec@1 {}\n'
                      'RGB Loss {} '
                      'RGB Prec@1 {}\n'
                      'Loss {} '
                      'Prec@1 {} '
                      'Prec@5 {} \n'
                      'lr {}').format(info['Flow Loss'], info['Flow Prec@1'], info['RGB Loss'],
                                   info['RGB Prec@1'], info['Loss'], info['Prec@1'], info['Prec@5'],
                                   info['lr'])
            print(result)

            df = pd.DataFrame.from_dict(info)
            column_names = ['Epoch', 'Flow Loss', 'Flow Prec@1', 'RGB Loss', 'RGB Prec@1', 'Loss',
                            'Prec@1', 'Prec@5', 'lr']

        elif distill:
            result = (
                'Time {batch_time} '
                'Data {data_time} \n'
                'CE Loss {ce_loss} '
                'RGB Distill Loss {rgb_distill_loss} '
                'Flow Distill Loss {flow_distill_loss}\n'
                'Prec@1 {top1} '
                'Prec@5 {top5}\n'
                'LR {lr}\n'.format(batch_time=info['Batch Time'], data_time=info['Data Time'],
                                   ce_loss=info['CE Loss'], rgb_distill_loss=info['RGB Distill Loss'],
                                   flow_distill_loss=info['Flow Distill Loss'],
                                   top1=info['Prec@1'], top5=info['Prec@5'], lr=info['lr']))
            print(result)

            df = pd.DataFrame.from_dict(info)
            column_names = ['Epoch', 'Batch Time', 'Data Time', 'CE Loss', 'RGB Distill Loss', 'Flow Distill Loss',
                            'Prec@1', 'Prec@5', 'lr']

        elif triple_ce:
            result = (
                'Time {batch_time} '
                'Data {data_time}\n'
                'Joint Loss {ce_loss} '
                'Flow Loss {flow_ce_loss} '
                'RGB Loss {rgb_ce_loss}\n'
                'Prec@1 {top1} '
                'Prec@5 {top5}\n'
                'LR {lr}\n'.format(batch_time=info['Batch Time'], data_time=info['Data Time'],
                                   ce_loss=info['Joint Loss'], flow_ce_loss=info['Flow Loss'],
                                   rgb_ce_loss=info['RGB Loss'],
                                   top1=info['Prec@1'], top5=info['Prec@5'], lr=info['lr']))
            print(result)
            df = pd.DataFrame.from_dict(info)
            column_names = ['Epoch', 'Batch Time', 'Data Time', 'Joint Loss', 'Flow Loss', 'RGB Loss',
                            'Prec@1', 'Prec@5', 'lr']

        else:
            result = (
                  'Time {batch_time} '
                  'Data {data_time} \n'
                  'Loss {loss} '
                  'Prec@1 {top1} '
                  'Prec@5 {top5}\n'
                  'LR {lr}\n'.format(batch_time=info['Batch Time'], data_time=info['Data Time'],
                                     loss=info['Loss'], top1=info['Prec@1'], top5=info['Prec@5'], lr=info['lr']))
            print(result)

            df = pd.DataFrame.from_dict(info)
            column_names = ['Epoch', 'Batch Time', 'Data Time', 'Loss', 'Prec@1', 'Prec@5', 'lr']

    if mode == 'test':

        if eval_naive_encoders:
            result = ('Loss {} '
                      'Prec@1 {} '
                      'Prec@5 {}\n'
                      'Flow Prec@1 {} '
                      'RGB Prec@1 {} ').format(info['Loss'], info['Prec@1'], info['Prec@5'],
                                               info['Flow Prec@1'], info['RGB Prec@1'])
            print(result)
            df = pd.DataFrame.from_dict(info)
            column_names = ['Epoch', 'Loss', 'Prec@1', 'Prec@5', 'Flow Prec@1', 'RGB Prec@1']

        else:
            result = (
                  'Time {batch_time} \n'
                  'Loss {loss} '
                  'Prec@1 {top1} '
                  'Prec@5 {top5} \n'.format( batch_time=info['Batch Time'],
                   loss=info['Loss'], top1=info['Prec@1'], top5=info['Prec@5']))
            print(result)
            df = pd.DataFrame.from_dict(info)
            column_names = ['Epoch', 'Batch Time', 'Loss', 'Prec@1', 'Prec@5']
    
    if not os.path.isfile(filename):
        df.to_csv(filename, index=False, columns=column_names)
    else: # else it exists so append without writing the header
        df.to_csv(filename, mode='a', header=False, index=False, columns=column_names)


