#!/usr/bin/env python

"Find singularities in morphological reconstruction"

from __future__ import print_function, division
from builtins import filter, list, set, str, eval, min, abs
from morphon import Morph, select
import numpy as np
import argparse

def _cos(u, v):
    #return np.abs(np.dot(u,v))/(np.linalg.norm(u)*np.linalg.norm(v))
    return np.dot(u,v)/(np.linalg.norm(u)*np.linalg.norm(v))

def parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('file', type=str, help='input morphology file (swc)')
    parser.add_argument("-p", dest='type', type=str, default='1,2,3,4',
        help='list of point types (csv) [1,2,3,4]')
    parser.add_argument("-d", dest='diam', type=str, help='diameter condition, um (> < == x)')
    parser.add_argument("-l", dest='len', type=str, help='segment length condition, um (> < == x)')
    parser.add_argument("-a", dest='turn', type=float, help='turn angle threshold (cos)')
    parser.add_argument("-g", dest='zzag', type=float, help='zigzag turn threshold (cos)')
    parser.add_argument("-z", dest='jump', type=float, help='zjump threshold, um')
    parser.add_argument("-c", dest='cut', type=float, help='cut plane thickness, um')
    return parser.parse_args()


def main(args):
    m = Morph()
    m.load(args.file)
    types = set(int(item) for item in args.type.replace(',', ' ').split(' '))
    idents = filter(lambda item: m.type(item) in types, m.nodes)
    if args.diam:
        idents = filter(lambda item: eval(str(m.diam(item))+args.diam), idents)
    if args.len:
        idents = filter(lambda item: eval(str(m.length(item))+args.len), idents)
    if args.jump:
        def is_jump(ident):
            ret = False
            if not m.is_root(ident):
                parent = m.parent(ident)
                c0 = m.coord(parent)
                c1 = m.coord(ident)
                ret = abs(c1[2]-c0[2]) > args.jump
            return ret
        #
        idents = filter(is_jump, idents)
    if args.turn:
        def is_turn(ident):
            ret = False
            if m.is_inner(ident):
                parent = m.parent(ident)
                if m.is_inner(parent):
                    grand = m.parent(parent)
                    c0 = m.coord(parent)
                    c1 = m.coord(ident)
                    v0 = c0 - m.coord(grand)
                    v1 = c1 - c0
                    ret = _cos(v0, v1) < args.turn
            return ret
        #
        idents = filter(is_turn, idents)
    if args.zzag:
        def is_zigzag(ident):
            ret = False
            if m.is_inner(ident):
                parent = m.parent(ident)
                if m.is_inner(parent):
                    grand = m.parent(parent)
                    child = m.children(ident)[0]
                    c0 = m.coord(parent)
                    c1 = m.coord(ident)
                    c2 = m.coord(child)
                    v0 = c0 - m.coord(grand)
                    v1 = c1 - c0
                    v2 = c2 - c1
                    ret = _cos(v0, v1) < args.zzag and _cos(v1, v2) < args.zzag
            return ret
        #
        idents = filter(is_zigzag, idents)
    if args.cut:
        h = args.cut
        idents = list(filter(m.is_leaf, idents))
        c = list(m.coord(item) for item in idents)
        cmin, cmax = np.min(c, axis=0), np.max(c, axis=0)
        xmin, ymin, zmin = cmin
        xmax, ymax, zmax = cmax
        #
        def is_cut(ident):
            c = m.coord(ident)
            return c[0] > xmax-h or c[0] < xmin+h \
                or c[1] > ymax-h or c[1] < ymin+h \
                or c[2] > zmax-h or c[2] < zmin+h
        #
        idents = filter(is_cut, idents)
    for item in idents:
        print(item, end=' ')
    print()


if __name__ == '__main__':
    main(parse_args())
