import os
import math
import torch
import numpy as np
from scipy.signal import butter, filtfilt, lfilter
from torch.autograd import Variable
from collections import Counter

def function_var_to_cutoff(var):
    log_var = math.log(var)
    if log_var>=-2:
        return 500+1000/(1+math.exp(log_var+2))
    else:
        return 2000/(1+math.exp(2*(log_var+2)))

# def function_var_to_cutoff(var):
#     log_var = math.log(var)
#     if log_var>=-2:
#         return 500+2000/(1+math.exp(log_var+2))
#     else:
#         return 3000/(1+math.exp(2*(log_var+2)))

def butter_lowpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

def butter_lowpass_lfilt(data, cutoff, fs, order=5):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = lfilter(b, a, data)
    return y

def gaussian_filter(angular_batch, window):
    batch_len = len(angular_batch)
    gkv = np.exp(-(( np.arange(batch_len) - (batch_len//2-1) )** 2) / (2 * (window ** 2)))
    gkv /= gkv.sum()
    return (angular_batch * gkv).sum()

def pca_analysis(input, k = 5, q = 6, centered = True, normlaized=True):
    # 1. obtain projections on top k and the other principle components
    q = min([q, input.size(-2), input.size(-1)])
    assert q >= k, 'q is an overestimation of k so q should be no-less than k'
    _, _, e_vecs = torch.pca_lowrank(input, q, center=centered, niter=10)
    e_vecs = e_vecs[:, :k]

    # 2. absolute (and normalize) explained variance
    top_pcs = input.mm(e_vecs)
    other_pcs = (input - input.mm(e_vecs).mm(e_vecs.t())).norm(dim=1, keepdim=True)
    e_vars = torch.cat((top_pcs, other_pcs), dim = 1).abs()
    if normlaized:
        e_vars = e_vars / e_vars.norm(dim=1, keepdim=True)
        e_vars = e_vars** 2 # Are we normalizing to 1?
    return e_vars, e_vecs

def calculate_angle(v1, v2):
    epsilon = 1e-16
    dot_production = (v1*v2).sum()/(v1.norm()*v2.norm()+1e-8)
    range = torch.clamp(dot_production, -1.0+epsilon, 1.0-epsilon)
    return (torch.acos(range)/math.pi* 180).item()

def mkdir(path):
    '''create a single empty directory if it didn't exist
    Parameters: path (str) -- a single directory path'''
    if not os.path.exists(path):
        os.makedirs(path)

def mkdirs(paths):
    '''create empty directories if they don't exist
    Parameters: paths (str list) -- a list of directory paths'''
    # rmdirs(paths)
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)

def rmdirs(paths):
    if os.path.exists(paths):
        for file in os.listdir(paths): 
            file_path = os.path.join(paths, file)
            if os.path.isfile(file_path):
                os.remove(file_path)
        os.rmdir(paths)

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []
        self.counter = Counter()
        self.total = 0

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        token_id = self.word2idx[word]
        self.counter[token_id] += 1
        self.total += 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r') as f:
            ids = torch.LongTensor(tokens)
            token = 0
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1

        return ids

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)


# Starting from sequential data, batchify arranges the dataset into columns.
# For instance, with the alphabet as the sequence and batch size 4, we'd get
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
# │ e k q w │
# └ f l r x ┘.
# These columns are treated as independent by the model, which means that the
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
# batch processing.

def batchify(data, bsz, device):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


# get_batch subdivides the source data into chunks of length args.bptt.
# If source is equal to the example output of the batchify function, with
# a bptt-limit of 2, we'd get the following two Variables for i = 0:
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘
# Note that despite the name of the function, the subdivison of data is not
# done along the batch dimension (i.e. dimension 1), since that was handled
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.

def get_batch(source, i):
    #seq_len = min(args_bptt, len(source) - 1 - i)
    data = source[i-args_bptt:i]
    target = source[i+1-args_bptt:i+1].view(-1)
    return data, target