#!/usr/bin/env python

"Test morphology reconstruction for structural consistency"

from __future__ import print_function
from builtins import all, list, set, filter, str
from morphon import Morph, Point
import argparse
import json
import sys


def parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('file', type=str, help='input morphology file (swc)')
    parser.add_argument('-i', dest='err', type=str, help='investigate violated condition')
    parser.add_argument('-q', dest='quiet', action='store_true', help='disable output')
    return parser.parse_args()


def main(args):
    m = Morph()
    m.load(args.file)
    rep = m.check()
    errors = list(key for (key, value) in rep.items() if not value)
    if args.err:
        if args.err in errors:
            nodes = list()
            if args.err == 'single_root':
                nodes = list(m.roots())
            elif args.err == 'defined_ids':
                nodes = filter(lambda item: m.nodes[item].ident is None, m.nodes)
            elif args.err == 'unique_ids':
                idents = list()
                for item in m.nodes:
                    if m.nodes[item].ident not in idents:
                        idents.append(m.nodes[item].ident)
                    else:
                        nodes.append(item)
            elif args.err == 'valid_ids':
                for item in m.nodes:
                    if m.nodes[item].ident not in m.nodes:
                        nodes.append(item)
            elif args.err == 'valid_parents':
                for item in m.nodes:
                    if not m.is_root(item):
                        if m.parent(item) not in m.nodes:
                            nodes.append(item)
            elif args.err == 'valid_children':
                for item in m.nodes:
                    if not set(m.children(item)).issubset(m.nodes):
                        nodes.append(item)
            elif args.err == 'no_root_reference':
                roots = set(m.roots())
                for item in m.nodes:
                    if not set(m.children(item)).isdisjoint(roots):
                        nodes.append(item)
            elif args.err == 'no_self_reference':
                for item in m.nodes:
                    if m.parent(item) in m.children(item):
                        nodes.append(item)
            elif args.err == 'valid_links':
                for item in m.nodes:
                    if not m.is_root(item):
                        if item not in m.children(m.parent(item)):
                            nodes.append(item)
            elif args.err == 'traversable_data':
                nodes = list(m.traverse())
                nodes = list(set(m.nodes).difference(nodes))
            elif args.err == 'increasing_id_order':
                for item in m.nodes:
                    if not m.is_root(item):
                        if m.parent(item) > item:
                            nodes.append(item)
            elif args.err == 'valid_types':
                for item in m.nodes:
                    if m.type(item) not in Point.TYPES:
                        nodes.append(item)
            elif args.err == 'unit_id_in_soma':
                soma = set(filter(m.is_soma, m.nodes))
                nodes = soma.difference(set([1]))
            elif args.err == 'unit_id_is_root':
                root = m.root()
                if root != 1 and m.type(root) == Point.SOMA:
                    nodes.append(root)
            elif args.err == 'unit_id_is_origin':
                for item in m.leaves():
                    path = list(m.traverse(item, reverse=True))
                    if path[-1] != 1:
                        nodes.append(path[-1])
            elif args.err == 'nonzero_diam':
                for item in m.nodes:
                    if m.diam(item) == 0:
                        nodes.append(item)
            elif args.err == 'sequential_ids':
                node = 0
                for item in sorted(m.nodes):
                    if item - node > 1:
                        nodes.append(item)
                    node = item
            for node in nodes:
                print(node, end=' ')
            print()
    else:
        if not args.quiet:
            for err in errors:
                print(err, end=' ')
            print() if errors else print(end='')
    return 0 if all(rep.values()) else 1


if __name__ == '__main__':
    sys.exit(main(parse_args()))
