#!/usr/bin/env python

"View neuron morphology"

from __future__ import print_function, division
from builtins import list, sum, set, str
from morphon import Morph, Point, plot, select
import numpy as np
import argparse
import math

from matplotlib import pyplot as plt


def _draw_sphere(ax, origin, radius, n=20, **kwargs):
    u = np.linspace(0, np.pi, n)
    v = np.linspace(0, 2*np.pi, n)
    x = origin[0] + radius*np.outer(np.sin(u), np.sin(v))
    y = origin[1] + radius*np.outer(np.sin(u), np.cos(v))
    z = origin[2] + radius*np.outer(np.cos(u), np.ones_like(v))
    ax.plot_surface(x, z, y, **kwargs)


def _draw_cylinder(ax, center1, center2, radius1, radius2, n=8, **kwargs):
    v = center2 - center1
    mag = np.linalg.norm(v)
    v = v / mag
    not_v = np.array([1, 1, 0])
    if (v == not_v).all():
        not_v = np.array([0, 1, 0])
    n1 = np.cross(v, not_v)
    n1 /= np.linalg.norm(n1)
    n2 = np.cross(v, n1)
    t = np.linspace(0, mag, 2)
    theta = np.linspace(0, 2*np.pi, n+1)
    t, theta = np.meshgrid(t, theta)
    R = np.linspace(radius1, radius2, 2)
    X, Y, Z = [center1[i] + v[i]*t + R*np.sin(theta)*n1[i] 
        + R*np.cos(theta)*n2[i] for i in [0, 1, 2]]
    ax.plot_surface(X, Z, Y, **kwargs)


def parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('file', type=str, nargs='+', help='input morphology file (swc)')
    parser.add_argument('-j', dest='proj', type=str, help='projection (xy, xz, zy) [3d]')
    parser.add_argument('-p', dest='type', type=str, default='1,2,3,4',
        help='list of point types (csv) [1,2,3,4]')
    parser.add_argument('-d', dest='diam', action='store_true', help='show dendrite diameter')
    parser.add_argument('-m', dest='mark', action='append', help='list of marker IDs (csv)')
    parser.add_argument('-mi', dest='marker_ids', action='store_true', help='show marker IDs')
    parser.add_argument('-b', dest='branch', action='append', help='list of branch IDs (csv)')
    parser.add_argument('-c', dest='color', action='store_true', help='colorize multiple plots')
    parser.add_argument('-a', dest='no_axes', action='store_true', help='disable axes')
    parser.add_argument('-s', dest='scale', type=float, default=0.0, help='show scale bar, um')
    parser.add_argument('-t', dest='title', type=str, help='set figure title')
    parser.add_argument('-o', dest='out', type=str, help='save image to file')
    return parser.parse_args()


def main(args):
    m = Morph()
    m.load(args.file[0])
    fig, ax = plot()
    if args.no_axes:
        ax.set_axis_off()
    if args.title:
        ax.set_title(args.title)
    if args.proj:
        if args.proj.lower() == 'xy':
            ax.view_init(0,-90.01)
            ax.set_ylabel('')
            ax.set_yticks([])
        elif args.proj.lower() == 'xz':
            ax.view_init(89.99,-90.01)
            ax.set_zlabel('')
            ax.set_zticks([])
        elif args.proj.lower() == 'zy':
            ax.view_init(0.01,0.01)
            ax.set_xlabel('')
            ax.set_xticks([])
    types = set(int(item) for item in args.type.replace(',', ' ').split(' '))
    if len(args.file) > 1:
        from cycler import cycler
        from matplotlib import colors as mcolors
        color_cycle = cycler(color=mcolors.TABLEAU_COLORS.values())
        icolor = color_cycle()
        for morphology, cycle in zip(args.file[1:], icolor):
            n = Morph()
            n.load(morphology)
            if args.color:
                plot(n, select(n, n.sections(with_parent=True), types=types, sec=True),
                    ax, color=cycle['color'], linewidth=1)
            else:
                plot(n, select(n, n.sections(with_parent=True), types=types, sec=True),
                    ax, color='#cccccc', linewidth=3, alpha=0.5)
    if 2 in types:
        plot(m, select(m, m.sections(with_parent=True), types=[2], sec=True), ax, color='grey',  linewidth=0.5)
    if 3 in types:
        plot(m, select(m, m.sections(with_parent=True), types=[3], sec=True), ax, color='b', linewidth=1)
    if 4 in types:
        plot(m, select(m, m.sections(with_parent=True), types=[4], sec=True), ax, color='g', linewidth=1)
    if types.difference([1,2,3,4]):
        plot(m, select(m, m.sections(with_parent=True), types=types.difference([1,2,3,4])),
            ax, color='lightgrey', linewidth=1)
    if 1 in types:
        root = m.root()
        if not args.diam:
            _draw_sphere(ax, m.coord(root), m.radius(root), n=20, color='grey', alpha=0.2)
        plot(m, [[root]], ax, color='k', marker='o')
    if args.diam:
        for sec in select(m, m.sections(), types=types.difference([1,2]), sec=True):
            for item in sec:
                parent = m.parent(item)
                _draw_cylinder(ax, m.coord(parent), m.coord(item), m.radius(parent), m.radius(item), 
                        n=6, color='grey', linewidth=0, antialiased=False, alpha=0.1)
    if args.mark:
        for markers in args.mark:
            markers = list(int(item) for item in markers.replace(',', ' ').split(' '))
            plot(m, [markers], ax, marker='o', linestyle='', alpha=0.5)
            if args.marker_ids:
                plt.rcParams.update({'font.size': 8})
                for marker in markers:
                    x, y, z = m.coord(marker)
                    ax.text3D(x, z, y, '  ' + str(marker))
    if args.branch:
        for branches in args.branch:
            branches = list(int(item) for item in branches.replace(',', ' ').split(' '))
            for branch in branches:
                plot(m, m.sections(branch, with_parent=True), ax, color='r', linewidth=2)
            if args.marker_ids:
                plt.rcParams.update({'font.size': 8})
                for branch in branches:
                    x, y, z = m.coord(branch)
                    ax.text3D(x, z, y, '  ' + str(branch))
            plot(m, [branches], ax, marker='o', linestyle='', color='r', alpha=0.5)
    if args.scale > 0 and args.no_axes:
        xmax, ymax, zmax = ax.xy_dataLim.xmax, ax.xy_dataLim.ymax, ax.zz_dataLim.xmax
        xmin, ymin, zmin = ax.xy_dataLim.xmin, ax.xy_dataLim.ymin, ax.zz_dataLim.xmin
        if args.proj and args.proj.lower() == 'xz':
            ax.plot([xmax, xmax], [ymin, ymin+args.scale], [zmax, zmax], color='k', linewidth=3)
            ax.text3D(xmax, ymin+args.scale/2, zmax, '  ' + str(args.scale) + ' um')
        else:
            ax.plot([xmax, xmax], [zmax, zmax], [ymax, ymax-args.scale], color='k', linewidth=3)
            ax.text3D(xmax, zmax, ymax-args.scale/2, '  ' + str(args.scale) + ' um')
    plt.show() if not args.out else plt.savefig(args.out)


if __name__ == '__main__':
    main(parse_args())
