#!/usr/bin/env python

"Morphology modification"

from __future__ import print_function, division
from builtins import max, list, sum, str
from morphon import Morph, resample, select, stretch, rotation
import numpy as np
import argparse
import sys


def parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('file', type=str, help='input morphology file (swc)')
    parser.add_argument('-p', dest='type', type=str, default='2,3,4',
        help='list of point types (csv) [2,3,4]')
    parser.add_argument('-e', dest='order', type=str, help='list of branch orders (csv)')
    parser.add_argument('-d', dest='degree', type=str, help='list of branch degrees (csv)')
    parser.add_argument('-t', dest='stretch', type=float, help='stretch factor')
    parser.add_argument('-c', dest='scale', type=str, help='scale factor x,y,z,r (csv)')
    parser.add_argument('-s', dest='seed', type=int, help='random seed value (int)')
    parser.add_argument('-j', dest='jitter', type=float, help='coordinate jitter, um')
    parser.add_argument('-r', dest='prune', type=int, help='prune random branches (1..N)')
    parser.add_argument('-w', dest='swap', type=int, help='swap random branches (1..N)')
    parser.add_argument('-i', dest='twist', type=int, help='twist random branches (1..N)')
    parser.add_argument('-o', dest='out', type=str, default='mod.swc',
        help='output file name (swc) [mod.swc]')
    parser.add_argument('-v', dest='verbose', action='store_true', help='verbose output')
    return parser.parse_args()


def main(args):
    err = 0
    m = Morph()
    m.load(args.file)
    vprint = print if args.verbose else lambda *a, **k: None
    types = set(int(item) for item in args.type.replace(',', ' ').split(' '))
    if args.order:
        orders = set(int(item) for item in args.order.replace(',', ' ').split(' '))
    else:
        orders = []
    if args.degree:
        degrees = set(int(item) for item in args.degree.replace(',', ' ').split(' '))
    else:
        degrees = []
    if args.stretch:
        for sec in select(m, m.sections(), types=types, orders=orders, degrees=degrees, sec=True):
            stretch(m, sec, args.stretch)
    if args.scale:
        scale = list(float(item) for item in args.scale.replace(',', ' ').split(' '))
        c = np.ones(4)
        for i, v in enumerate(scale):
            c[i] = v
        vprint('scale x,y,z,r', c)
        c0 = m.coord(m.root())
        c1 = c0 * c[:3]
        v1 = c1 - c0
        for sec in select(m, m.sections(), types=types, orders=orders, degrees=degrees):
            for ident in sec:
                m.nodes[ident].value[1] *= c[:3]
                m.nodes[ident].value[1] -= v1
                m.nodes[ident].value[2] *= c[3]
    if args.seed:
        np.random.seed(args.seed)
    if args.jitter:
        for sec in select(m, m.sections(), types=types, orders=orders, degrees=degrees):
            ident = sec[0]
            jitter = np.random.uniform(-args.jitter/2, args.jitter/2, 3)
            vprint('perturbing', ident, 'by', jitter)
            m.translate(jitter, ident)
    if args.prune:
        for i in range(args.prune):
            idents = list(sec[0] for sec in
                select(m, m.sections(), types=types, orders=orders, degrees=degrees, sec=True))
            if idents:
                ident = np.random.choice(idents)
                vprint('prunining', ident)
                m.prune(ident)
            else:
                vprint('cannot prune at iteration', i)
                err = 1
                break
        m.renumber()
        vprint('nodes renumbered')
    if args.swap:
        for i in range(args.swap):
            idents = list(sec[0] for sec in
                select(m, m.sections(), types=types, orders=orders, degrees=degrees, sec=True))
            idents = filter(lambda item: not m.is_fork(item), idents)
            if len(idents) < 2:
                err = 1
                break
            head1 = np.random.choice(idents)
            idents.remove(head1)
            ptype = m.type(head1)
            order = m.order(head1)
            vprint('swapping', head1, end=' ')
            idents = list(sec[0] for sec in
                select(m, m.sections(), types=[ptype], orders=[order], sec=True))
            idents = filter(lambda item: not m.is_fork(item), idents)
            idents.remove(head1)
            if not idents:
                vprint('... ignored')
                err = 1
                continue
            head2 = np.random.choice(idents)
            vprint('with', head2)
            p1 = m.parent(head1)
            p2 = m.parent(head2)
            b1 = m.copy(head1)
            b2 = m.copy(head2)
            v1 = m.coord(head1) - m.coord(p1)
            v2 = m.coord(head2) - m.coord(p2)
            axis, angle = rotation(v1, v2)
            b1.rotate(axis, angle)
            b1.translate(m.coord(head2) - b1.coord(b1.root()))
            m.graft(p2, b1)
            axis, angle = rotation(v2, v1)
            b2.rotate(axis, angle)
            b2.translate(m.coord(head1) - b2.coord(b2.root()))
            m.graft(p1, b2)
            m.prune(head1)
            m.prune(head2)
        m.renumber()
        vprint('nodes renumbered')
    if args.twist:
        sections = select(m, m.sections(), types=types, orders=orders, degrees=degrees, sec=True)
        bifurcations = list(filter(m.is_bifurcation, list(sec[-1] for sec in sections)))
        for i in range(args.twist):
            if bifurcations:
                ident = np.random.choice(bifurcations)
                bifurcations.remove(ident)
                parent = m.parent(ident)
                c0 = m.coord(ident)
                axis = c0 - m.coord(parent)
                angle = np.random.uniform(np.pi/4, np.pi*7/4)
                vprint('twisting', ident, 'by', angle)
                for child in m.children(ident):
                    m.translate(-c0, child)
                    m.rotate(axis, angle, child)
                    m.translate(c0, child)
            else:
                vprint('cannot twist at iteration', i)
                err = 1
                break
    m.save(args.out)
    return err


if __name__ == '__main__':
    sys.exit(main(parse_args()))
