#!/usr/bin/env python
u"""
read_tide_model.py (06/2021)
Reads files for a tidal model and makes initial calculations to run tide program
Includes functions to extract tidal harmonic constants from OTIS tide models for
    given locations

Reads OTIS format tidal solutions provided by Ohio State University and ESR
    http://volkov.oce.orst.edu/tides/region.html
    https://www.esr.org/research/polar-tide-models/list-of-polar-tide-models/
    ftp://ftp.esr.org/pub/datasets/tmd/

INPUTS:
    ilon: longitude to interpolate
    ilat: latitude to interpolate
    grid_file: grid file for model
    model_file: model file containing each constituent
    EPSG: projection of tide model data

OPTIONS:
    TYPE: tidal variable to run
        z: heights
        u: horizontal transport velocities
        U: horizontal depth-averaged transport
        v: vertical transport velocities
        V: vertical depth-averaged transport
    METHOD: interpolation method
        bilinear: quick bilinear interpolation
        spline: scipy bivariate spline interpolation
        linear, nearest: scipy regular grid interpolations
    EXTRAPOLATE: extrapolate model using nearest-neighbors
    CUTOFF: extrapolation cutoff in kilometers
        set to np.inf to extrapolate for all points
    GRID: binary file type to read
        ATLAS: reading a global solution with localized solutions
        OTIS: combined global solution

OUTPUTS:
    amplitude: amplitudes of tidal constituents
    phase: phases of tidal constituents
    D: bathymetry of tide model
    constituents: list of model constituents

PYTHON DEPENDENCIES:
    numpy: Scientific Computing Tools For Python
        https://numpy.org
        https://numpy.org/doc/stable/user/numpy-for-matlab-users.html
    scipy: Scientific Tools for Python
        https://docs.scipy.org/doc/

PROGRAM DEPENDENCIES:
    convert_ll_xy.py: converts lat/lon points to and from projected coordinates
    bilinear_interp.py: bilinear interpolation of data to coordinates
    nearest_extrap.py: nearest-neighbor extrapolation of data to coordinates

UPDATE HISTORY:
    Updated 06/2021: fix tidal currents for bilinear interpolation
        check for nan points when reading elevation and transport files
    Updated 05/2021: added option for extrapolation cutoff in kilometers
    Updated 03/2021: add extrapolation check where there are no invalid points
        prevent ComplexWarning for fill values when calculating amplitudes
        can read from single constituent TPXO9 ATLAS binary files
        replaced numpy bool/int to prevent deprecation warnings
    Updated 02/2021: set invalid values to nan in extrapolation
        replaced numpy bool to prevent deprecation warning
    Updated 12/2020: added valid data extrapolation with nearest_extrap
    Updated 09/2020: set bounds error to false for regular grid interpolations
        adjust dimensions of input coordinates to be iterable
        use masked arrays with atlas models and grids. make 2' grid with nearest
    Updated 08/2020: check that interpolated points are within range of model
        replaced griddata interpolation with scipy regular grid interpolators
    Updated 07/2020: added function docstrings. separate bilinear interpolation
        update griddata interpolation. changed TYPE variable to keyword argument
    Updated 06/2020: output currents as numpy masked arrays
        use argmin and argmax in bilinear interpolation
    Updated 11/2019: interpolate heights and fluxes to numpy masked arrays
    Updated 09/2019: output as numpy masked arrays instead of nan-filled arrays
    Updated 01/2019: decode constituents for Python3 compatibility
    Updated 08/2018: added option GRID for using ATLAS outputs that
        combine both global and localized tidal solutions
        added multivariate spline interpolation option
    Updated 07/2018: added different interpolation methods
    Updated 09/2017: Adapted for Python
"""
import os
import numpy as np
import scipy.interpolate
from pyTMD.convert_ll_xy import convert_ll_xy
from pyTMD.bilinear_interp import bilinear_interp
from pyTMD.nearest_extrap import nearest_extrap
import matplotlib.pyplot as plt

#-- PURPOSE: extract tidal harmonic constants from tide models at coordinates
def extract_tidal_constants(ilon, ilat, grid_file, model_file, EPSG, TYPE='z',
    METHOD='spline', EXTRAPOLATE=False, CUTOFF=10.0, GRID='OTIS'):
    """
    Reads files for an OTIS-formatted tidal model
    Makes initial calculations to run the tide program
    Spatially interpolates tidal constituents to input coordinates

    Arguments
    ---------
    ilon: longitude to interpolate
    ilat: latitude to interpolate
    grid_file: grid file for model
    model_file: model file containing each constituent
    EPSG: projection of tide model data

    Keyword arguments
    -----------------
    TYPE: tidal variable to read
        z: heights
        u: horizontal transport velocities
        U: horizontal depth-averaged transport
        v: vertical transport velocities
        V: vertical depth-averaged transport
    METHOD: interpolation method
        bilinear: quick bilinear interpolation
        spline: scipy bivariate spline interpolation
        linear, nearest: scipy regular grid interpolations
    EXTRAPOLATE: extrapolate model using nearest-neighbors
    CUTOFF: extrapolation cutoff in kilometers
        set to np.inf to extrapolate for all points
    GRID: binary file type to read
        ATLAS: reading a global solution with localized solutions
        OTIS: combined global solution

    Returns
    -------
    amplitude: amplitudes of tidal constituents
    phase: phases of tidal constituents
    D: bathymetry of tide model
    constituents: list of model constituents
    """
    #-- read the OTIS-format tide grid file
    if (GRID == 'ATLAS'):
        #-- if reading a global solution with localized solutions
        x0,y0,hz0,mz0,iob,dt,pmask,local = read_atlas_grid(grid_file)
        xi,yi,hz = combine_atlas_model(x0,y0,hz0,pmask,local,VARIABLE='depth')
        mz = create_atlas_mask(x0,y0,mz0,local,VARIABLE='depth')
    else:
        #-- if reading a single OTIS solution
        xi,yi,hz,mz,iob,dt = read_tide_grid(grid_file)
    #-- adjust dimensions of input coordinates to be iterable
    #-- run wrapper function to convert coordinate systems of input lat/lon
    x,y = convert_ll_xy(np.atleast_1d(ilon),np.atleast_1d(ilat),EPSG,'F')
    #-- grid step size of tide model
    dx = xi[1] - xi[0]
    dy = yi[1] - yi[0]

    if (TYPE != 'z'):
        mz,mu,mv = Muv(hz)
        hu,hv = Huv(hz)

    #-- if global: extend limits
    GLOBAL = False
    #-- replace original values with extend arrays/matrices
    if ((xi[-1] - xi[0]) == (360.0 - dx)) & (EPSG == '4326'):
        xi = extend_array(xi, dx)
        hz = extend_matrix(hz)
        mz = extend_matrix(mz)
        #-- set global flag
        GLOBAL = True

    #-- adjust longitudinal convention of input latitude and longitude
    #-- to fit tide model convention
    if (np.min(x) < np.min(xi)) & (EPSG == '4326'):
        lt0, = np.nonzero(x < 0)
        x[lt0] += 360.0
    if (np.max(x) > np.max(xi)) & (EPSG == '4326'):
        gt180, = np.nonzero(x > 180)
        x[gt180] -= 360.0
    #-- determine if any input points are outside of the model bounds
    invalid = (x < xi.min()) | (x > xi.max()) | (y < yi.min()) | (y > yi.max())

    #-- masks zero values
    hz = np.ma.array(hz,mask=(hz==0))
    if (TYPE != 'z'):
        #-- replace original values with extend matrices
        if GLOBAL:
            hu = extend_matrix(hu)
            hv = extend_matrix(hv)
            mu = extend_matrix(mu)
            mv = extend_matrix(mv)
        #-- masks zero values
        hu = np.ma.array(hu,mask=(hu==0))
        hv = np.ma.array(hv,mask=(hv==0))

    #-- interpolate depth and mask to output points
    if (METHOD == 'bilinear'):
        #-- use quick bilinear to interpolate values
        D = bilinear_interp(xi,yi,hz,x,y)
        mz1 = bilinear_interp(xi,yi,mz,x,y)
        mz1 = np.floor(mz1).astype(mz.dtype)
        if (TYPE != 'z'):
            mu1 = bilinear_interp(xi,yi,mu,x,y)
            mu1 = np.floor(mu1).astype(mu.dtype)
            mv1 = bilinear_interp(xi,yi,mv,x,y)
            mv1 = np.floor(mv1).astype(mz.dtype)
    elif (METHOD == 'spline'):
        #-- use scipy bivariate splines to interpolate values
        f1=scipy.interpolate.RectBivariateSpline(xi,yi,hz.T,kx=1,ky=1)
        f2=scipy.interpolate.RectBivariateSpline(xi,yi,mz.T,kx=1,ky=1)
        D = f1.ev(x,y)
        mz1 = np.floor(f2.ev(x,y)).astype(mz.dtype)
        if (TYPE != 'z'):
            f3=scipy.interpolate.RectBivariateSpline(xi,yi,mu.T,kx=1,ky=1)
            f4=scipy.interpolate.RectBivariateSpline(xi,yi,mv.T,kx=1,ky=1)
            mu1 = np.floor(f3.ev(x,y)).astype(mu.dtype)
            mv1 = np.floor(f4.ev(x,y)).astype(mv.dtype)
    else:
        #-- use scipy regular grid to interpolate values for a given method
        r1 = scipy.interpolate.RegularGridInterpolator((yi,xi),hz,
            method=METHOD,bounds_error=False)
        r2 = scipy.interpolate.RegularGridInterpolator((yi,xi),mz,
            method=METHOD,bounds_error=False,fill_value=0)
        D = r1.__call__(np.c_[y,x])
        mz1 = np.floor(r2.__call__(np.c_[y,x])).astype(mz.dtype)
        if (TYPE != 'z'):
            r3 = scipy.interpolate.RegularGridInterpolator((yi,xi),mu,
                method=METHOD,bounds_error=False,fill_value=0)
            r4 = scipy.interpolate.RegularGridInterpolator((yi,xi),mv,
                method=METHOD,bounds_error=False,fill_value=0)
            mu1 = np.floor(r3.__call__(np.c_[y,x])).astype(mu.dtype)
            mv1 = np.floor(r4.__call__(np.c_[y,x])).astype(mv.dtype)

    #-- u and v are velocities in cm/s
    if TYPE in ('v','u'):
        unit_conv = (D/100.0)
    #-- U and V are transports in m^2/s
    elif TYPE in ('V','U'):
        unit_conv = 1.0

    #-- read and interpolate each constituent
    if isinstance(model_file,list):
        constituents = [read_constituents(m)[0].pop() for m in model_file]
        nc = len(constituents)
    else:
        constituents,nc = read_constituents(model_file)
    #-- number of output data points
    npts = len(D)
    amplitude = np.ma.zeros((npts,nc))
    amplitude.mask = np.zeros((npts,nc),dtype=bool)
    ph = np.ma.zeros((npts,nc))
    ph.mask = np.zeros((npts,nc),dtype=bool)
    for i,c in enumerate(constituents):
        if (TYPE == 'z'):
            #-- read constituent from elevation file
            if (GRID == 'ATLAS'):
                z0,zlocal = read_atlas_elevation(model_file,i,c)
                xi,yi,z=combine_atlas_model(x0,y0,z0,pmask,zlocal,VARIABLE='z')
            elif isinstance(model_file,list):
                z = read_elevation_file(model_file[i],0)
            else:
                z = read_elevation_file(model_file,i)
            #-- replace original values with extend matrices
            if GLOBAL:
                z = extend_matrix(z)
            #-- interpolate amplitude and phase of the constituent
            z1 = np.ma.zeros((npts),dtype=z.dtype)
            if (METHOD == 'bilinear'):
                #-- replace zero values with nan
                z[z==0] = np.nan
                #-- use quick bilinear to interpolate values
                z1.data[:] = bilinear_interp(xi,yi,z,x,y,dtype=np.complex128)
                #-- replace nan values with fill_value
                z1.mask = (np.isnan(z1.data) | (~mz1.astype(bool)))
                z1.data[z1.mask] = z1.fill_value
            elif (METHOD == 'spline'):
                #-- use scipy bivariate splines to interpolate values
                f1 = scipy.interpolate.RectBivariateSpline(xi,yi,
                    z.real.T,kx=1,ky=1)
                f2 = scipy.interpolate.RectBivariateSpline(xi,yi,
                    z.imag.T,kx=1,ky=1)
                z1.data.real = f1.ev(x,y)
                z1.data.imag = f2.ev(x,y)
                #-- replace zero values with fill_value
                z1.mask = (~mz1.astype(bool))
                z1.data[z1.mask] = z1.fill_value
            else:
                #-- use scipy regular grid to interpolate values
                r1 = scipy.interpolate.RegularGridInterpolator((yi,xi),z,
                    method=METHOD,bounds_error=False,fill_value=z1.fill_value)
                z1 = np.ma.zeros((npts),dtype=z.dtype)
                z1.data[:] = r1.__call__(np.c_[y,x])
                #-- replace invalid values with fill_value
                z1.mask = (z1.data == z1.fill_value) | (~mz1.astype(bool))
                z1.data[z1.mask] = z1.fill_value
            #-- extrapolate data using nearest-neighbors
            if EXTRAPOLATE and np.any(z1.mask):
                #-- find invalid data points
                inv, = np.nonzero(z1.mask)
                #-- replace zero values with nan
                z[z==0] = np.nan
                #-- extrapolate points within cutoff of valid model points
                z1.data[inv] = nearest_extrap(xi,yi,z,x[inv],y[inv],
                    dtype=np.complex128,cutoff=CUTOFF,EPSG=EPSG)
                #-- replace nan values with fill_value
                z1.mask[inv] = np.isnan(z1.data[inv])
                z1.data[z1.mask] = z1.fill_value
            #-- amplitude and phase of the constituent
            amplitude.data[:,i] = np.abs(z1.data)
            amplitude.mask[:,i] = np.copy(z1.mask)
            ph.data[:,i] = np.arctan2(-np.imag(z1.data),np.real(z1.data))
            ph.mask[:,i] = np.copy(z1.mask)
        elif TYPE in ('U','u'):
            #-- read constituent from transport file
            if (GRID == 'ATLAS'):
                u0,v0,uvlocal = read_atlas_transport(model_file,i,c)
                xi,yi,u=combine_atlas_model(x0,y0,u0,pmask,uvlocal,VARIABLE='u')
            elif isinstance(model_file,list):
                u,v = read_transport_file(model_file[i],0)
            else:
                u,v = read_transport_file(model_file,i)
            #-- replace original values with extend matrices
            if GLOBAL:
                u = extend_matrix(u)
            #-- x coordinates for u transports
            xu = xi - dx/2.0
            #-- interpolate amplitude and phase of the constituent
            u1 = np.ma.zeros((npts),dtype=u.dtype)
            if (METHOD == 'bilinear'):
                #-- replace zero values with nan
                u[u==0] = np.nan
                #-- use quick bilinear to interpolate values
                u1.data[:] = bilinear_interp(xu,yi,u,x,y,dtype=np.complex128)
                #-- replace nan values with fill_value
                u1.mask = (np.isnan(u1.data) | (~mu1.astype(bool)))
                u1.data[u1.mask] = u1.fill_value
            elif (METHOD == 'spline'):
                f1 = scipy.interpolate.RectBivariateSpline(xu,yi,
                    u.real.T,kx=1,ky=1)
                f2 = scipy.interpolate.RectBivariateSpline(xu,yi,
                    u.imag.T,kx=1,ky=1)
                u1.data.real = f1.ev(x,y)
                u1.data.imag = f2.ev(x,y)
                #-- replace zero values with fill_value
                u1.mask = (~mu1.astype(bool))
                u1.data[u1.mask] = u1.fill_value
            else:
                #-- use scipy regular grid to interpolate values
                r1 = scipy.interpolate.RegularGridInterpolator((yi,xu),u,
                    method=METHOD,bounds_error=False,fill_value=u1.fill_value)
                u1.data[:] = r1.__call__(np.c_[y,x])
                #-- replace invalid values with fill_value
                u1.mask = (u1.data == u1.fill_value) | (~mu1.astype(bool))
                u1.data[u1.mask] = u1.fill_value
            #-- extrapolate data using nearest-neighbors
            if EXTRAPOLATE and np.any(u1.mask):
                #-- find invalid data points
                inv, = np.nonzero(u1.mask)
                #-- replace zero values with nan
                u[u==0] = np.nan
                #-- extrapolate points within cutoff of valid model points
                u1.data[inv] = nearest_extrap(xu,yi,u,x[inv],y[inv],
                    dtype=np.complex128,cutoff=CUTOFF,EPSG=EPSG)
                #-- replace nan values with fill_value
                u1.mask[inv] = np.isnan(u1.data[inv])
                u1.data[u1.mask] = u1.fill_value
            #-- convert units
            #-- amplitude and phase of the constituent
            amplitude.data[:,i] = np.abs(u1.data)/unit_conv
            amplitude.mask[:,i] = np.copy(u1.mask)
            ph.data[:,i] = np.arctan2(-np.imag(u1),np.real(u1))
            ph.mask[:,i] = np.copy(u1.mask)
        elif TYPE in ('V','v'):
            #-- read constituent from transport file
            if (GRID == 'ATLAS'):
                u0,v0,uvlocal = read_atlas_transport(model_file,i,c)
                xi,yi,v = combine_atlas_model(x0,y0,v0,pmask,local,VARIABLE='v')
            elif isinstance(model_file,list):
                u,v = read_transport_file(model_file[i],0)
            else:
                u,v = read_transport_file(model_file,i)
            #-- replace original values with extend matrices
            if GLOBAL:
                v = extend_matrix(v)
            #-- y coordinates for v transports
            yv = yi - dy/2.0
            #-- interpolate amplitude and phase of the constituent
            v1 = np.ma.zeros((npts),dtype=v.dtype)
            if (METHOD == 'bilinear'):
                #-- replace zero values with nan
                v[v==0] = np.nan
                #-- use quick bilinear to interpolate values
                v1.data[:] = bilinear_interp(xi,yv,v,x,y,dtype=np.complex128)
                #-- replace nan values with fill_value
                v1.mask = (np.isnan(v1.data) | (~mv1.astype(bool)))
                v1.data[v1.mask] = v1.fill_value
            elif (METHOD == 'spline'):
                f1 = scipy.interpolate.RectBivariateSpline(xi,yv,
                    v.real.T,kx=1,ky=1)
                f2 = scipy.interpolate.RectBivariateSpline(xi,yv,
                    v.imag.T,kx=1,ky=1)
                v1.data.real = f1.ev(x,y)
                v1.data.imag = f2.ev(x,y)
                #-- replace zero values with fill_value
                v1.mask = (~mv1.astype(bool))
                v1.data[v1.mask] = v1.fill_value
            else:
                #-- use scipy regular grid to interpolate values
                r1 = scipy.interpolate.RegularGridInterpolator((yv,xi),v,
                    method=METHOD,bounds_error=False,fill_value=v1.fill_value)
                v1.data[:] = r1.__call__(np.c_[y,x])
                #-- replace invalid values with fill_value
                v1.mask = (v1.data == v1.fill_value) | (~mv1.astype(bool))
                v1.data[v1.mask] = v1.fill_value
            #-- extrapolate data using nearest-neighbors
            if EXTRAPOLATE and np.any(v1.mask):
                #-- find invalid data points
                inv, = np.nonzero(v1.mask)
                #-- replace zero values with nan
                v[z==v] = np.nan
                #-- extrapolate points within cutoff of valid model points
                v1.data[inv] = nearest_extrap(x,yv,v,x[inv],y[inv],
                    dtype=np.complex128,cutoff=CUTOFF,EPSG=EPSG)
                #-- replace nan values with fill_value
                v1.mask[inv] = np.isnan(v1.data[inv])
                v1.data[v1.mask] = v1.fill_value
            #-- convert units
            #-- amplitude and phase of the constituent
            amplitude.data[:,i] = np.abs(v1.data)/unit_conv
            amplitude.mask[:,i] = np.copy(v1.mask)
            ph.data[:,i] = np.arctan2(-np.imag(v1),np.real(v1))
            ph.mask[:,i] = np.copy(v1.mask)
        #-- update mask to invalidate points outside model domain
        ph.mask[:,i] |= invalid
        amplitude.mask[:,i] |= invalid

    #-- convert phase to degrees
    phase = ph*180.0/np.pi
    phase.data[phase.data < 0] += 360.0
    #-- replace data for invalid mask values
    amplitude.data[amplitude.mask] = amplitude.fill_value
    phase.data[phase.mask] = phase.fill_value
    #-- return the interpolated values
    return (amplitude,phase,D,constituents)

#-- PURPOSE: wrapper function to extend an array
def extend_array(input_array,step_size):
    """
    Wrapper function to extend an array

    Arguments
    ---------
    input_array: array to extend
    step_size: step size between elements of array

    Returns
    -------
    temp: extended array
    """
    n = len(input_array)
    temp = np.zeros((n+2),dtype=input_array.dtype)
    #-- extended array [x-1,x0,...,xN,xN+1]
    temp[0] = input_array[0] - step_size
    temp[1:-1] = input_array[:]
    temp[-1] = input_array[-1] + step_size
    return temp

#-- PURPOSE: wrapper function to extend a matrix
def extend_matrix(input_matrix):
    """
    Wrapper function to extend a matrix

    Arguments
    ---------
    input_matrix: matrix to extend

    Returns
    -------
    temp: extended matrix
    """
    ny,nx = np.shape(input_matrix)
    temp = np.ma.zeros((ny,nx+2),dtype=input_matrix.dtype)
    temp[:,0] = input_matrix[:,-1]
    temp[:,1:-1] = input_matrix[:,:]
    temp[:,-1] = input_matrix[:,0]
    return temp

#-- PURPOSE: read tide grid file
def read_tide_grid(input_file):
    """
    Read grid file to extract model coordinates, bathymetry, masks and indices

    Arguments
    ---------
    input_file: input grid file

    Returns
    -------
    x: x-coordinates of input grid
    y: y-coordinates of input grid
    hz: model bathymetry
    mz: land/water mask
    iob: open boundary index
    dt: time step
    """
    #-- open the file
    fid = open(os.path.expanduser(input_file),'rb')
    fid.seek(4,0)
    #-- read data as big endian
    #-- get model dimensions and limits
    nx, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    ny, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    #-- extract x and y limits (these could be latitude and longitude)
    ylim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    xlim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    dt, = np.fromfile(fid, dtype=np.dtype('>f4'), count=1)
    #-- convert longitudinal limits (if x == longitude)
    if (xlim[0] < 0) & (xlim[1] < 0) & (dt > 0):
        xlim += 360.0
    #-- create x and y arrays arrays (these could be lon and lat values)
    dx = (xlim[1] - xlim[0])/nx
    dy = (ylim[1] - ylim[0])/ny
    x = np.linspace(xlim[0]+dx/2.0,xlim[1]-dx/2.0,nx)
    y = np.linspace(ylim[0]+dy/2.0,ylim[1]-dy/2.0,ny)
    #-- read nob and iob from file
    nob, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    if (nob == 0):
        fid.seek(20,1)
        iob = []
    else:
        fid.seek(8,1)
        iob=np.fromfile(fid, dtype=np.dtype('>i4'), count=2*nob).reshape(nob,2)
        fid.seek(8,1)
    #-- read hz matrix
    hz = np.fromfile(fid, dtype=np.dtype('>f4'), count=nx*ny).reshape(ny,nx)
    fid.seek(8,1)
    #-- read mz matrix
    mz = np.fromfile(fid, dtype=np.dtype('>i4'), count=nx*ny).reshape(ny,nx)
    #-- close the file
    fid.close()
    #-- return values
    return (x,y,hz,mz,iob,dt)

#-- PURPOSE: read tide grid file with localized solutions
def read_atlas_grid(input_file):
    """
    Read ATLAS grid file to extract model coordinates, bathymetry, masks and
    indices for both global and local solutions

    Arguments
    ---------
    input_file: input ATLAS grid file

    Returns
    -------
    x: x-coordinates of input ATLAS grid
    y: y-coordinates of input ATLAS grid
    hz: model bathymetry
    mz: land/water mask
    iob: open boundary index
    dt: time step
    pmask: global mask
    local: dictionary of local tidal solutions for grid variables
        depth: model bathymetry
    """
    #-- read the input file to get file information
    fd = os.open(os.path.expanduser(input_file),os.O_RDONLY)
    file_info = os.fstat(fd)
    fid = os.fdopen(fd, 'rb')
    fid.seek(4,0)
    #-- read data as big endian
    #-- get model dimensions and limits
    nx, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    ny, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    #-- extract latitude and longitude limits
    lats = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    lons = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    dt, = np.fromfile(fid, dtype=np.dtype('>f4'), count=1)
    #-- create lon and lat arrays
    dlon = (lons[1] - lons[0])/nx
    dlat = (lats[1] - lats[0])/ny
    x = np.linspace(lons[0]+dlon/2.0,lons[1]-dlon/2.0,nx)
    y = np.linspace(lats[0]+dlat/2.0,lats[1]-dlat/2.0,ny)
    #-- read nob and iob from file
    nob, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    if (nob == 0):
        fid.seek(20,1)
        iob = []
    else:
        fid.seek(8,1)
        iob=np.fromfile(fid, dtype=np.dtype('>i4'), count=2*nob).reshape(nob,2)
        fid.seek(8,1)
    #-- read hz matrix
    hz = np.fromfile(fid, dtype=np.dtype('>f4'), count=nx*ny).reshape(ny,nx)
    fid.seek(8,1)
    #-- read mz matrix
    mz = np.fromfile(fid, dtype=np.dtype('>i4'), count=nx*ny).reshape(ny,nx)
    fid.seek(8,1)
    #-- read pmask matrix
    pmask = np.fromfile(fid, dtype=np.dtype('>i4'), count=nx*ny).reshape(ny,nx)
    fid.seek(4,1)
    #-- read local models
    nmod = 0
    local = {}
    #-- while the file position is not at the end of file
    while (fid.tell() < file_info.st_size):
        #-- add 1 to number of models
        fid.seek(4,1)
        nmod += 1
        #-- get local model dimensions and limits
        nx1, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        ny1, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        nd, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        #-- extract latitude and longitude limits of local model
        lt1 = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
        ln1 = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
        #-- extract name
        name = fid.read(20).strip()
        fid.seek(8,1)
        iz = np.fromfile(fid, dtype=np.dtype('>i4'), count=nd)
        jz = np.fromfile(fid, dtype=np.dtype('>i4'), count=nd)
        fid.seek(8,1)
        depth = np.ma.zeros((ny1,nx1))
        depth.mask = np.ones((ny1,nx1),dtype=bool)
        depth.data[jz-1,iz-1] = np.fromfile(fid,dtype=np.dtype('>f4'),count=nd)
        depth.mask[jz-1,iz-1] = False
        fid.seek(4,1)
        #-- save to dictionary
        local[name] = dict(lon=ln1,lat=lt1,depth=depth)
    #-- close the file
    fid.close()
    #-- return values
    return (x,y,hz,mz,iob,dt,pmask,local)

#-- PURPOSE: read list of constituents from an elevation or transport file
def read_constituents(input_file):
    """
    Read the list of constituents from an elevation or transport file

    Arguments
    ---------
    input_file: input tidal file

    Returns
    -------
    constituents: list of tidal constituent IDs
    nc: number of constituents
    """
    #-- open the file
    fid = open(os.path.expanduser(input_file),'rb')
    ll, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    nx,ny,nc = np.fromfile(fid, dtype=np.dtype('>i4'), count=3)
    fid.seek(16,1)
    constituents = [c.decode("utf-8").rstrip() for c in fid.read(nc*4).split()]
    fid.close()
    return (constituents,nc)

#-- PURPOSE: read elevation file to extract real and imaginary components for
#-- constituent
def read_elevation_file(input_file,ic):
    """
    Read elevation file to extract real and imaginary components for constituent

    Arguments
    ---------
    input_file: input elevation file
    ic: index of consituent

    Returns
    -------
    h: tidal elevation
    """
    #-- open the file
    fid = open(os.path.expanduser(input_file),'rb')
    ll, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    nx,ny,nc = np.fromfile(fid, dtype=np.dtype('>i4'), count=3)
    #-- extract x and y limits
    ylim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    xlim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    #-- skip records to constituent
    nskip = ic*(nx*ny*8+8) + 8 + ll - 28
    fid.seek(nskip,1)
    #-- real and imaginary components of elevation
    h = np.ma.zeros((ny,nx),dtype=np.complex64)
    h.mask = np.zeros((ny,nx),dtype=bool)
    for i in range(ny):
        temp = np.fromfile(fid, dtype=np.dtype('>f4'), count=2*nx)
        h.data.real[i,:] = temp[0:2*nx-1:2]
        h.data.imag[i,:] = temp[1:2*nx:2]
    #-- update mask for nan values
    h.mask[np.isnan(h.data)] = True
    #-- replace masked values with fill value
    h.data[h.mask] = h.fill_value
    #-- close the file
    fid.close()
    #-- return the elevation
    return h

#-- PURPOSE: read elevation file with localized solutions to extract real and
#-- imaginary components for constituent
def read_atlas_elevation(input_file,ic,constituent):
    """
    Read elevation file with localized solutions to extract real and imaginary
    components for constituent

    Arguments
    ---------
    input_file: input ATLAS elevation file
    ic: index of consituent
    constituent: tidal constituent ID

    Returns
    -------
    h: global tidal elevation
    local: dictionary of local tidal solutions for elevation variables
        z: tidal elevation
    """
    #-- read the input file to get file information
    fd = os.open(os.path.expanduser(input_file),os.O_RDONLY)
    file_info = os.fstat(fd)
    fid = os.fdopen(fd, 'rb')
    ll, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    nx,ny,nc = np.fromfile(fid, dtype=np.dtype('>i4'), count=3)
    #-- extract x and y limits
    ylim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    xlim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    #-- skip records to constituent
    nskip = 8 + nc*4 + ic*(nx*ny*8 + 8)
    fid.seek(nskip,1)
    #-- real and imaginary components of elevation
    h = np.ma.zeros((ny,nx),dtype=np.complex64)
    h.mask = np.zeros((ny,nx),dtype=bool)
    for i in range(ny):
        temp = np.fromfile(fid, dtype=np.dtype('>f4'), count=2*nx)
        h.data.real[i,:] = temp[0:2*nx-1:2]
        h.data.imag[i,:] = temp[1:2*nx:2]
    #-- skip records after constituent
    nskip = (nc-ic-1)*(nx*ny*8 + 8) + 4
    fid.seek(nskip,1)
    #-- read local models to find constituent
    nmod = 0
    local = {}
    #-- while the file position is not at the end of file
    while (fid.tell() < file_info.st_size):
        #-- add 1 to number of models
        fid.seek(4,1)
        nmod += 1
        #-- get local model dimensions and limits
        nx1, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        ny1, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        nc1, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        nz, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        #-- extract latitude and longitude limits of local model
        lt1 = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
        ln1 = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
        #-- extract constituents for localized solution
        cons = fid.read(nc1*4).strip().split()
        #-- check if constituent is in list of localized solutions
        if (constituent in cons):
            ic1, = [i for i,c in enumerate(cons) if (c == constituent)]
            #-- extract name
            name = fid.read(20).strip()
            fid.seek(8,1)
            iz = np.fromfile(fid, dtype=np.dtype('>i4'), count=nz)
            jz = np.fromfile(fid, dtype=np.dtype('>i4'), count=nz)
            #-- skip records to constituent
            nskip = 8 + ic1*(8*nz + 8)
            fid.seek(nskip,1)
            #-- real and imaginary components of elevation
            h1 = np.ma.zeros((ny1,nx1),fill_value=np.nan,dtype=np.complex64)
            h1.mask = np.ones((ny1,nx1),dtype=bool)
            temp = np.fromfile(fid, dtype=np.dtype('>f4'), count=2*nz)
            h1.data.real[jz-1,iz-1] = temp[0:2*nz-1:2]
            h1.data.imag[jz-1,iz-1] = temp[1:2*nz:2]
            h1.mask[jz-1,iz-1] = False
            #-- save constituent to dictionary
            local[name] = dict(lon=ln1,lat=lt1,z=h1)
            #-- skip records after constituent
            nskip = (nc1-ic1-1)*(8*nz + 8) + 4
            fid.seek(nskip,1)
        else:
            #-- skip records for local model if constituent not in list
            nskip = 40 + 16*nz + (nc1-1)*(8*nz + 8)
            fid.seek(nskip,1)
    #-- close the file
    fid.close()
    #-- return the elevation
    return (h,local)

#-- PURPOSE: read transport file to extract real and imaginary components for
#-- constituent
def read_transport_file(input_file,ic):
    """
    Read transport file to extract real and imaginary components for constituent

    Arguments
    ---------
    input_file: input transport file
    ic: index of consituent

    Returns
    -------
    u: zonal tidal transport
    v: meridional zonal transport
    """
    #-- open the file
    fid = open(os.path.expanduser(input_file),'rb')
    ll, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    nx,ny,nc = np.fromfile(fid, dtype=np.dtype('>i4'), count=3)
    #-- extract x and y limits
    ylim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    xlim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    #-- skip records to constituent
    nskip = ic*(nx*ny*16+8) + 8 + ll - 28
    fid.seek(nskip,1)
    #-- real and imaginary components of transport
    u = np.ma.zeros((ny,nx),dtype=np.complex64)
    u.mask = np.zeros((ny,nx),dtype=bool)
    v = np.ma.zeros((ny,nx),dtype=np.complex64)
    v.mask = np.zeros((ny,nx),dtype=bool)
    for i in range(ny):
        temp = np.fromfile(fid, dtype=np.dtype('>f4'), count=4*nx)
        u.data.real[i,:] = temp[0:4*nx-3:4]
        u.data.imag[i,:] = temp[1:4*nx-2:4]
        v.data.real[i,:] = temp[2:4*nx-1:4]
        v.data.imag[i,:] = temp[3:4*nx:4]
    #-- update mask for nan values
    u.mask[np.isnan(u.data)] = True
    v.mask[np.isnan(v.data)] = True
    #-- replace masked values with fill value
    u.data[u.mask] = u.fill_value
    v.data[v.mask] = v.fill_value
    #-- close the file
    fid.close()
    #-- return the transport components
    return (u,v)

#-- PURPOSE: read transport file with localized solutions to extract real and
#-- imaginary components for constituent
def read_atlas_transport(input_file,ic,constituent):
    """
    Read transport file with localized solutions to extract real and imaginary
    components for constituent

    Arguments
    ---------
    input_file: input ATLAS transport file
    ic: index of consituent
    constituent: tidal constituent ID

    Returns
    -------
    u: global zonal tidal transport
    v: global meridional zonal transport
    local: dictionary of local tidal solutions for transport variables
        u: zonal tidal transport
        v: meridional zonal transport
    """
    #-- read the input file to get file information
    fd = os.open(os.path.expanduser(input_file),os.O_RDONLY)
    file_info = os.fstat(fd)
    fid = os.fdopen(fd, 'rb')
    ll, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
    nx,ny,nc = np.fromfile(fid, dtype=np.dtype('>i4'), count=3)
    #-- extract x and y limits
    ylim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    xlim = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
    #-- skip records to constituent
    nskip = 8 + nc*4 + ic*(nx*ny*16 + 8)
    fid.seek(nskip,1)
    #-- real and imaginary components of transport
    u = np.ma.zeros((ny,nx),dtype=np.complex64)
    u.mask = np.zeros((ny,nx),dtype=bool)
    v = np.ma.zeros((ny,nx),dtype=np.complex64)
    v.mask = np.zeros((ny,nx),dtype=bool)
    for i in range(ny):
        temp = np.fromfile(fid, dtype=np.dtype('>f4'), count=4*nx)
        u.data.real[i,:] = temp[0:4*nx-3:4]
        u.data.imag[i,:] = temp[1:4*nx-2:4]
        v.data.real[i,:] = temp[2:4*nx-1:4]
        v.data.imag[i,:] = temp[3:4*nx:4]
    #-- skip records after constituent
    nskip = (nc-ic-1)*(nx*ny*16 + 8) + 4
    fid.seek(nskip,1)
    #-- read local models to find constituent
    nmod = 0
    local = {}
    #-- while the file position is not at the end of file
    while (fid.tell() < file_info.st_size):
        #-- add 1 to number of models
        fid.seek(4,1)
        nmod += 1
        #-- get local model dimensions and limits
        nx1, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        ny1, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        nc1, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        nu, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        nv, = np.fromfile(fid, dtype=np.dtype('>i4'), count=1)
        #-- extract latitude and longitude limits of local model
        lt1 = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
        ln1 = np.fromfile(fid, dtype=np.dtype('>f4'), count=2)
        #-- extract constituents for localized solution
        cons = fid.read(nc1*4).strip().split()
        #-- check if constituent is in list of localized solutions
        if (constituent in cons):
            ic1, = [i for i,c in enumerate(cons) if (c == constituent)]
            #-- extract name
            name = fid.read(20).strip()
            fid.seek(8,1)
            iu = np.fromfile(fid, dtype=np.dtype('>i4'), count=nu)
            ju = np.fromfile(fid, dtype=np.dtype('>i4'), count=nu)
            fid.seek(8,1)
            iv = np.fromfile(fid, dtype=np.dtype('>i4'), count=nv)
            jv = np.fromfile(fid, dtype=np.dtype('>i4'), count=nv)
            #-- skip records to constituent
            nskip = 8 + ic1*(8*nu + 8*nv + 16)
            fid.seek(nskip,1)
            #-- real and imaginary components of u transport
            u1 = np.ma.zeros((ny1,nx1),fill_value=np.nan,dtype=np.complex64)
            u1.mask = np.ones((ny1,nx1),dtype=bool)
            tmpu = np.fromfile(fid, dtype=np.dtype('>f4'), count=2*nu)
            u1.data.real[ju-1,iu-1] = tmpu[0:2*nu-1:2]
            u1.data.imag[ju-1,iu-1] = tmpu[1:2*nu:2]
            u1.mask[ju-1,iu-1] = False
            fid.seek(8,1)
            #-- real and imaginary components of v transport
            v1 = np.ma.zeros((ny1,nx1),fill_value=np.nan,dtype=np.complex64)
            v1.mask = np.ones((ny1,nx1),dtype=bool)
            tmpv = np.fromfile(fid, dtype=np.dtype('>f4'), count=2*nv)
            v1.data.real[jv-1,iv-1] = tmpv[0:2*nv-1:2]
            v1.data.imag[jv-1,iv-1] = tmpv[1:2*nv:2]
            v1.mask[jv-1,iv-1] = False
            #-- save constituent to dictionary
            local[name] = dict(lon=ln1,lat=lt1,u=u1,v=v1)
            #-- skip records after constituent
            nskip = (nc1-ic1-1)*(8*nu + 8*nv + 16) + 4
            fid.seek(nskip,1)
        else:
            #-- skip records for local model if constituent not in list
            nskip = 56 + 16*nu + 16*nv + (nc1-1)*(8*nu + 8*nv + 16)
            fid.seek(nskip,1)
    #-- close the file
    fid.close()
    #-- return the transport components
    return (u,v,local)

#-- PURPOSE: create a 2 arc-minute grid mask from mz and depth variables
def create_atlas_mask(xi,yi,mz,local,VARIABLE=None):
    """
    Creates a high-resolution grid mask from model variables

    Arguments
    ---------
    xi: input x-coordinates of global tide model
    yi: input y-coordinates of global tide model
    mz: global land/water mask
    local: dictionary of local tidal solutions

    Keyword arguments
    -----------------
    VARIABLE: key for variable within each local solution
        depth: model bathymetry

    Returns
    -------
    x30: x-coordinates of high-resolution tide model
    y30: y-coordinates of high-resolution tide model
    m30: high-resolution land/water mask
    """
    #-- create 2 arc-minute grid dimensions
    d30 = 1.0/30.0
    x30 = np.arange(d30/2.0, 360.0+d30/2.0, d30)
    y30 = np.arange(-90.0+d30/2.0, 90.0+d30/2.0, d30)
    #-- interpolate global mask to create initial 2 arc-minute mask
    xcoords=np.clip((len(xi)-1)*(x30-xi[0])/(xi[-1]-xi[0]),0,len(xi)-1)
    ycoords=np.clip((len(yi)-1)*(y30-yi[0])/(yi[-1]-yi[0]),0,len(yi)-1)
    gridy,gridx=np.meshgrid(np.around(ycoords),np.around(xcoords),indexing='ij')
    #-- interpolate with nearest-neighbors
    m30 = np.ma.zeros((len(y30),len(x30)),dtype=np.int8,fill_value=0)
    m30.data[:,:] = mz[gridy.astype(np.int32),gridx.astype(np.int32)]
    #-- iterate over localized solutions to fill in high-resolution coastlines
    for key,val in local.items():
        #-- create latitude and longitude for local model
        ilon = np.arange(val['lon'][0]+d30/2.0,val['lon'][1]+d30/2.0,d30)
        ilat = np.arange(val['lat'][0]+d30/2.0,val['lat'][1]+d30/2.0,d30)
        X,Y = np.meshgrid(ilon,ilat)
        #-- local model output
        validy,validx = np.nonzero(~val[VARIABLE].mask)
        for indy,indx in zip(validy,validx):
            #-- check if model is -180:180
            lon30 = (X[indy,indx]+360.) if (X[indy,indx]<=0.0) else X[indy,indx]
            ii = int((lon30 - x30[0])/d30)
            jj = int((Y[indy,indx] - y30[0])/d30)
            #-- fill global mask with regional solution
            m30[jj,ii] = 1
    #-- return the 2 arc-minute mask
    m30.mask = (m30.data == m30.fill_value)
    return m30

#-- PURPOSE: combines global and local atlas solutions
def combine_atlas_model(xi,yi,zi,pmask,local,VARIABLE=None):
    """
    Combines global and local ATLAS tidal solutions into a single
    high-resolution solution

    Arguments
    ---------
    xi: input x-coordinates of global tide model
    yi: input y-coordinates of global tide model
    zi: global tide model data
    pmask: global mask
    local: dictionary of local tidal solutions

    Keyword arguments
    -----------------
    VARIABLE: key for variable within each local solution
        depth: model bathymetry
        z: tidal elevation
        u: zonal tidal transport
        v: meridional zonal transport

    Returns
    -------
    x30: x-coordinates of high-resolution tide model
    y30: y-coordinates of high-resolution tide model
    z30: combined high-resolution tidal solution for variable
    """
    #-- create 2 arc-minute grid dimensions
    d30 = 1.0/30.0
    x30 = np.arange(d30/2.0, 360.0+d30/2.0, d30)
    y30 = np.arange(-90.0+d30/2.0, 90.0+d30/2.0, d30)
    #-- interpolate global solution to 2 arc-minute solution
    z30 = np.ma.zeros((len(y30),len(x30)),dtype=zi.dtype)
    z30.mask = np.zeros((len(y30),len(x30)),dtype=bool)
    #-- test if combining elevation/transport variables with complex components
    if np.iscomplexobj(z30):
        f1 = scipy.interpolate.RectBivariateSpline(xi, yi, zi.real.T, kx=1,ky=1)
        f2 = scipy.interpolate.RectBivariateSpline(xi, yi, zi.imag.T, kx=1,ky=1)
        z30.data.real[:,:] = f1(x30,y30).T
        z30.data.imag[:,:] = f2(x30,y30).T
    else:
        f = scipy.interpolate.RectBivariateSpline(xi, yi, zi.T, kx=1,ky=1)
        z30.data[:,:] = f(x30,y30).T
    #-- iterate over localized solutions
    for key,val in local.items():
        #-- local model output
        zlocal = val[VARIABLE][:]
        validy,validx = np.nonzero(~zlocal.mask)
        #-- create latitude and longitude for local model
        ilon = np.arange(val['lon'][0]+d30/2.0,val['lon'][1]+d30/2.0,d30)
        ilat = np.arange(val['lat'][0]+d30/2.0,val['lat'][1]+d30/2.0,d30)
        X,Y = np.meshgrid(ilon,ilat)
        for indy,indx in zip(validy,validx):
            #-- check if model is -180:180
            lon30 = (X[indy,indx]+360.) if (X[indy,indx]<=0.0) else X[indy,indx]
            ii = int((lon30 - x30[0])/d30)
            jj = int((Y[indy,indx] - y30[0])/d30)
            #-- fill global model with regional solution
            z30.data[jj,ii] = zlocal[indy,indx]
    #-- return 2 arc-minute solution and coordinates
    return (x30,y30,z30)

#-- For a rectangular bathymetry grid:
#-- construct masks for zeta, u and v nodes on a C-grid
def Muv(hz):
    """
    Construct masks for zeta, u and v nodes on a C-grid
    """
    ny,nx = np.shape(hz)
    mz = (hz > 0).astype(int)
    #-- x-indices
    indx = np.zeros((nx),dtype=int)
    indx[:-1] = np.arange(1,nx)
    indx[-1] = 0
    #-- y-indices
    indy = np.zeros((ny),dtype=int)
    indy[:-1] = np.arange(1,ny)
    indy[-1] = 0
    #-- calculate mu and mv
    mu = np.zeros((ny,nx),dtype=int)
    mv = np.zeros((ny,nx),dtype=int)
    mu[indy,:] = mz*mz[indy,:]
    mv[:,indx] = mz*mz[:,indx]
    return (mu,mv,mz)

#-- PURPOSE: Interpolate bathymetry to zeta, u and v nodes on a C-grid
def Huv(hz):
    """
    Interpolate bathymetry to zeta, u and v nodes on a C-grid
    """
    ny,nx = np.shape(hz)
    mu,mv,mz = Muv(hz)
    #-- x-indices
    indx = np.zeros((nx),dtype=int)
    indx[0] = nx-1
    indx[1:] = np.arange(1,nx)
    #-- y-indices
    indy = np.zeros((ny),dtype=int)
    indy[0] = ny-1
    indy[1:] = np.arange(1,ny)
    #-- calculate hu and hv
    hu = mu*(hz + hz[indy,:])/2.0
    hv = mv*(hz + hz[:,indx])/2.0
    return (hu,hv)
