"""Image processing/preparation functions used in examples.keras."""
import itertools
import numpy as np
from skimage import exposure, filters, transform


def stitchImage(data, positions, channel=0):
    """Stitch an image back together that has been tiled by prepareNNImages."""
    stitch = positions["stitch"]
    # This has to take into account that the last tiles are sometimes shifted
    stitchedImageSize = positions["px"][-1][-1]
    stitchedImage = np.zeros([stitchedImageSize, stitchedImageSize])
    stitch1 = None if stitch == 0 else -stitch
    i = 0
    for position in positions["px"]:
        stitchedImage[
            position[0] + stitch : position[2] - stitch,
            position[1] + stitch : position[3] - stitch,
        ] = data[i, stitch:stitch1, stitch:stitch1, channel]
        i = i + 1
    return stitchedImage


def prepareNNImages(bact_img, ftsz_img, model, bacteria=False):
    """Preprocess raw iSIM images before running them throught the neural network.

    Returns a 3D numpy array that contains the data for the neural network and the
    positions dict generated by getTilePositions for tiling.
    """
    # Set iSIM specific values
    pixelCalib = 56  # nm per pixel
    sig = 121.5 / 81  # in pixel
    resizeParam = pixelCalib / 81  # no unit
    try:
        nnImageSize = model.layers[0].input_shape[0][1]
    except AttributeError:
        nnImageSize = model
    positions = None

    # Preprocess the images
    if nnImageSize is None or ftsz_img.shape[1] > nnImageSize:
        # Adjust to 81nm/px
        bact_img = transform.rescale(bact_img, resizeParam)
        ftsz_img = transform.rescale(ftsz_img, resizeParam)
        # This leaves an image that is smaller then initially

        # gaussian and background subtraction
        bact_img = filters.gaussian(bact_img, sig, preserve_range=True)
        ftsz_img = filters.gaussian(
            ftsz_img, sig, preserve_range=True
        ) - filters.gaussian(ftsz_img, sig * 5, preserve_range=True)

        # Tiling
        if nnImageSize is not None:
            positions = getTilePositionsV2(ftsz_img, nnImageSize)
            contrastMax = 255
        else:
            contrastMax = 1

        # Contrast
        ftsz_img = exposure.rescale_intensity(
            ftsz_img, (np.min(ftsz_img), np.max(ftsz_img)), out_range=(0, contrastMax)
        )
        bact_img = exposure.rescale_intensity(
            bact_img, (np.mean(bact_img), np.max(bact_img)), out_range=(0, contrastMax)
        )

    else:
        positions = {
            "px": [(0, 0, ftsz_img.shape[1], ftsz_img.shape[1])],
            "n": 1,
            "overlap": 0,
            "stitch": 0,
        }

    # Put into format for the network
    if nnImageSize is not None:
        ftsz_img = ftsz_img.reshape(1, ftsz_img.shape[0], ftsz_img.shape[0], 1)
        bact_img = bact_img.reshape(1, bact_img.shape[0], bact_img.shape[0], 1)
        inputDataFull = np.concatenate((bact_img, ftsz_img), axis=3)

        # Cycle through these tiles and make one array for everything
        i = 0
        inputData = np.zeros(
            (positions["n"] ** 2, nnImageSize, nnImageSize, 2), dtype=np.uint8()
        )
        for position in positions["px"]:

            inputData[i, :, :, :] = inputDataFull[
                :, position[0] : position[2], position[1] : position[3], :
            ]
            if bacteria:
                inputData[i, :, :, 1] = exposure.rescale_intensity(
                    inputData[i, :, :, 1],
                    (0, np.max(inputData[i, :, :, 1])),
                    out_range=(0, 255),
                )

            inputData[i, :, :, 0] = exposure.rescale_intensity(
                inputData[i, :, :, 0],
                (0, np.max(inputData[i, :, :, 0])),
                out_range=(0, 255),
            )
            i = i + 1
        inputData = inputData.astype("uint8")
    else:
        # This is now missing the tile-wise rescale_intensity for the mito channel.
        # Image shape has to be in multiples of 4, not even quadratic
        cropPixels = (
            bact_img.shape[0] - bact_img.shape[0] % 4,
            bact_img.shape[1] - bact_img.shape[1] % 4,
        )
        bact_img = bact_img[0 : cropPixels[0], 0 : cropPixels[1]]
        ftsz_img = ftsz_img[0 : cropPixels[0], 0 : cropPixels[1]]

        positions = getTilePositionsV2(bact_img, 128)
        bact_img = bact_img.reshape(1, bact_img.shape[0], bact_img.shape[0], 1)
        ftsz_img = ftsz_img.reshape(1, ftsz_img.shape[0], ftsz_img.shape[0], 1)
        inputData = np.stack((bact_img, ftsz_img), 3)

    return inputData, positions


def prepare_wo_tiling(images: np.ndarray):
    sig = 121.5 / 81
    out_range = (0, 1)

    for z_slice in range(images.shape[-1]):
        for channel in range(images.shape[-2]):

            image = images[:, :, channel, z_slice]
            # resc_image = transform.rescale(image, resize_param)
            image = filters.gaussian(image, sig)
            # Do the background subtraction for the Drp1/FtsZ channel only
            if channel == 1:
                image = image - filters.gaussian(
                    images[:, :, channel, z_slice], sig * 5
                )
            in_range = (
                (image.min(), image.max())
                if channel == 1
                else (image.mean(), image.max())
            )
            image = exposure.rescale_intensity(image, in_range, out_range=out_range)

            crop_pixels = (
                image.shape[0] - image.shape[0] % 4,
                image.shape[1] - image.shape[1] % 4,
            )
            image = image[: crop_pixels[0], : crop_pixels[1]]

            if z_slice == 0 and channel == 0:
                prep_images = np.empty(
                    [image.shape[0], image.shape[1], images.shape[-2], images.shape[-1]]
                )

            prep_images[:, :, channel, z_slice] = image
    return prep_images


def getTilePositionsV2(image, targetSize=128):
    """Generate tuples with the positions of tiles to split up an image withan overlap.

    Calculates the number of tiles in a way that allows for only
    full tiles to be needed.

    Args:
        filePath (PIL image): Image.open of a tiff file. Should be square and
        ideally from the geometric series (128, 256, 512, 1024, etc)
        targetSize (int, optional): target square size. Defaults to 128.

    Returns:
        [type]: [description]
    """
    # Check for the smallest overlap that gives a good result
    numberTiles = int(image.shape[0] / targetSize) + 1
    cond = False
    minOverlap = 35

    while not cond and numberTiles < targetSize and numberTiles > 1:
        overlap = (numberTiles * targetSize - image.shape[0]) / numberTiles - 1
        overlap = overlap - 1 if overlap % 2 else overlap
        if int(overlap) >= minOverlap:
            overlap = int(overlap)
            cond = True
        else:
            numberTiles = numberTiles + 1

    # For nxn tiles calculate the pixel positions considering the overlap
    numberTileRange = [range(0, numberTiles)] * 2
    positions = {
        "mn": tuple(itertools.product(*numberTileRange)),
        "px": [],
        "overlap": overlap,
        "stitch": int(overlap / 2),
        "n": numberTiles,
    }

    for position in positions["mn"]:
        positionXY = calculatePixel(position, overlap, targetSize, image.shape)
        positions["px"].append(positionXY)

    return positions


def calculatePixel(posMN, overlap, targetSize, shape):
    """Get the corresponding pixel start end x/y values to the tile defined by row/column in posMN.

    Args:
        posMN ([type]): tile as in row/column
        overlap ([type]): overlap between tiles as defined by getTilePositions
        targetSize ([type]): Size of the tiles
        shape ([type]): shape of the tiled image

    Returns:
        [type]: tuple of start/end x/y pixels of the tile
    """
    posXY = (
        int(posMN[0] * (targetSize - overlap)),
        int(posMN[1] * (targetSize - overlap)),
        int(posMN[0] * (targetSize - overlap) + targetSize),
        int(posMN[1] * (targetSize - overlap) + targetSize),
    )

    # shift the last one if it goes over the edge of the image
    if posMN[1] * (targetSize - overlap) + targetSize > shape[1]:
        shift = int(posMN[1] * (targetSize - overlap) + targetSize) - shape[1]
        posXY = (posXY[0], posXY[1] - shift, posXY[2], posXY[3] - shift)
        # print('Shifted vert for ', shift, 'pixel')

    if posMN[0] * (targetSize - overlap) + targetSize > shape[0]:
        shift = int(posMN[0] * (targetSize - overlap) + targetSize) - shape[0]
        posXY = (posXY[0] - shift, posXY[1], posXY[2] - shift, posXY[3])
        # print('Shifted hor for ', shift, 'pixel')

    return posXY
