import numpy as np

import torch
import torch.nn as nn

from datetime import datetime, timedelta
from timeit import default_timer as default_timer
from string import Formatter

def calc_L2_norm(f, x1, x2, nPoints=100000, device=None):

    if device is not None:
        pass
    elif hasattr(f, 'device'):
        device = f.device
        
    if not (x2 > x1):
        raise Exception("x2 must be greater than x1")
    
    step_size = (x2-x1)/(nPoints-1)

    x1 = torch.arange(x1, x2+step_size/2, step_size, device=device)
    y1 = f(x1)

    return np.sqrt( torch.mean(torch.square(y1)).item() ) 



def randperm_batch(size, nElements, device=None):
    # Returns a batch of random permutations of nElements.
    # Examples:
    #
    # > randperm_batch( (), nElements): 
    # Returns a one-dimensional vector which is a random permutation of 0,...,nElements-1
    #
    # If size = (d0,d1,...,dk):
    # Returns a tensor P of size (d0,d1,...,dk,nElements), such that each P(i0,i1,...,ik,:) is a random permutation of 0,1,...,nElements-1.
    #
    # If size = s an integer:
    # Same as above for size=(s,). i.e. returns a tensor P of size (s,p).

    if type(size) not in (list,tuple):
        size = (size,)

    size += (nElements,)

    rand_nums = torch.rand(size=size, device=device)
    out = torch.argsort(rand_nums, dim=len(size)-1)
    #rand_nums = np.random.rand(*size)
    #out = np.argsort(rand_nums, axis=len(size)-1)
    
    return out


def is_sorted(lst, key=lambda x: x):
    for i, el in enumerate(lst[1:]):
        if key(el) < key(lst[i]): # i is the index of the previous element
            return False
    return True

def mem_report():
    # Returns a string that reports the total, used and free memory on the GPU

    (total_free_gpu_memory, total_gpu_memory) = torch.cuda.mem_get_info()
    total_used_gpu_memory = total_gpu_memory-total_free_gpu_memory
    
    report = ('CUDA memory: Total: %s  Used: %s  Free: %s' % ( f"{total_gpu_memory:,}", f"{total_used_gpu_memory:,}", f"{total_free_gpu_memory:,}") )
    return report


class Log():
    # Prints text to a log file and to the screen simultaneously.
    # Basic usage example:
    # > log = utils.Log(fname='logs/file.log')
    # > log('Started running')
    # > for i in range(10):
    # >     log.write('.')
    # > log(' done')
    # > log('Result: %g' % (result))

    def __init__(self, fname=None, screen=True):
        # If fname is set to None, the log is not saved to a file.
        # screen: Tells whether by default to print text to the screen or not.
        #         This can be overridden in the __call__ command 
        self.fname = fname
        self.screen = screen

        if (fname is not None) and (fname != ''):
            f = open(fname, 'w')
            f.close()

    # Note: __call__() and write() differ only in the default value of newline.
    def __call__(self, text='', newline = True, screen=None):
        self.write(text, newline, screen)

    def write(self, text='', newline = False, screen=None):
        # newline: Tells whether to write an end-of-line character at the end of the text.
        #          This affects both screen output and file output.
        # screen: Overrides the default 'screen' option given at this instance's construction.
        #         Tells whether to print the text to the screen.
        text_str = str(text)

        if newline:
            suffix = '\n'
        else:
            suffix = ''

        if screen is None:
            screen = self.screen

        if screen:
            print(text_str, end=suffix)

        if self.fname is not None:
            f = open(self.fname, 'a')
            f.write(text_str + suffix)
            f.close()


def nowstr(time=None, mode='default'):
    # Returns a string that describes the time now.
    # time: Optionally describe another time. Should be of datetime class.
    # mode: 'default': 2023-05-07 16:43:38
    #       'fname':   2023-05-07_16.43.38

    if time is None:
        time = datetime.now()

    if mode == 'default':
        s = time.strftime("%Y-%m-%d %H:%M:%S")
    elif mode == 'fname':
        s = time.strftime("%Y-%m-%d_%H.%M.%S")
    else:
        raise Exception('Invalid mode')
        
    return s


def dt2str(time):
    # Converts a datetime.timedelta or a (possibly non-integer) number of seconds to a string.
    if not isinstance(time, timedelta):
        time = timedelta(seconds=time)

    return strftimedelta(time)


class Timer():
    """ Measures time elapsed, calculates ETA, and reports in human readable format.
    Usage example:    

    import time
    t = Timer()
    nIter = 70

    # Iteration zero
    i=0; part_done = i/nIter
    print('%d: Time elapsed: %s  ETA: %s  (Seconds: elapsed %g, ETA %g)' % (i, t.str(), t.etastr(part_done), t(), t.eta(part_done)) )    

    for i in range(1,nIter+1):
        time.sleep(1)
        part_done = i/nIter
        print('%d: Time elapsed: %s  ETA: %s  (Seconds: elapsed %g, ETA %g)' % (i, t.str(), t.etastr(part_done), t(), t.eta(part_done)) )    
        # Note that in the above line, t() is typically called a few fractions of a second after t.str(), and the same goes for
        # t.eta() vs. t.stastr(). Thus, it is ok that the times reported vary slightly between each of those pairs.
    """

    def __init__(self):
        self.start = default_timer()

    def __call__(self):
        return default_timer() - self.start
    
    def reset(self):
        self.start = default_timer()

    def str(self):
        return dt2str(self())

    def eta(self, part_done):
        if part_done == 0:
            return np.nan
        else:
            time_tot = self() / part_done
            return time_tot * (1-part_done)

    def etastr(self, part_done):
        if part_done == 0:
            return 'n/a'
        else:
            return dt2str(self.eta(part_done))


def strftimedelta(tdelta, fmt=None):
    """Convert a datetime.timedelta object or a regular number of seconds to a custom-
    formatted string, just like the stftime() method does for datetime.datetime
    objects.

    Based on tomatoeshift's post on StackOverflow:
    https://stackoverflow.com/questions/538666/format-timedelta-to-string/63198084#63198084

    tdelta may me of class datetime.timedelta, or a (possibly non-integer) number of
    seconds.

    The fmt argument allows custom formatting to be specified. Fields can 
    include seconds, minutes, hours, days, and weeks.  Each field is optional.

    Some examples:
        '{D:02}d {H:02}h {M:02}m {S:02.0f}s {mS:03.0f}ms' --> '05d 08h 04m 02s 032ms'
        '{D}d {H}:{M:02}:{S:02.0f}.{mS:03.0f}'            --> '5d 8:04:02.001'
        '{D}d {H}h{M:02}m{S:02.0f}s {mS:01.0f}'           --> '5d 8h04m02s 32ms'
        '{W}w {D}d {H}:{M:02}:{S:02.0f}'                  --> '4w 5d 8:04:02'
        '{D:2}d {H:2}:{M:02}:{S:02.0f}'                   --> ' 5d  8:04:02'
        '{H}h {S:.0f}s'                                   --> '72h 800s'
    By default, the format is chosen automatically from:
        '5d 8h04m02s 12ms' / '8h04m02s 12ms' / '4m02s 12ms' / '2s 12ms'
    """

    # Convert tdelta to seconds
    if isinstance(tdelta, timedelta):
        remainder = tdelta.total_seconds()
        #remainder = float(tdelta/timedelta(milliseconds=1)) / 1000
    else:
        remainder = float(tdelta)

    if fmt is None:
        if remainder >= 86400:
            fmt = '{D}d {H}h{M:02}m{S:02.0f}s {mS:01.0f}ms'
        elif remainder >= 60*60:
            fmt = '{H}h{M:02}m{S:02.0f}s {mS:01.0f}ms'
        elif remainder >= 60:
            fmt = '{M:01}m{S:02.0f}s {mS:01.0f}ms'
        else:
            fmt = '{S:01.0f}s {mS:01.0f}ms'

    f = Formatter()
    desired_fields = [field_tuple[1] for field_tuple in f.parse(fmt)]
    possible_fields = ('Y','m','W', 'D', 'H', 'M', 'S', 'mS', 'µS')
    constants = {'Y':86400*365.24,'m': 86400*30.44 ,'W': 604800, 'D': 86400, 'H': 3600, 'M': 60, 'S': 1, 'mS': 1/pow(10,3) , 'µS':1/pow(10,6)}
    values = {}
    for field in possible_fields:
        if field in desired_fields and field in constants:
            Quotient, remainder = divmod(remainder, constants[field])
            values[field] = int(Quotient) if field != 'S' else Quotient + remainder
    return f.format(fmt, **values)
