#!/usr/bin/env python

# Created on Fri Sep 03 11:30:18 2021
# Author: XiaoTao Wang

## Required modules

import argparse, sys, traceback, neoloop
import numpy as np

currentVersion = neoloop.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Plot genome-wide CNV profiles and segments.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')

    parser.add_argument('--cnv-profile',
                        help='''Copy number profile in bedGraph format.''')
    parser.add_argument('--cnv-segment',
                        help='''Copy number segmentation in bedGraph format. Optional''')
    parser.add_argument('--output-figure-name',
                        help='''Output file path.''')
    parser.add_argument('-C', '--chroms', nargs = '*', default = ['#', 'X'],
                        help = 'List of chromosome labels. Only Hi-C data within the specified '
                        'chromosomes will be included. Specially, "#" stands for chromosomes '
                        'with numerical labels. "--chroms" with zero argument will include '
                        'all chromosome data.')
    parser.add_argument('--dot-size', default=0.5, type=float, help='''Size of the dots.''')
    parser.add_argument('--dot-alpha', default=0.1, type=float, help='''The alpha blending value of the dots,
                        between 0 (transparent) and 1 (opaque).''')
    parser.add_argument('--line-width', default=1, type=float, help='''Width of the CNV segment lines.''')
    parser.add_argument('--boundary-width', default=1, type=float,
                        help='''Width of the chromosome boundaries.''')
    parser.add_argument('--label-size', default=6, type=float, help='''Label fontsize.''')
    parser.add_argument('--tick-label-size', default=3, type=float, help='''Fontsize of the tick labels.''')
    parser.add_argument('--maximum-value', type=float, help='''Maximum value.''')
    parser.add_argument('--minimum-value', type=float, help='''Minimum value.''')
    parser.add_argument('--fig-width', type=float, default=5, help='''Width of the figure.''')
    parser.add_argument('--fig-height', type=float, default=1, help='''Height of the figure.''')
    parser.add_argument('--dpi', default=500, type=int)
    parser.add_argument('--clean-mode', action = 'store_true')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:

        from neoloop.cnv.segcnv import HMMsegment
        import matplotlib.pyplot as plt

        # infer resolution from the CNV profile
        F = open(args.cnv_profile, 'r')
        tmp = F.readline().split()
        res = int(tmp[2]) - int(tmp[1])

        # read CNV profiles
        cnv_signals = HMMsegment(args.cnv_profile, res).bin_cnv
        if not args.cnv_segment is None:
            cnv_seg = read_segment(args.cnv_segment, cnv_signals, res, minlen=3)

        # included chromosomes
        chroms = []
        for c in cnv_signals:
            if (not args.chroms) or (c.isdigit() and '#' in args.chroms) or (c in args.chroms):
                chroms.append(c)
        sorted_chroms = sort_chromlabels(chroms)

        # plot
        fig = plt.figure(figsize=(args.fig_width, args.fig_height))
        ax = fig.add_subplot(111)
        xbars = [0]
        start = 0
        for c in sorted_chroms:
            end = start + len(cnv_signals[c])
            coords = np.arange(start, end)
            values = cnv_signals[c]
            ax.scatter(coords[values>0], np.log2(values[values>0]), s=args.dot_size, alpha=args.dot_alpha,
                       edgecolors='none', c='k')
            if not args.cnv_segment is None:
                for x, y, v in cnv_seg[c]:
                    ax.hlines(v, x+start, y+start, colors='r', lw=args.line_width)
            start = end
            xbars.append(start)
        ylim = ax.get_ylim()
        ax.vlines(xbars, ylim[0], ylim[1], linestyles=':', lw=args.boundary_width)

        xticks = []
        for s, e in zip(xbars[:-1], xbars[1:]):
            xticks.append((s+e)/2)
        
        if args.clean_mode:
            tmp = [xticks[i] for i in range(0, len(xticks), 2)]
            xticks = tmp
            tmp = [sorted_chroms[i] for i in range(0, len(sorted_chroms), 2)]
            chroms = tmp
        else:
            chroms = sorted_chroms

        ax.set_xticks(xticks)
        ax.set_xticklabels(chroms, fontsize=args.tick_label_size)
        ax.set_xlabel('Chromosome Labels', fontsize=args.label_size)
        ax.set_ylabel('Log2 Copy Numbers', fontsize=args.label_size)

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_linewidth(0.7)
        ax.spines['left'].set_linewidth(0.7)

        ax.xaxis.set_tick_params(width=0.7, labelsize=args.tick_label_size, pad=2, length=0)
        ax.yaxis.set_tick_params(width=0.7, labelsize=args.tick_label_size, pad=2, length=2)
        ymin, ymax = ax.get_ylim()
        if not args.maximum_value is None:
            ymax = args.maximum_value
        if not args.minimum_value is None:
            ymin = args.minimum_value
        ax.set_ylim(ymin, ymax)
        plt.savefig(args.output_figure_name, dpi=args.dpi, bbox_inches='tight')
        plt.close()

def read_segment(seg_fil, cnv_signals, res, minlen=3):

    data = {}
    with open(seg_fil, 'r') as source:
        for line in source:
            parse = line.rstrip().split()
            chrom = parse[0].lstrip('chr')
            si = int(parse[1]) // res
            ei = int(parse[2]) // res
            if not chrom in cnv_signals:
                continue
            if not chrom in data:
                data[chrom] = []
            
            if ei - si > minlen:
                tmp = cnv_signals[chrom][si:ei]
                mask = tmp==0
                zero_ratio = mask.sum() / mask.size
                if zero_ratio > 0.8:
                    continue
                data[chrom].append((si, ei, np.log2(np.median(tmp[tmp!=0]))))
    
    return data

def find_digit_parts(chrname):

    collect = []
    for s in chrname[::-1]:
        if s.isdigit():
            collect.append(s)
        else:
            break
    
    if len(collect):
        digit_parts = int(''.join(collect[::-1]))
        return digit_parts
    else:
        return

def sort_chromlabels(chrnames):

    num_table = []
    nonnum_names = []
    for n in chrnames:
        tmp = find_digit_parts(n)
        if tmp is None:
            nonnum_names.append(n)
        else:
            num_table.append((tmp, n))

    num_table.sort()
    sorted_names = [s[1] for s in num_table]

    for s in ['M', 'Y', 'X']:
        for idx, n in enumerate(nonnum_names):
            if n.endswith(s):
                nonnum_names.pop(idx)
                nonnum_names.insert(0, n)
                break
    sorted_names = sorted_names + nonnum_names

    return sorted_names

if __name__ == '__main__':
    run()