import numpy as np

from tqdm import tqdm

from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from scipy.integrate import trapezoid


def get_patches(img, patch_size, max_patches=None, random_state=0):
    data = extract_patches_2d(img, patch_size, max_patches=max_patches, random_state=random_state)
    data = np.reshape(data, (len(data), -1))
    return data

def middle_coord(n):
    #n is the (odd) sidelength of a square
    #we wish to extract the central coordinate
    k = (n + 1) // 2
    return (k - 1) * n + (k - 1)

def score_dset(x, y, random_state):
    pipe = Pipeline([('scaler', StandardScaler()), ('lr_model', LinearRegression())])
    scores = cross_val_score(pipe, x, y, cv=5)
    return np.mean(scores)

def make_img_patch_dset(img, patch_size, max_patches=None, random_state=0):
    psize = (patch_size, patch_size)
    _middle_idx = middle_coord(patch_size)
    data = get_patches(img, psize, max_patches, random_state)
    _idxs = list(range(patch_size ** 2))
    _idxs.pop(_middle_idx)

    return data[:, _idxs], data[:, _middle_idx]

def residual_determinism(img, patch_size, max_patches=None, random_state=0):
    x, y = make_img_patch_dset(img, patch_size, max_patches, random_state)
    return score_dset(x, y, random_state)

# def residual_determinism_curve(img, seed=0, verbose=False, max_patches=None):
#     sidelength = img.shape[0]
#     rng = np.random.RandomState(seed)
#     #size_range = list(range(3, sidelength // 4 + 3, 2))
#     size_range = list(range(3, sidelength // 4 - 1, 2))
#     scores = []
#     if verbose:
#         _iter = tqdm(size_range)
#     else:
#         _iter = size_range
        
#     for size in _iter:
#         scores.append(residual_determinism(
#             img, size, max_patches=max_patches, random_state=rng
#         ))

#     return np.array(size_range), np.array(scores)

# def average_residual_determinism(img, seed=0, verbose=False, max_patches=None):
#     _sizes, _scores = residual_determinism_curve(img, seed=seed, verbose=verbose, max_patches=max_patches)
#     return trapezoid(_scores, x=_sizes) / (_sizes[-1] - _sizes[0])

def residual_determinism_curve(img, seed=0, verbose=False, max_patches=None, size_range=None, score_func=score_dset):
    sidelength = img.shape[0]
    rng = np.random.RandomState(seed)
    #size_range = list(range(3, sidelength // 4 + 3, 2))
    if size_range is None:
        #size_range = list(range(3, sidelength // 4 - 1, 2))
        list(range(3, sidelength // 3, 4))
    scores = []
    if verbose:
        _iter = tqdm(size_range)
    else:
        _iter = size_range
        
    for size in _iter:
        scores.append(residual_determinism(
            img, size, max_patches=max_patches, random_state=rng, score_func=score_func
        ))

    return np.array(size_range), np.array(scores)

def average_residual_determinism(img, seed=0, verbose=False, max_patches=None, size_range=None, score_func=score_dset):
    _sizes, _scores = residual_determinism_curve(
        img, seed=seed, verbose=verbose, max_patches=max_patches, size_range=size_range, score_func=score_func
    )
    return trapezoid(_scores, x=_sizes) / (_sizes[-1] - _sizes[0])


def radial_profile(data):
    center = tuple(np.array(data.shape) // 2)
    y, x = np.indices(data.shape)
    r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
    r = r.astype('int')
    return r

def radial_average(data):
    r = radial_profile(data)
    radial_mean = np.bincount(r.ravel(), weights=data.ravel()) / np.bincount(r.ravel())
    return radial_mean

def average_pcorrelation_score(img, psize, max_patches=15_000, seed=0):
    dx = 2 / img.shape[0]
    
    patches = get_patches(img, (psize,psize), max_patches=max_patches, random_state=seed)
    pcorr = np.corrcoef(patches.T)
    pcorr_centre = pcorr[middle_coord(psize)].reshape(psize, psize)
    
    raverage = radial_average(pcorr_centre)
    dist = np.arange(len(raverage)) * dx

    return trapezoid(raverage, x=dist) / (dist[-1] - dist[0])