# Licensed under a 3-clause BSD style license - see LICENSE.rst

from astropy import log
from astropy.wcs import WCS
import numpy as np
import psutil

from sofia_redux.toolkit.image.combine import combine_images
from sofia_redux.toolkit.utilities.fits import hdinsert
from sofia_redux.toolkit.utilities.func import stack
from sofia_redux.toolkit.resampling.resample import Resample
from sofia_redux.toolkit.image.warp import warp_image

__all__ = ['coadd']


def _target_xy(header, outwcs):
    """
    Retrieve target x, y coordinates.

    Parameters
    ----------
    header : astropy.io.fits.Header
        Header to retrieve target RA, Dec from (TGTRA, TGTDEC).
    outwcs : astropy.wcs.WCS
        WCS to transform into

    Returns
    -------
    x, y : float, float
        Target x and y position in provided WCS.
    """
    tgt_x, tgt_y = None, None
    tgt_ra = header.get('TGTRA', None)
    tgt_dec = header.get('TGTDEC', None)
    if tgt_ra is not None and tgt_dec is not None \
            and not np.allclose([tgt_ra, tgt_dec], 0):
        # convert from hours to degrees
        tgt_ra *= 15.0
        if outwcs.wcs.naxis == 2:
            tgt_x, tgt_y = \
                outwcs.wcs_world2pix(tgt_ra, tgt_dec, 0)
        else:
            tgt_w, tgt_y, tgt_x = \
                outwcs.wcs_world2pix(0, tgt_dec, tgt_ra, 0)
    return tgt_x, tgt_y


def coadd(hdr_list, data_list, var_list, exp_list,
          method='mean', weighted=True, robust=True, sigma=8.0,
          maxiters=5, spectral=False, cube=False, wcskey=' ',
          rotate=True, fit_order=2, window=7.0, smoothing=2.0,
          adaptive_algorithm=None, edge_threshold=0.7,
          reference='first'):
    """
    Coadd total intensity or spectral images.

    The WCS is used to transform images into a common coordinate
    system.  By default, the reference field is the WCS for the first
    data set provided.  Optionally, the reference may be corrected
    for target position motion, as for a non-sidereal target.

    For coadd methods 'mean' or 'median', each image is interpolated into
    the reference frame, then all images are combined using the chosen
    statistic.  Note that this method may be memory-intensive for large
    fields.  The coadd method 'resample' uses locally weighted polynomial
    surface fits to resample data onto the output grid
    (see `sofia_redux.toolkit.resampling` for more information).  Exposure
    maps are always generated by interpolating and summing individual maps.

    The output may be either a 2D image, for either spectral or imaging
    data, or a 3D spectral cube (cube = True).  If cube is set, then the
    method is always 'resample'.

    Parameters
    ----------
    hdr_list : list of astropy.io.fits.Header
        List of headers associated with the data to combine.
        The first header in the list is used as the reference.
    data_list : list of numpy.ndarray of float
        List of flux arrays to combine.
    var_list : list of numpy.ndarray of float
        List of variance arrays associated with flux arrays.
    exp_list : list of numpy.ndarray of float
        List of exposure time maps associated with the flux arrays.
    method : {'mean', 'median', 'resample'}, optional
        Method for combining data into the output map.  For 'mean' or
        'median', data is interpolated into the output grid, then
        combined with the selected statistic.  For 'resample', data
        are sampled onto the output grid with locally weighted polynomial
        fits.
    weighted : bool, optional
        If set, input flux values will be weighted by the variance values,
        for 'mean' or 'resample' methosds.
    robust : bool, optional
        If set, input flux values will be sigma-clipped before combining,
        for 'mean' or 'median' methods.
    sigma : float, optional
        The sigma value to use for clipping if the `robust` option is set.
    maxiters : int, optional
        Maximum number of sigma-clipping iterations to use if the
        `robust` option is set.
    spectral : bool, optional
        If not set, any dimensions higher than 2 in the input WCS will
        be ignored. This is required for compatibility with old-style
        FORCAST imaging data (pipeline version < 2.0).
    cube : bool, optional
        If set, spectral data is assembled into a 3D cube (nw, ny, nx)
        instead of a 2D spectral image (nw, ny).
    wcskey : str, optional
        Indicates the WCS to use for assembling data.  If ' ', the primary
        WCS is used.  For spectral data, the alternate WCS with key 'A'
        is expected.
    rotate : bool, optional
        If set, data is rotated to set North up and East left, in RA/Dec
        coordinates.  This option is not recommended for 2D spectral images.
    fit_order : int, optional
        The polynomial fit order to use with the 'resample' method.
    window : float, optional
        The local fitting window (in pixels) to use  with the 'resample'
        method.
    smoothing : float, optional
        The Gaussian smoothing radius (in pixels) to use  with the 'resample'
        method.
    adaptive_algorithm : {'scaled', 'shaped', None}, optional
        Algorithm for adaptive smoothing kernel.  If scaled, only the
        size is allowed to vary.  If shaped, the kernel shape and
        rotation may also vary.
    edge_threshold : float, optional
        Used to determine how much of the image edges should be masked,
        Specified as a fraction of the fit window; lower values clip more
        pixels.
    reference : {'first', 'target'}, optional
        If set to 'target', the output coordinates for each input file
        will be corrected for target motion, as recorded in the TGTRA
        and TGTDEC keywords. This is necessary to correct for the motion
        of non-sidereal targets.  If TGTRA/DEC are not found, no correction
        will be made.

    Returns
    -------
    header : astropy.io.fits.Header
        The output header with appropriate WCS.
    flux : numpy.ndarray
        The output flux image or cube.
    variance : numpy.ndarray
        The output variance image or cube.
    expmap : numpy.ndarray
        The exposure map associated with the flux. This array
        will always be 2D, even for cube outputs.
    """

    # cube is only supported for spectral data
    if cube:
        spectral = True

    # reference all data to the first file
    out_header = hdr_list[0].copy()

    # set reference angle to zero if it isn't already
    key = wcskey.strip().upper()
    if rotate:
        for wkey in [f'CROTA2{key}',
                     f'PC1_1{key}', f'PC1_2{key}',
                     f'PC2_1{key}',
                     f'PC2_2{key}', f'PC2_3{key}',
                     f'PC3_2{key}', f'PC3_3{key}']:
            if wkey in out_header:
                if wkey == f'CROTA2{key}':
                    out_header[wkey] = 0.0
                else:
                    del out_header[wkey]

        # swap RA to east-left if needed
        ra = f'CDELT1{key}'
        if not cube and ra in out_header and out_header[ra] > 0:
            out_header[ra] *= -1

    # turn down logging to avoid FITS warning for 3D coord sys
    olevel = log.level
    log.setLevel('ERROR')
    if not spectral:
        outwcs = WCS(out_header, key=wcskey, naxis=2)
    else:
        outwcs = WCS(out_header, key=wcskey)
    log.setLevel(olevel)

    wcs_dim = outwcs.wcs.naxis
    if cube and wcs_dim < 3:
        msg = 'WCS is not 3D. Cannot make cube.'
        log.error(msg)
        raise ValueError(msg)

    if cube:
        # expectation is that 3D coord was in a secondary WCS --
        # we don't handle it if not
        if key == '':
            log.error('Unexpected input WCS condition. '
                      'Cannot fix output header.')
            raise ValueError

        method = 'resample'
        if 'SLTW_PIX' not in out_header:
            log.warning('Slit width not in header; output flux '
                        'may not be conserved.')
        float_slitw = out_header.get('SLTW_PIX', 1.0)
        slit_width = int(np.round(float_slitw))
    else:
        float_slitw = 1.0
        slit_width = 1

    # if referencing to a target RA/Dec (e.g. for nonsidereal targets),
    # get the target position in reference x, y coordinates
    tgt_x, tgt_y = None, None
    if reference == 'target':
        tgt_x, tgt_y = _target_xy(out_header, outwcs)
        if None in (tgt_x, tgt_y):
            msg = 'Missing TGTRA or TGTDEC; cannot reference to target.'
            log.warning(msg)

    out_coord_x = []
    out_coord_y = []
    out_coord_w = []
    flxvals = []
    errvals = []
    expvals = []
    corners = []
    for (hdr, flux, var, exp) in zip(hdr_list, data_list, var_list, exp_list):
        # input wcs
        if not spectral:
            inwcs = WCS(hdr, key=wcskey, naxis=2)
        else:
            inwcs = WCS(hdr, key=wcskey)

        # assemble flux, error, and exposure map values
        ny, nx = flux.shape
        err = np.sqrt(var)
        good = ~np.isnan(flux) & ~np.isnan(err)
        if not np.any(good):
            log.warning(f"No good data in "
                        f"{hdr.get('FILENAME', 'UNKNOWN')}; skipping.")
            continue
        if method == 'resample':
            flxvals.append(flux[good])
            errvals.append(err[good])
        else:
            flxvals.append(flux)
            errvals.append(err)
        if cube:
            # exposure value is at one wavelength only, with
            # slit width size, plus two zero columns for padding
            expval = exp[:, 0:slit_width + 2]
            expval[:, 0] = 0
            expval[:, -1] = 0
            expvals.append(expval)
        else:
            expvals.append(exp)

        # index values for resampling
        yin, xin = np.meshgrid(np.arange(ny), np.arange(nx), indexing='ij')
        yin = yin[good]
        xin = xin[good]
        xamin, xamax = np.argmin(xin), np.argmax(xin)
        yamin, yamax = np.argmin(yin), np.argmax(yin)

        # corner values for interpolation
        if cube:
            in_corner = [[xin[xamin], xin[xamin],
                          xin[xamax], xin[xamax]],
                         [yin[yamin], yin[yamax],
                          yin[yamin], yin[yamax]],
                         [-slit_width / 2 + 0.5, -slit_width / 2 + 0.5,
                          slit_width / 2 - 0.5, slit_width / 2 - 0.5]]
        else:
            in_corner = [[xin[xamin], xin[xamin],
                          xin[xamax], xin[xamax]],
                         [yin[yamin], yin[yamax],
                          yin[yamin], yin[yamax]]]

        # transform all coords to reference WCS
        if wcs_dim == 2:
            wxy = inwcs.wcs_pix2world(xin, yin, 0)
            oxy = outwcs.wcs_world2pix(*wxy, 0)
            cxy = inwcs.wcs_pix2world(*in_corner, 0)
            out_corner = outwcs.wcs_world2pix(*cxy, 0)
        else:
            wxy = inwcs.wcs_pix2world(xin, yin, 0, 0)
            oxy = outwcs.wcs_world2pix(*wxy, 0)
            if cube:
                cxy = inwcs.wcs_pix2world(*in_corner, 0)
                out_corner = outwcs.wcs_world2pix(*cxy, 0)
                # ra, dec corners
                in_corner = [in_corner[2], in_corner[1]]
                # correct for slit width offset in not-yet
                # existant 3rd dimension
                out_corner = np.array([out_corner[2] - slit_width / 2,
                                       out_corner[1]])
            else:
                cxy = inwcs.wcs_pix2world(*in_corner, 0, 0)
                out_corner = outwcs.wcs_world2pix(*cxy, 0)[0:2]

        # correct all coordinates for target movement
        x_off, y_off = 0., 0.
        if None not in [tgt_x, tgt_y]:
            upd_x, upd_y = _target_xy(hdr, outwcs)
            if None in [upd_x, upd_y]:
                log.warning(f"Missing target RA/Dec in file "
                            f"{hdr.get('FILENAME', 'UNKNOWN')}.")
            else:
                x_off = tgt_x - upd_x
                y_off = tgt_y - upd_y

        if cube and wcs_dim == 3:
            # assuming crval1=wavelength, crval2=dec, crval3=ra
            out_coord_w.append(oxy[0])
            out_coord_y.append(oxy[1] + y_off)
            out_coord_x.append(oxy[2] + x_off)
        else:
            out_coord_x.append(oxy[0] + x_off)
            out_coord_y.append(oxy[1] + y_off)

        out_corner[0] += x_off
        out_corner[1] += y_off
        corners.append((in_corner, out_corner))

    # output grid shape
    stk_coord_x = np.hstack(out_coord_x)
    minx, maxx = np.min(stk_coord_x), np.max(stk_coord_x)
    stk_coord_y = np.hstack(out_coord_y)
    miny, maxy = np.min(stk_coord_y), np.max(stk_coord_y)

    # shift coordinates to new grid
    stk_coord_x -= minx
    stk_coord_y -= miny

    # stack coordinates for output grid
    if cube:
        stk_coord_w = np.hstack(out_coord_w)
        minw, maxw = np.min(stk_coord_w), np.max(stk_coord_w)
        out_shape = (int(np.ceil(maxw) - np.floor(minw) + 1),
                     int(np.ceil(maxy) - np.floor(miny) + 1),
                     int(np.ceil(maxx) - np.floor(minx)) + 1)
        stk_coord_w -= minw
        coordinates = stack(stk_coord_x, stk_coord_y, stk_coord_w)

        xout = np.arange(out_shape[2], dtype=np.float64)
        yout = np.arange(out_shape[1], dtype=np.float64)
        wout = np.arange(out_shape[0], dtype=np.float64)
        grid = xout, yout, wout

        # fix header reference pixel for new min value in w and x
        out_header['CRPIX1' + key] -= minw
        out_header['CRPIX2' + key] -= miny
        out_header['CRPIX3' + key] -= minx
    else:
        out_shape = (int(np.ceil(maxy) - np.floor(miny) + 1),
                     int(np.ceil(maxx) - np.floor(minx)) + 1)

        coordinates = stack(stk_coord_x, stk_coord_y)

        xout = np.arange(out_shape[1], dtype=np.float64)
        yout = np.arange(out_shape[0], dtype=np.float64)
        grid = xout, yout

        # fix header reference pixel
        out_header['CRPIX1' + key] -= minx
        out_header['CRPIX2' + key] -= miny

        # also fix primary coordinates for 2D spectrum
        if key != '' and wcs_dim > 2:
            out_header['CRPIX1'] -= minx
            out_header['CRPIX2'] -= miny

    log.info('Output shape: {}'.format(out_shape))

    # use local polynomial fits to resample and coadd data
    if method == 'resample':
        flxvals = np.hstack(flxvals)
        errvals = np.hstack(errvals)

        if cube:
            edge_threshold = (edge_threshold, edge_threshold, 0)
            window = (window, window, 2.0)
            smoothing = (smoothing, smoothing, 1.0)
            if adaptive_algorithm in ['scaled', 'shaped']:
                adaptive_threshold = (1.0, 1.0, 0.0)
            else:
                adaptive_threshold = None
                adaptive_algorithm = None
        else:
            if adaptive_algorithm in ['scaled', 'shaped']:
                adaptive_threshold = 1.0
            else:
                adaptive_threshold = None
                adaptive_algorithm = None

        max_cores = psutil.cpu_count() - 1
        if max_cores < 2:  # pragma: no cover
            max_cores = None

        log.info('Setting up output grid.')
        resampler = Resample(
            coordinates, flxvals, error=errvals,
            window=window, order=fit_order, fix_order=False)

        log.info('Resampling flux data.')
        flux, std = resampler(
            *grid, smoothing=smoothing, edge_threshold=edge_threshold,
            adaptive_threshold=adaptive_threshold,
            adaptive_algorithm=adaptive_algorithm,
            edge_algorithm='distribution', get_error=True,
            error_weighting=weighted, jobs=max_cores)
        var = std**2

        log.info('Interpolating and summing exposure maps.')
        if cube:
            expmap = np.zeros(out_shape[1:], dtype=float)
        else:
            expmap = np.zeros(out_shape, dtype=float)
        for i, expval in enumerate(expvals):
            inx, iny = corners[i][0]
            outx, outy = corners[i][1]
            outx -= minx
            outy -= miny
            exp_out = warp_image(
                expval, inx, iny, outx, outy,
                output_shape=expmap.shape, cval=0,
                order=1, interpolation_order=1)
            expmap += exp_out
    else:
        # interpolate corners for approximate warp solution
        log.info('Interpolating all images.')

        flx = []
        vr = []
        expmap = np.zeros(out_shape)
        for i, (flxval, errval, expval) in \
                enumerate(zip(flxvals, errvals, expvals)):
            inx, iny = corners[i][0]
            outx, outy = corners[i][1]
            outx -= minx
            outy -= miny

            # flux image
            flx.append(
                warp_image(flxval, inx, iny, outx, outy,
                           output_shape=out_shape, cval=np.nan,
                           order=1, interpolation_order=1))

            # var image
            vr.append(
                warp_image(errval**2, inx, iny, outx, outy,
                           output_shape=out_shape, cval=np.nan,
                           order=1, interpolation_order=0))

            # exposure map image
            exp_out = warp_image(
                expval, inx, iny, outx, outy,
                output_shape=out_shape, cval=0,
                order=1, interpolation_order=0)
            expmap += exp_out

        if len(flx) > 1:
            log.info('{}-combining images.'.format(method.title()))
            flux, var = combine_images(
                flx, variance=vr, method=method, weighted=weighted,
                robust=robust, sigma=sigma, maxiters=maxiters)
        else:
            flux, var = flx[0], vr[0]

    if cube:
        # reconstruct as primary wcs
        key = wcskey.strip().upper()
        wcs_key_set = ['CTYPE1', 'CTYPE2', 'CUNIT1', 'CUNIT2',
                       'CRPIX1', 'CRPIX2', 'CRVAL1', 'CRVAL2',
                       'CDELT1', 'CDELT2', 'CROTA2', 'SPECSYS',
                       f'CTYPE1{key}', f'CTYPE2{key}', f'CTYPE3{key}',
                       f'CUNIT1{key}', f'CUNIT2{key}', f'CUNIT3{key}',
                       f'CRPIX1{key}', f'CRPIX2{key}', f'CRPIX3{key}',
                       f'CRVAL1{key}', f'CRVAL2{key}', f'CRVAL3{key}',
                       f'CDELT1{key}', f'CDELT2{key}', f'CDELT3{key}',
                       f'RADESYS{key}', f'EQUINOX{key}', f'SPECSYS{key}']
        tmp = out_header.copy()
        for wkey in wcs_key_set:
            if wkey in out_header:
                del out_header[wkey]
            if wkey.endswith(key) and wkey in tmp:
                # swap coords 1 and 3 (to make it wave, RA, Dec)
                new_key = wkey[:-1].replace('3', '9999')
                new_key = new_key.replace('1', '3').replace('9999', '1')
                hdinsert(out_header, new_key, tmp[wkey], tmp.comments[wkey])

    # fix source position estimate too
    if 'SRCPOSX' in out_header and 'SRCPOSY' in out_header:
        coord = ([out_header['SRCPOSX']],
                 [out_header['SRCPOSY']])
        first_wcs = WCS(hdr_list[0], naxis=2)
        out_wcs = WCS(out_header, naxis=2)
        sxy = first_wcs.wcs_pix2world(*coord, 0)
        new_xy = out_wcs.wcs_world2pix(*sxy, 0)
        out_header['SRCPOSX'] = new_xy[0][0]
        out_header['SRCPOSY'] = new_xy[1][0]

    if cube:
        # correct flux for pixel size change
        # before: pixel x slit width in pixels
        # after: pixel x pixel
        flux /= float_slitw
        var /= float_slitw**2

    return out_header, flux, var, expmap
