'''Create dot pattern potraits

Author: guangzhi XU (xugzhi1987@gmail.com)
Update time: 2021-12-22 22:13:43.
'''

import numpy as np
import cv2
import matplotlib.pyplot as plt

IMG_FILE = './dot_pattern.jpg'
CELL = 80
MARGIN = 6

def asStride(arr, sub_shape, stride):
    '''Get a strided sub-matrices view of an ndarray.

    Args:
        arr (ndarray): input array of rank 2 or 3, with shape (m1, n1) or (m1, n1, c).
        sub_shape (tuple): window size: (m2, n2).
        stride (int): stride of windows in both y- and x- dimensions.
    Returns:
        subs (view): strided window view.

    See also skimage.util.shape.view_as_windows()
    '''
    s0, s1 = arr.strides[:2]
    m1, n1 = arr.shape[:2]
    m2, n2 = sub_shape[:2]

    view_shape = (1+(m1-m2)//stride, 1+(n1-n2)//stride, m2, n2)+arr.shape[2:]
    strides = (stride*s0, stride*s1, s0, s1)+arr.strides[2:]
    subs = np.lib.stride_tricks.as_strided(
        arr, view_shape, strides=strides, writeable=False)

    return subs

def padArray(var, pad_y, pad_x):
    '''Pad array with 0s

    Args:
        var (ndarray): 2d or 3d ndarray. Padding is done on the first 2 dimensions.
        pad_y (int): number of rows to pad at bottom.
        pad_x (int): number of columns to pad at right.
    Returns:
        var_pad (ndarray): 2d or 3d ndarray with 0s padded along the first 2
            dimensions.
    '''
    if pad_y + pad_x == 0:
        return var

    var_pad = np.zeros((pad_y + var.shape[0], pad_x + var.shape[1]) + var.shape[2:])
    var_pad[:-pad_y, :-pad_x] = var

    return var_pad

def adjustVibrancy(image, increment=0.5):
    '''Adjust color vibrancy of image

    Args:
        image (ndarray): image in (H, W, C) shape. Color channel as the last
            dimension and in range [0, 255].
    Keyword Args:
        increment (float): color vibrancy adjustment. > 0 for enhance vibrancy,
            < 0 otherwise. If 0, no effects.
    Returns:
        image (ndarray): adjusted image.
    '''
    increment = np.clip(increment, -1, 1)
    min_val = np.min(image, axis=-1)
    max_val = np.max(image, axis=-1)
    delta = (max_val - min_val) / 255.
    L = 0.5 * (max_val + min_val) / 255.
    S = np.minimum(L, 1 - L) * 0.5 * delta
    #S = np.maximum(0.5*delta/L, 0.5*delta/(1 - L))

    if increment > 0:
        alpha = np.maximum(S, 1 - increment)
        alpha = 1. / alpha - 1
        image = image + (image - L[:,:,None] * 255.) * alpha[:,:,None]
    else:
        alpha = increment
        image = L[:,:,None] * 255. + (image - L[:,:,None] * 255.) * (1 + alpha)

    image = np.clip(image, 0, 255)

    return image

def screenBlend(image1, image2, alpha=1):
    '''Blend 2 layers using "screen" mode

    Args:
        image1 (ndarray): the layer on top.
        image2 (ndarray): the layer underneath. Both image1 and image2 are in
           range [0, 255].
    Returns:
        res (ndarray): blended image.

    Reference: https://photoblogstop.com/photoshop/photoshop-blend-modes-explained
    '''
    image1 = image1 / 255.
    image2 = image2 / 255.
    res =  1 - (1 - image1 * alpha) * (1 - image2)
    res = np.clip(res * 255, 0, 255)
    return res

def applyStencil(image, stencil):
    '''Apply stencil on image

    Args:
        image (ndarray): image in (H, W, C) shape. Color channel as the last
            dimension and in range [0, 255].
        stencil (ndarray): stencil, in [h, w] shape.
    Returns:
        out (ndarray): image filtered by stencil.
    '''
    h, w = image.shape[:2]
    sh, sw = stencil.shape
    ny = np.ceil(h / sh).astype('int')
    nx = np.ceil(w / sw).astype('int')

    out = np.copy(image)
    for ii in range(ny):
        for jj in range(nx):
            yij_low = ii*sh
            yij_high = min((ii+1)*sh, h)
            xij_low = jj*sw
            xij_high = min((jj+1)*sw, w)
            out[yij_low : yij_high, xij_low : xij_high, :] *=\
                    stencil[:(yij_high - yij_low), :(xij_high - xij_low), None]
    return out

def pixelateImage(image, cell):

    pad = padArray(image, cell - image.shape[0] % cell, cell - image.shape[1] % cell)

    pad_strided = asStride(pad, (cell, cell), cell)
    print(f'{pad_strided.shape = }')

    # pixelate image
    cell_means = np.mean(pad_strided, axis=(2,3), keepdims=True)
    pixelate = cell_means * np.ones(pad_strided.shape)

    pixelate = pixelate.transpose([0, 2, 1, 3, 4])

    # get back to the original shape
    pixelate = pixelate.reshape(pad.shape)

    pixelate = pixelate[:image.shape[0], :image.shape[1]]

    return pixelate

def stride2Image(strided, image_shape):

    res = strided.transpose([0, 2, 1, 3, 4])
    res = res.reshape(image_shape)
    return res

def plotImages(imgs, nrows, ncols, titles=None, colorbar=False):

    fig = plt.figure(figsize=(12,10), dpi=100)

    if not isinstance(imgs, (list, tuple)):
        imgs = [imgs,]

    for ii, imgii in enumerate(imgs):

        ax = fig.add_subplot(nrows, ncols, ii+1)
        cs = ax.imshow(imgii / 255.)
        ax.axis('off')
        if titles is not None:
            ax.set_title(titles[ii])
        if colorbar:
            fig.colorbar(cs, ax=ax)

    fig.tight_layout()
    return fig


#-------------Main---------------------------------
if __name__=='__main__':

    # Read in image
    image = cv2.imread(IMG_FILE)
    print(f'{image.shape = }')

    # BGR to RGB
    image = image[:, :, ::-1]

    # Create dot patterns
    stencil1 = np.zeros([CELL, CELL]).astype('uint8')
    yy, xx = np.mgrid[-CELL//2 : CELL//2, -CELL//2 : CELL//2]
    stencil1[(xx**2 + yy**2) <= (CELL//2)**2] = 1

    stencil = np.zeros([CELL, CELL]).astype('uint8')
    yy, xx = np.mgrid[-CELL//2 : CELL//2, -CELL//2 : CELL//2]
    stencil[(xx**2 + yy**2) <= (CELL//2 - MARGIN)**2] = 1


    fig, axes = plt.subplots(nrows=1, ncols=2)
    ax = axes[0]
    cs = ax.imshow(stencil1, interpolation='nearest')
    ax.set_title('stencil with size %d' %CELL)
    fig.colorbar(cs, ax=ax, orientation='horizontal')

    ax = axes[1]
    cs = ax.imshow(stencil, interpolation='nearest')
    ax.set_title('stencil with margin %d' %MARGIN)
    fig.colorbar(cs, ax=ax, orientation='horizontal')

    fig.show()

    #----------------- Save plot------------
    plot_save_name='disk_stencil'
    print('\n# <dot_pattern>: Save figure to', plot_save_name)
    fig.savefig(plot_save_name+'.png', dpi=100, bbox_inches='tight')

    # pad image size to the integer multiples of stencil
    pad = padArray(image, CELL - image.shape[0] % CELL, CELL - image.shape[1] % CELL)
    print(f'{pad.shape = }')

    # create strided view of the padded image
    pad_strided = asStride(pad, stencil.shape, CELL)
    print(f'{pad_strided.shape = }')

    # pixelate image
    cell_means = np.mean(pad_strided, axis=(2,3), keepdims=True)
    print(f'{cell_means.shape = }')

    pixelate = cell_means * np.ones(pad_strided.shape)
    print(f'{pixelate.shape = }')

    pixelate_img = stride2Image(pixelate, pad.shape)
    pixelate_img = pixelate_img[:image.shape[0], :image.shape[1]]

    fig = plotImages([image, pixelate_img], 1, 2, ['Original', 'Pixelated'])
    fig.show()

    #----------------- Save plot------------
    plot_save_name='pixelated_compare'
    print('\n# <dot_pattern>: Save figure to', plot_save_name)
    fig.savefig(plot_save_name+'.png', dpi=100, bbox_inches='tight')

    # apply the dot stencil
    filtered = pixelate * stencil[:, :, None]
    print(f'{filtered.shape = }')

    # get back to the original shape
    filtered = stride2Image(filtered, pad.shape)
    filtered = filtered[:image.shape[0], :image.shape[1]]
    print(f'{filtered.shape = }')

    fig = plotImages([filtered,], 1, 1, ['dot pattern'])
    fig.show()

    #----------------- Save plot------------
    plot_save_name='dot_pattern_img'
    print('\n# <dot_pattern>: Save figure to', plot_save_name)
    fig.savefig(plot_save_name+'.png', dpi=100, bbox_inches='tight')

    # mimic a vibrancy enhancement
    filtered_vib = adjustVibrancy(filtered, 0.3)

    image_vib = adjustVibrancy(image, 0.3)

    fig = plotImages([image, image_vib], 1, 2,
            ['(a) Original', '(b) Vibrancy enhanced'])
    fig.show()

    #----------------- Save plot------------
    plot_save_name='ori_vib_compare'
    print('\n# <dot_pattern>: Save figure to', plot_save_name)
    fig.savefig(plot_save_name+'.png', dpi=100, bbox_inches='tight')

    # mimic a screen blend mode
    filtered_vib_srn = screenBlend(filtered_vib, filtered_vib, 0.7)

    '''
    hsv = cv2.cvtColor(bb.astype('uint8'), cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    #h += value # 4
    s = np.where(s > 255-20, 255, s+20)
    v = np.where(v > 255-50, 255, v+50)
    final_hsv = cv2.merge((h, s, v))
    bb2 = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2RGB)
    bb = bb[:,:,::-1]
    '''

    fig = plotImages([image, filtered, filtered_vib, filtered_vib_srn], 2, 2,
            ['(a) Original', '(b) dot pattern', '(c) Vibrancy enhanced', '(d) Brightness enhanced'])
    fig.show()

    #----------------- Save plot------------
    plot_save_name='ori_dot_vib_screen_compare'
    print('\n# <dot_pattern>: Save figure to', plot_save_name)
    fig.savefig(plot_save_name+'.png', dpi=100, bbox_inches='tight')

    # save image
    cv2.imwrite('dot_pattern_final.jpg', filtered_vib_srn[:, :, ::-1])

