"""generate chords progresion from chrods sequence
    Author: Joey.Zhu
    Email: joey8273@qq.com
    Date: 2024/5/3
"""
from tqdm import tqdm
import numpy as np

def chord_extraction(raw_chords, min_l=4, max_l=8):
    """extract the chord mode
    time complexity: o(n^2 * (max_l - min_l))
    Args:
        raw_chords (list): raw chords list
        min_l(int): minimum length of the generating chord progression
        max_l(int): maximum length of the generating chord progression

    Returns:
        tuple: augmented src sequence, augmented target sequence
    """
    n = len(raw_chords)
    assert n > 2 * min_l, 'input sequnce too short'
    dp = [[1 for _ in range(max_l - min_l + 1)] for _ in range(n + 1)]
    
    for i in range(2 * min_l, n + 1):
        for j in range(min_l, max_l + 1):
            for k in range(j, i - j + 1):
                if raw_chords[i - j:i] == raw_chords[i - k - j: i - k]:
                    dp[i][j - min_l] = dp[i - k][j - min_l] + 1
                    break
                # dp[i][j - min_l] = max(dp[i - k][j - min_l] + (1 if raw_chords[i - j:i] == raw_chords[i - k - j: i - k] else 0), dp[i][j - min_l])

    start, end = 0, 0
    max_n = 0
    for i in range(min_l, max_l + 1):
        for j in range(min_l, n + 1):
            if dp[j][i - min_l] > max_n: # j = 16 i = 4
                start, end = j - i, j 
                max_n = dp[j][i - min_l]
                
    return raw_chords[start: end], max_n

def merge(raw, ietmpreature = 20, ignore='NA'):
    """  merge the raw chord list and calculate the weight 
    Args:
        raw(list): raw chords list
        temperature(int): the temperature index in the softmax
        ignore(str): the ignore chord 
    Returns:
        chords: the finnal merged chords
        softmax_dict: the weight of chords
    """
    chords, cnt, last  = [], {}, "last"

    for item in raw:
        if item == ignore:
            continue
        if item != last:
            chords.append(item)
            last = item
        cnt[item] = cnt.get(item, 0) + 1
    
    values = np.array(list(cnt.values()))
    softmax_values = softmax_with_temperature(values, ietmpreature) # this cound be a contrastive experience
    softmax_dict = {key: softmax_values[i] for i, key in enumerate(cnt.keys())}
    total = 0
    for _, val in softmax_dict.items():
        total += val
    return chords, softmax_dict

def softmax_with_temperature(logits, temperature=1.0):
    # the higher temperature, the smoother
    exp_logits = np.exp(logits / temperature)
    return exp_logits / np.sum(exp_logits)

def chord_extraction_with_power(raw_chords, power, min_l=4, max_l=8):
    """
    extract the chord mode
    version: v2
    """
    n = len(raw_chords)
    # assert n > 2 * min_l, 'input sequnce too short'
    if n <= 2 * min_l:
        print("chord_progression too short")
        return raw_chords, -1

    dp = [[1 for _ in range(max_l - min_l + 1)] for _ in range(n + 1)]
    
    for i in range(2 * min_l, n + 1):
        for j in range(min_l, max_l + 1):
            for k in range(j, i - j + 1):
                if raw_chords[i - j:i] == raw_chords[i - k - j: i - k]:
                    dp[i][j - min_l] = dp[i - k][j - min_l] + sum([power[chord] for chord in raw_chords[i - j:i]]) / j
                    break

    start, end = 0, 0
    max_n = 0
    for i in range(min_l, max_l + 1):
        for j in range(min_l, n + 1):
            if dp[j][i - min_l] > max_n: # j = 16 i = 4
                start, end = j - i, j 
                max_n = dp[j][i - min_l]
                
    return raw_chords[start: end], max_n


def generate_chords_progression(raw_chords, version=2):
    if version == 2:
        chords, chord2cnt = merge(raw_chords)
        return chord_extraction_with_power(chords, chord2cnt)
    if version == 1:
        chords, _ = merge(raw_chords)
        return chord_extraction(chords)