#!/usr/bin/env python

"Morphometric analysis"

from __future__ import print_function, division
from builtins import dict, list, sum, max, str
from collections import OrderedDict
from morphon import Morph, meter, std_features
import numpy as np
import argparse
import json
import math
import os


feature_defs = dict(
        angle='local bifurcation angle (mean)',
        area='total branch area',
        asym='metrical asymmetry (mean)',
        degree='branch degree (max)',
        depth='z-extent of reconstruction',
        diam='average diameter (median)',
        dist='euclidean distance to root (max)',
        height='y-extent of reconstruction',
        length='total branch length',
        nbifs='number of bifurcations',
        npoints='number of points',
        nsecs='number of sections',
        nstems='number of stems',
        ntips='number of terminations',
        order='branch order (max)',
        part='partition asymmetry (mean)',
        path='path distance to root (max)',
        pc1='principal component dimension',
        pc2='principal component dimension',
        pc3='principal component dimension',
        pca='principal component dimensions',
        seclen='section length (mean)',
        tort='section tortuosity (mean)',
        volume='total branch volume',
        width='x-extent of reconstruction'
    )


feature_list = str()
feature_list += 'standard features:\n'
for feature in sorted(std_features+['nstems']):
    feature_list += '  {:10s}  {}\n'.format(feature, feature_defs[feature])
feature_list += '\nadd-on features:\n'
feature_list += '  {:10s}  {}\n'.format('pca', feature_defs['pca'])


def parse_args():
    parser = argparse.ArgumentParser(description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter, epilog=feature_list)
    parser.add_argument('file', type=str, nargs='+', help='input morphology file (swc)')
    parser.add_argument('-f', dest='feature', type=str, default=[], action='append',
        help='morphometric feature')
    parser.add_argument('-a', dest='add', type=str, action='append',
        help='add feature to standard set')
    parser.add_argument('-p', dest='type', type=str, default='1,2,3,4',
        help='list of point types (csv) [1,2,3,4]')
    parser.add_argument('-o', dest='out', type=str, help='output file name (json)')
    return parser.parse_args()


ptmap = {1: 'soma', 2: 'axon', 3: 'dend', 4: 'apic'}


def _contains(a, b):
    return any(map(lambda x: x in b, a))


def main(args):
    m = Morph()
    recs = dict()
    features = args.feature if args.feature else std_features + ['nstems']
    if args.add:
        features += args.add
    for f in args.file:
        name = os.path.basename(f).split('.')[0]
        recs[name] = dict()
        m.load(f)
        metrics = OrderedDict()
        types = set(int(item) for item in args.type.replace(',', ' ').split(' '))
        ptypes = set(m.type(item) for item in m.nodes)
        if 1 in types:
            metrics['soma'] = dict()
            points = list(filter(lambda x: m.type(x)==1, m.nodes))
            npoints = len(points)
            if 'npoints' in features:
                metrics['soma']['npoints'] = npoints
            if 'length' in features:
                if npoints == 1:
                    length = m.diam(points[0])
                else:
                    length = sum(m.length(item) for item in points)
                metrics['soma']['length'] = length
            if 'volume' in features:
                if npoints == 1:
                    volume = 4/3*math.pi*m.radius(points[0])**3
                else:
                    volume = sum(m.volume(item) for item in points)
                metrics['soma']['volume'] = volume
            if 'area' in features:
                if npoints == 1:
                    area = 4*math.pi*m.radius(points[0])**2
                else:
                    area = sum(m.area(item) for item in points)
                metrics['soma']['area'] = area
            if 'diam' in features:
                if npoints == 1:
                    diam = m.diam(points[0])
                else:
                    diam = np.median(list(m.diam(item) for item in points))
                metrics['soma']['diam'] = diam
        for t in [2, 3, 4]:
            root = m.root()
            if t in types and t in ptypes:
                metrics[ptmap[t]] = dict()
                seclen = list()
                diam = list()
                tort = list()
                part = list()
                asym = list()
                angle = list()
                cmin = np.finfo(1.0).max * np.ones(3)
                cmax = np.finfo(1.0).min * np.ones(3)
                rmax = 0.0
                path = 0.0
                degree = 0
                order = 0
                nbifs = 0
                ntips = 0
                npoints = 0
                length = 0.0
                volume = 0.0
                area = 0.0
                stems = list(filter(lambda x: m.type(x)==t, m.children(root)))
                for stem in stems:
                    mm = meter(m, stem, features)
                    if _contains(features, ['width', 'height', 'depth']):
                        cmin = np.minimum(mm['bounds'][0], cmin)
                        cmax = np.maximum(mm['bounds'][1], cmax)
                    if 'dist' in features:
                        rmax = max(mm['dist'], rmax)
                    if 'path' in features:
                        path = max(mm['path'], path)
                    if 'npoints' in features:
                        npoints += mm['npoints']
                    if 'nbifs' in features:
                        nbifs += mm['nbifs']
                    if 'ntips' in features:
                        ntips += mm['ntips']
                    if 'degree' in features:
                        degree = max(mm['degree'], degree)
                    if 'order' in features:
                        order = max(mm['order'], order)
                    if 'length' in features:
                        length += mm['length']
                    if 'volume' in features:
                        volume += mm['volume']
                    if 'area' in features:
                        area += mm['area']
                    if _contains(features, ['seclen', 'nsecs']):
                        seclen.extend(mm['seclen'])
                    if 'diam' in features:
                        diam.extend(mm['diam'])
                    if 'tort' in features:
                        tort.extend(mm['tort'])
                    if 'part' in features:
                        part.extend(mm['part'])
                    if 'asym' in features:
                        asym.extend(mm['asym'])
                    if 'angle' in features:
                        angle.extend(mm['angle'])
                #
                if 'pca' in features:
                    coords = np.array(list(m.coord(item)
                        for item in filter(lambda x: m.type(x)==t, m.nodes)))
                    coords -= np.mean(coords, axis=0)
                    _, eigv = np.linalg.eig(np.cov(coords.transpose()))
                    pca = np.zeros(3)
                    for i in range(eigv.shape[1]):
                        scalar_projs = np.sort(np.array([np.dot(c, eigv[:, i]) for c in coords]))
                        pca[i] = scalar_projs[-1]
                        if scalar_projs[0] < 0.:
                            pca -= scalar_projs[0]
                #
                if 'width' in features:
                    metrics[ptmap[t]]['width'] = (cmax-cmin)[0]
                if 'height' in features:
                    metrics[ptmap[t]]['height'] = (cmax-cmin)[1]
                if 'depth' in features:
                    metrics[ptmap[t]]['depth'] = (cmax-cmin)[2]
                if 'dist' in features:
                    metrics[ptmap[t]]['dist'] = rmax
                if 'path' in features:
                    metrics[ptmap[t]]['path'] = path
                if 'degree' in features:
                    metrics[ptmap[t]]['degree'] = degree
                if 'order' in features:
                    metrics[ptmap[t]]['order'] = order
                if 'npoints' in features:
                    metrics[ptmap[t]]['npoints'] = npoints
                if 'nstems' in features:
                    metrics[ptmap[t]]['nstems'] = len(stems)
                if 'nbifs' in features:
                    metrics[ptmap[t]]['nbifs'] = nbifs
                if 'ntips' in features:
                    metrics[ptmap[t]]['ntips'] = ntips
                if 'nsecs' in features:
                    metrics[ptmap[t]]['nsecs'] = len(seclen)
                if 'diam' in features:
                    metrics[ptmap[t]]['diam'] = np.median(diam)
                if 'seclen' in features:
                    metrics[ptmap[t]]['seclen'] = np.mean(seclen)
                if 'tort' in features:
                    metrics[ptmap[t]]['tort'] = np.mean(tort)
                if 'part' in features:
                    metrics[ptmap[t]]['part'] = np.mean(part)
                if 'asym' in features:
                    metrics[ptmap[t]]['asym'] = np.mean(asym)
                if 'angle' in features:
                    metrics[ptmap[t]]['angle'] = np.mean(angle)
                if 'length' in features:
                    metrics[ptmap[t]]['length'] = length
                if 'volume' in features:
                    metrics[ptmap[t]]['volume'] = volume
                if 'area' in features:
                    metrics[ptmap[t]]['area'] = area
                if 'pca' in features:
                    metrics[ptmap[t]]['pc1'] = pca[0]
                    metrics[ptmap[t]]['pc2'] = pca[1]
                    metrics[ptmap[t]]['pc3'] = pca[2]
        #
        print(name)
        if not args.out:
            for (pt, stat) in metrics.items():
                for (f, v) in sorted(stat.items()):
                    print('{} {:10g}  {:10} {}'.format(pt, v, f, feature_defs[f]))
            print()
        recs[name] = metrics
        m.clear()
    if args.out:
        with open(args.out, 'w') as fp:
            json.dump(recs, fp, indent=4, sort_keys=True)


if __name__ == '__main__':
    main(parse_args())
