import numpy as np
from numba import prange

from .._common import jitted, norm2d

Big = 1.0e5
eps = 1.0e-15
epsin = 5


@jitted("f8(i4, i4, f8, f8, f8, f8, f8)")
def t_ana(i, j, dz, dx, zsa, xsa, vzero):
    """Calculate analytical times in homogeneous model."""
    return vzero * ((dz * (i - zsa)) ** 2.0 + (dx * (j - xsa)) ** 2.0) ** 0.5


@jitted("UniTuple(f8, 3)(i4, i4, f8, f8, f8, f8, f8)")
def t_anad(i, j, dz, dx, zsa, xsa, vzero):
    """Calculate analytical times in homogeneous model and derivatives of times."""
    t = t_ana(i, j, dz, dx, zsa, xsa, vzero)

    if t > 0.0:
        tmp = vzero ** 2.0 / t
        tzc = (i - zsa) * dz * tmp
        txc = (j - xsa) * dx * tmp
    else:
        tzc = 0.0
        txc = 0.0

    return t, tzc, txc


@jitted("f8(f8, f8, f8, f8, f8, f8, f8, f8, f8, f8, f8, f8, f8, i4, i4)")
def delta(
    t1,
    tauv,
    taue,
    tauev,
    t0c,
    tzc,
    txc,
    dzi,
    dxi,
    dz2i,
    dx2i,
    vzero,
    vref,
    sgntz,
    sgntx,
):
    """Solve quadratic equation."""
    ta = tauev + taue - tauv
    tb = tauev - taue + tauv

    apoly = dz2i + dx2i
    bpoly = 4.0 * (sgntx * txc * dxi + sgntz * tzc * dzi) - 2.0 * (
        ta * dx2i + tb * dz2i
    )
    cpoly = (
        (ta ** 2.0 * dx2i)
        + (tb ** 2.0 * dz2i)
        - 4.0 * (sgntx * txc * dxi * ta + sgntz * tzc * dzi * tb)
        + 4.0 * (vzero ** 2.0 - vref ** 2.0)
    )
    dpoly = bpoly ** 2.0 - 4.0 * apoly * cpoly

    return 0.5 * (dpoly ** 0.5 - bpoly) / apoly + t0c if dpoly >= 0.0 else t1


@jitted(
    "void(f8[:, :], i4[:, :, :], f8[:, :], UniTuple(f8, 6), f8, f8, f8, f8, f8, i4, i4, i4, i4, i4, i4, i4, i4, b1)"
)
def sweep(
    tt,
    ttsgn,
    slow,
    dargs,
    zsi,
    xsi,
    zsa,
    xsa,
    vzero,
    i,
    j,
    sgnvz,
    sgnvx,
    sgntz,
    sgntx,
    nz,
    nx,
    grad,
):
    """Sweep in given direction."""
    dz, dx, dzi, dxi, dz2i, dx2i = dargs
    i1 = i - sgnvz
    j1 = j - sgnvx

    # Get local times of surrounding points
    tv = tt[i - sgntz, j]
    te = tt[i, j - sgntx]
    tev = tt[i - sgntz, j - sgntx]

    # 1D operators (refracted times)
    # First dimension (Z axis)
    vref = min(slow[i1, max(j - 1, 0)], slow[i1, min(j, nx - 2)])
    t1d1 = tv + dz * vref

    # Second dimension (X axis)
    vref = min(slow[max(i - 1, 0), j1], slow[min(i, nz - 2), j1])
    t1d2 = te + dx * vref

    t1d = min(t1d1, t1d2)

    # 2D operators
    t2d = Big
    vref = slow[i1, j1]

    # Choose plane wave or spherical
    # Test for plane wave
    if np.abs(i - zsi) > epsin or np.abs(j - xsi) > epsin:
        # 4 points operator if possible, otherwise do three points
        if tv <= te + dx * vref and te <= tv + dz * vref and te >= tev and tv >= tev:
            ta = tev + te - tv
            tb = tev - te + tv
            t2d = (
                (tb * dz2i + ta * dx2i)
                + (4.0 * vref ** 2.0 * (dz2i + dx2i) - dz2i * dx2i * (ta - tb) ** 2.0)
                ** 0.5
            ) / (dz2i + dx2i)

        # Two 3 points operators
        elif (
            te - tev <= dz ** 2.0 * vref / (dx ** 2.0 + dz ** 2.0) ** 0.5
            and te - tev > 0.0
        ):
            t2d = te + dx * (vref ** 2.0 - ((te - tev) / dz) ** 2.0) ** 0.5

        elif (
            tv - tev <= dx ** 2.0 * vref / (dx ** 2.0 + dz ** 2.0) ** 0.5
            and tv - tev > 0.0
        ):
            t2d = tv + dz * (vref ** 2.0 - ((tv - tev) / dx) ** 2.0) ** 0.5

    # Test for spherical
    else:
        # Do spherical operator if conditions ok
        if tv < te + dx * vref and te < tv + dz * vref and te >= tev and tv >= tev:
            t0c, tzc, txc = t_anad(i, j, dz, dx, zsa, xsa, vzero)
            tauv = tv - t_ana(i - sgntz, j, dz, dx, zsa, xsa, vzero)
            taue = te - t_ana(i, j - sgntx, dz, dx, zsa, xsa, vzero)
            tauev = tev - t_ana(i - sgntz, j - sgntx, dz, dx, zsa, xsa, vzero)

            t2d = delta(
                t2d,
                tauv,
                taue,
                tauev,
                t0c,
                tzc,
                txc,
                dzi,
                dxi,
                dz2i,
                dx2i,
                vzero,
                vref,
                sgntz,
                sgntx,
            )
            if t2d < tv or t2d < te:
                t2d = Big

    # Select minimum time
    t0 = tt[i, j]
    tt[i, j] = min(t0, t1d, t2d)

    # Compute gradient according to minimum time direction
    if grad and tt[i, j] != t0:
        if tt[i, j] == t1d1:
            ttsgn[i, j, 0] = sgntz
            ttsgn[i, j, 1] = 0
        elif tt[i, j] == t1d2:
            ttsgn[i, j, 0] = 0
            ttsgn[i, j, 1] = sgntx
        else:
            ttsgn[i, j, 0] = sgntz
            ttsgn[i, j, 1] = sgntx


@jitted("void(f8[:, :], i4[:, :, :], f8[:, :], f8, f8, f8, f8, f8, f8, f8, i4, i4, b1)")
def sweep2d(tt, ttsgn, slow, dz, dx, zsi, xsi, zsa, xsa, vzero, nz, nx, grad):
    """Perform one sweeping."""
    dzi = 1.0 / dz
    dxi = 1.0 / dx
    dz2i = dzi / dz
    dx2i = dxi / dx
    dargs = (dz, dx, dzi, dxi, dz2i, dx2i)

    for j in range(1, nx):
        for i in range(1, nz):
            sweep(
                tt,
                ttsgn,
                slow,
                dargs,
                zsi,
                xsi,
                zsa,
                xsa,
                vzero,
                i,
                j,
                1,
                1,
                1,
                1,
                nz,
                nx,
                grad,
            )

        for i in range(nz - 2, -1, -1):
            sweep(
                tt,
                ttsgn,
                slow,
                dargs,
                zsi,
                xsi,
                zsa,
                xsa,
                vzero,
                i,
                j,
                0,
                1,
                -1,
                1,
                nz,
                nx,
                grad,
            )

    for j in range(nx - 2, -1, -1):
        for i in range(1, nz):
            sweep(
                tt,
                ttsgn,
                slow,
                dargs,
                zsi,
                xsi,
                zsa,
                xsa,
                vzero,
                i,
                j,
                1,
                0,
                1,
                -1,
                nz,
                nx,
                grad,
            )

        for i in range(nz - 2, -1, -1):
            sweep(
                tt,
                ttsgn,
                slow,
                dargs,
                zsi,
                xsi,
                zsa,
                xsa,
                vzero,
                i,
                j,
                0,
                0,
                -1,
                -1,
                nz,
                nx,
                grad,
            )


@jitted("Tuple((f8[:, :], f8[:, :, :], f8))(f8[:, :], f8, f8, f8, f8, i4, b1)")
def fteik2d(slow, dz, dx, zsrc, xsrc, nsweep=2, grad=False):
    """Calculate traveltimes given a 2D velocity model."""
    # Parameters
    nz, nx = np.shape(slow)

    # Check inputs
    condz = 0.0 <= zsrc <= dz * nz
    condx = 0.0 <= xsrc <= dx * nx
    if not (condz and condx):
        raise ValueError("source out of bound")

    # Convert src to grid position and try and take into account machine precision
    zsa = zsrc / dz
    xsa = xsrc / dx

    # Try to handle edges simply for source due to precision
    zsa = zsa - eps if zsa > nz else zsa
    xsa = xsa - eps if xsa > nx else xsa

    # Grid points to initialize source
    zsi = int(zsa)
    xsi = int(xsa)
    vzero = slow[zsi, xsi]

    # Allocate work array
    nz += 1
    nx += 1
    tt = np.full((nz, nx), Big, dtype=np.float64)

    if grad:
        ttgrad = np.zeros((nz, nx, 2), dtype=np.float64)
        ttsgn = np.zeros((nz, nx, 2), dtype=np.int32)

    else:
        ttgrad = np.empty((0, 0, 0), dtype=np.float64)
        ttsgn = np.empty((0, 0, 0), dtype=np.int32)

    # Do our best to initialize source
    dzu = np.abs(zsa - float(zsi))
    dzd = 1.0 - dzu
    dxw = np.abs(xsa - float(xsi))
    dxe = 1.0 - dxw

    # Source seems close enough to a grid point in X and Y direction
    dzv_min = min(dzu, dzd)
    dzh_min = min(dxw, dxe)
    if dzv_min < eps and dzh_min < eps:
        zsa = np.round(zsa)
        xsa = np.round(xsa)
        iflag = 1

    # At least one of coordinates not close to any grid point in X and Y direction
    elif dzv_min > eps or dzh_min > eps:
        zsa = np.round(zsa) if dzv_min < eps else zsa
        xsa = np.round(xsa) if dzh_min < eps else xsa
        iflag = 2

    # Oops we are lost, not sure this happens - fix src to nearest grid point
    else:
        zsa = np.round(zsa)
        xsa = np.round(xsa)
        iflag = 3

    # We know where src is - start first propagation
    if iflag == 2:
        td = np.full(max(nz, nx), Big, dtype=np.float64)

        dzu = np.abs(zsa - float(zsi))
        dzd = 1.0 - dzu
        dxw = np.abs(xsa - float(xsi))
        dxe = 1.0 - dxw

        # First initialize 4 points around source
        iterables = (
            (zsi, xsi),
            (zsi + 1, xsi),
            (zsi, xsi + 1),
            (zsi + 1, xsi + 1),
        )
        for i, j in iterables:
            tt[i, j], tzc, txc = t_anad(i, j, dz, dx, zsa, xsa, vzero)

            if grad:
                ttgrad[i, j, 0] = tzc
                ttgrad[i, j, 1] = txc

        dxi = 1.0 / dx
        dx2i = dxi / dx
        td[xsi + 1] = vzero * dxe * dx
        for j in range(xsi + 2, nx):
            vref = slow[zsi, j - 1]
            td[j] = td[j - 1] + dx * vref
            tauv = td[j] - vzero * np.abs(j - xsa) * dx
            tauev = td[j - 1] - vzero * np.abs(j - xsa - 1.0) * dx

            dzi = 1.0 / dzd
            dz2i = dz / dzd / dzd
            taue = tt[zsi + 1, j - 1] - t_ana(zsi + 1, j - 1, dz, dx, zsa, xsa, vzero)
            t0c, tzc, txc = t_anad(zsi + 1, j, dz, dx, zsa, xsa, vzero)
            tt[zsi + 1, j] = delta(
                tt[zsi + 1, j],
                tauv,
                taue,
                tauev,
                t0c,
                tzc,
                txc,
                dzi,
                dxi,
                dz2i,
                dx2i,
                vzero,
                vref,
                1,
                1,
            )
            if grad:
                ttsgn[zsi + 1, j, 0] = 1
                ttsgn[zsi + 1, j, 1] = 1

            if dzu > 0.0:
                dzi = 1.0 / dzu
                dz2i = dz / dzu / dzu
                taue = tt[zsi, j - 1] - t_ana(zsi, j - 1, dz, dx, zsa, xsa, vzero)
                t0c, tzc, txc = t_anad(zsi, j, dz, dx, zsa, xsa, vzero)
                tt[zsi, j] = delta(
                    tt[zsi, j],
                    tauv,
                    taue,
                    tauev,
                    t0c,
                    tzc,
                    txc,
                    dzi,
                    dxi,
                    dz2i,
                    dx2i,
                    vzero,
                    vref,
                    -1,
                    1,
                )
                if grad:
                    ttsgn[zsi, j, 0] = -1
                    ttsgn[zsi, j, 1] = 1

        td[xsi] = vzero * dxw * dx
        for j in range(xsi - 1, -1, -1):
            vref = slow[zsi, j]
            td[j] = td[j + 1] + dx * vref
            tauv = td[j] - vzero * np.abs(j - xsa) * dx
            tauev = td[j + 1] - vzero * np.abs(j - xsa + 1.0) * dx

            dzi = 1.0 / dzd
            dz2i = dz / dzd / dzd
            taue = tt[zsi + 1, j + 1] - t_ana(zsi + 1, j + 1, dz, dx, zsa, xsa, vzero)
            t0c, tzc, txc = t_anad(zsi + 1, j, dz, dx, zsa, xsa, vzero)
            tt[zsi + 1, j] = delta(
                tt[zsi + 1, j],
                tauv,
                taue,
                tauev,
                t0c,
                tzc,
                txc,
                dzi,
                dxi,
                dz2i,
                dx2i,
                vzero,
                vref,
                1,
                -1,
            )
            if grad:
                ttsgn[zsi + 1, j, 0] = 1
                ttsgn[zsi + 1, j, 1] = -1

            if dzu > 0.0:
                dzi = 1.0 / dzu
                dz2i = dz / dzu / dzu
                taue = tt[zsi + 1, j + 1] - t_ana(
                    zsi + 1, j + 1, dz, dx, zsa, xsa, vzero
                )
                t0c, tzc, txc = t_anad(zsi, j, dz, dx, zsa, xsa, vzero)
                tt[zsi, j] = delta(
                    tt[zsi, j],
                    tauv,
                    taue,
                    tauev,
                    t0c,
                    tzc,
                    txc,
                    dzi,
                    dxi,
                    dz2i,
                    dx2i,
                    vzero,
                    vref,
                    -1,
                    -1,
                )
                if grad:
                    ttsgn[zsi, j, 0] = -1
                    ttsgn[zsi, j, 1] = -1

        dzi = 1.0 / dz
        dz2i = dzi / dz
        td[:] = Big
        td[zsi + 1] = vzero * dzd * dz
        for i in range(zsi + 2, nz):
            vref = slow[i - 1, xsi]
            td[i] = td[i - 1] + dz * vref
            taue = td[i] - vzero * np.abs(i - zsa) * dz
            tauev = td[i - 1] - vzero * np.abs(i - zsa - 1.0) * dz

            dxi = 1.0 / dxe
            dx2i = dx / dxe / dxe
            tauv = tt[i - 1, xsi + 1] - t_ana(i - 1, xsi + 1, dz, dx, zsa, xsa, vzero)
            t0c, tzc, txc = t_anad(i, xsi + 1, dz, dx, zsa, xsa, vzero)
            tt[i, xsi + 1] = delta(
                tt[i, xsi + 1],
                tauv,
                taue,
                tauev,
                t0c,
                tzc,
                txc,
                dzi,
                dxi,
                dz2i,
                dx2i,
                vzero,
                vref,
                1,
                1,
            )
            if grad:
                ttsgn[i, xsi + 1, 0] = 1
                ttsgn[i, xsi + 1, 1] = 1

            if dxw > 0.0:
                dxi = 1.0 / dxw
                dx2i = dx / dxw / dxw
                tauv = tt[i - 1, xsi] - t_ana(i - 1, xsi, dz, dx, zsa, xsa, vzero)
                t0c, tzc, txc = t_anad(i, xsi, dz, dx, zsa, xsa, vzero)
                tt[i, xsi] = delta(
                    tt[i, xsi],
                    tauv,
                    taue,
                    tauev,
                    t0c,
                    tzc,
                    txc,
                    dzi,
                    dxi,
                    dz2i,
                    dx2i,
                    vzero,
                    vref,
                    1,
                    -1,
                )
                if grad:
                    ttsgn[i, xsi, 0] = 1
                    ttsgn[i, xsi, 1] = -1

        td[zsi] = vzero * dzu * dz
        for i in range(zsi - 1, -1, -1):
            vref = slow[i, xsi]
            td[i] = td[i + 1] + dz * vref
            taue = td[i] - vzero * np.abs(i - zsa) * dz
            tauev = td[i + 1] - vzero * np.abs(i - zsa + 1.0) * dz

            dxi = 1.0 / dxe
            dx2i = dx / dxe / dxe
            tauv = tt[i + 1, xsi + 1] - t_ana(i + 1, xsi + 1, dz, dx, zsa, xsa, vzero)
            t0c, tzc, txc = t_anad(i, xsi + 1, dz, dx, zsa, xsa, vzero)
            tt[i, xsi + 1] = delta(
                tt[i, xsi + 1],
                tauv,
                taue,
                tauev,
                t0c,
                tzc,
                txc,
                dzi,
                dxi,
                dz2i,
                dx2i,
                vzero,
                vref,
                -1,
                1,
            )
            if grad:
                ttsgn[i, xsi + 1, 0] = -1
                ttsgn[i, xsi + 1, 1] = 1

            if dxw > 0.0:
                dxi = 1.0 / dxw
                dx2i = dx / dxw / dxw
                tauv = tt[i + 1, xsi] - t_ana(i + 1, xsi, dz, dx, zsa, xsa, vzero)
                t0c, tzc, txc = t_anad(i, xsi, dz, dx, zsa, xsa, vzero)
                tt[i, xsi] = delta(
                    tt[i, xsi],
                    tauv,
                    taue,
                    tauev,
                    t0c,
                    tzc,
                    txc,
                    dzi,
                    dxi,
                    dz2i,
                    dx2i,
                    vzero,
                    vref,
                    -1,
                    -1,
                )
                if grad:
                    ttsgn[i, xsi, 0] = -1
                    ttsgn[i, xsi, 1] = -1

    else:
        tt[int(zsa), int(xsa)] = 0.0

    # Start sweeping
    for _ in range(nsweep):
        sweep2d(tt, ttsgn, slow, dz, dx, zsi, xsi, zsa, xsa, vzero, nz, nx, grad)

    if grad:
        for i in range(nz):
            for j in range(nx):
                sgntz = ttsgn[i, j, 0]
                if sgntz != 0:
                    t1 = tt[i - sgntz, j]
                    ttgrad[i, j, 0] = sgntz * (tt[i, j] - t1) / dz

                sgntx = ttsgn[i, j, 1]
                if sgntx != 0:
                    t1 = tt[i, j - sgntx]
                    ttgrad[i, j, 1] = sgntx * (tt[i, j] - t1) / dx

                # Normalize gradients
                gn = norm2d(ttgrad[i, j, 0], ttgrad[i, j, 1])
                if gn > 0.0:
                    ttgrad[i, j] /= gn

    return tt, ttgrad, vzero


@jitted(
    "Tuple((f8[:, :, :], f8[:, :, :, :], f8[:]))(f8[:, :], f8, f8, f8[:], f8[:], i4, b1)",
    parallel=True,
)
def fteik2d_vectorized(slow, dz, dx, zsrc, xsrc, nsweep=2, grad=False):
    """Calculate traveltimes in parallel for different sources."""
    nsrc = len(zsrc)
    nz, nx = slow.shape
    tt = np.empty((nsrc, nz + 1, nx + 1), dtype=np.float64)
    ttgrad = (
        np.empty((nsrc, nz + 1, nx + 1, 2), dtype=np.float64)
        if grad
        else np.empty((nsrc, 0, 0, 0), dtype=np.float64)
    )
    vzero = np.empty(nsrc, dtype=np.float64)
    for i in prange(nsrc):
        tt[i], ttgrad[i], vzero[i] = fteik2d(
            slow, dz, dx, zsrc[i], xsrc[i], nsweep, grad
        )

    return tt, ttgrad, vzero


@jitted
def solve2d(slow, dz, dx, src, nsweep=2, grad=False):
    """Solve Eikonal."""
    if src.ndim == 1:
        return fteik2d(slow, dz, dx, src[0], src[1], nsweep, grad)

    else:
        return fteik2d_vectorized(slow, dz, dx, src[:, 0], src[:, 1], nsweep, grad)
