#!/usr/bin/env python

"Repair or correct neuron morphology"

from __future__ import print_function, division
from builtins import list, dict, set, sum, abs, min, max, str
from morphon import Morph, select, rotation, repair_cut, resample
import numpy as np
import argparse
import math
import sys


def recycle(a):
    x = a.pop(0)
    a.append(x)
    return x


def _norm(v):
    return math.sqrt(v[0]*v[0]+v[1]*v[1]+v[2]*v[2])


def parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('file', type=str, help='input morphology file (swc)')
    parser.add_argument('-d', dest='diam', type=str, help='list of node IDs (csv)')
    parser.add_argument('-z', dest='jump', type=str, help='list of node IDs (csv)')
    parser.add_argument('-b', dest='bridge', type=str, help='list of node IDs (csv)')
    parser.add_argument('-t', dest='tilt', type=str, help='list of node IDs (csv)')
    parser.add_argument('-c', dest='cuts', type=str, help='list of node IDs (csv)')
    parser.add_argument('-s', dest='seed', type=int, help='random seed value (int)')
    parser.add_argument('-k', dest='shrink', type=str,
        help='shrinkage correction factor z,y,x (csv)')
    parser.add_argument('-r', dest='res', type=float, help='sampling resolution, um')
    parser.add_argument('-n', dest='renumber', action='store_true', help='renumber nodes')
    parser.add_argument('-e', dest='center', action='store_true', help='center root')
    parser.add_argument('-o', dest='out', type=str, default='rep.swc',
        help='output file name (swc) [rep.swc]')
    parser.add_argument('-v', dest='verbose', action='store_true', help='verbose output')
    return parser.parse_args()


def main(args):
    err = 0
    m = Morph()
    m.load(args.file)
    vprint = print if args.verbose else lambda *a, **k: None
    if args.diam:
        idents = list(int(item) for item in args.diam.replace(',', ' ').split(' '))
        for item in idents:
            parent = m.parent(item)
            children = m.children(item)
            radii = list()
            if parent is not None:
                radii.append(m.radius(parent))
            for child in children:
                radii.append(m.radius(child))
            m.nodes[item].value[2] = np.mean(radii)
    if args.jump:
        idents = list(int(item) for item in args.jump.replace(',', ' ').split(' '))
        for item in idents:
            parent = m.parent(item)
            z0 = m.coord(parent)[2]
            z1 = m.coord(item)[2]
            m.translate([0,0,z0-z1], item)
    if args.bridge:
        idents = list(int(item) for item in args.bridge.replace(',', ' ').split(' '))
        for item in idents:
            parent = m.parent(item)
            c0 = m.coord(parent)
            c1 = m.coord(item)
            m.move(item, (c0-c1)/2)
    if args.tilt:
        idents = list(int(item) for item in args.tilt.replace(',', ' ').split(' '))
        for ident in idents:
            parent = m.parent(ident)
            cp = m.coord(parent)
            c0 = m.coord(ident)
            d0 = c0[2]-cp[2]
            ll = max(_norm(m.coord(item)-c0) for item in m.leaves(ident))
            for item in m.traverse(ident):
                cx = m.coord(item)
                lx = _norm(cx-c0)
                m.move(item, (0, 0, d0/ll*(lx-ll)))
    if args.seed:
        np.random.seed(args.seed)
    if args.cuts:
        idents = list(int(item) for item in args.cuts.replace(',', ' ').split(' '))
        cut_points = dict()
        types = set()
        for item in idents:
            ptype = m.type(item)
            order = m.order(item)
            cut_points[item] = dict(type=ptype, order=order)
            types.add(ptype)
        intact_branches = dict()
        for ptype in types:
            for sec in select(m, m.sections(), types=[ptype]):
                head = sec[0]
                ptype = m.type(head)
                order = m.order(head)
                if set(m.leaves(head)).isdisjoint(cut_points):
                    if ptype not in intact_branches:
                        intact_branches[ptype] = dict()
                    if order not in intact_branches[ptype]:
                        intact_branches[ptype][order] = dict(idents=list())
                    intact_branches[ptype][order]['idents'].append(head)
        for ptype in intact_branches:
            for order in intact_branches[ptype]:
                if intact_branches[ptype][order]['idents']:
                    np.random.shuffle(intact_branches[ptype][order]['idents'])
        while idents:
            ident = idents.pop()
            ptype = cut_points[ident]['type']
            order = cut_points[ident]['order']
            #
            def repair():
                head = recycle(intact_branches[ptype][order]['idents'])
                len0 = sum(m.length(item) for item in m.section(head))
                len1 = sum(m.length(item) for item in m.section(ident, reverse=True))
                if len0 > len1:
                    offset = 0
                    for item in m.section(head):
                        offset += m.length(item)
                        if offset > len1:
                            break
                    head = item
                vprint('repairing', ident, 'using', head)
                if not m.is_bifurcation(head):
                    repair_cut(m, ident, head, flip_z=True)
                else:
                    vprint('cannot repair', ident, 'using', head, '(bifurcation point), ignored')
            #
            if order in intact_branches[ptype] and intact_branches[ptype][order]['idents']:
                repair()
            else:
                vprint('cannot repair', ident, 'type', ptype, 'order', order, end=' ')
                orders = list(sorted(intact_branches[ptype], key=lambda x: abs(order-x), reverse=True))
                order = orders.pop()
                vprint('trying order', order, '...', end=' ')
                repair()
                err = 1
    if args.renumber:
        vprint('nodes renumbered')
        m.renumber()
    if args.shrink:
        shrink = list(float(item) for item in args.shrink.replace(',', ' ').split(' '))
        c = np.ones(3)
        for i, v in enumerate(shrink):
            c[i] = v
        c = np.flip(c, axis=0)
        zmin = min(m.coord(item)[2] for item in m.nodes)
        m.translate([0, 0, zmin*(1-c[2])/c[2]])
        vprint('scaled by (x,y,z)', c)
        m.scale(c)
    if args.res:
        m = resample(m, args.res)
    if args.center:
        m.translate(-m.coord(m.root()))
    m.save(args.out)
    return err


if __name__ == '__main__':
    sys.exit(main(parse_args()))
