#!/usr/bin/env python

# Created on Mon Jan 06 11:14:13 2020
# Author: XiaoTao Wang

## Required modules

import argparse, sys, logging, logging.handlers, traceback, neoloop, os

currentVersion = neoloop.__version__


def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='''Assemble complex SVs given an individual SV list.''',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Version
    parser.add_argument('-v', '--version', action='version',
                        version=' '.join(['%(prog)s',currentVersion]),
                        help='Print version number and exit.')
    
    # Output
    parser.add_argument('-O', '--output', help='Output prefix.')

    # Input
    parser.add_argument('-B', '--break-points', help='''Path to a TXT file containing pairs
                        of break points.''')
    parser.add_argument('-H', '--hic', nargs='+', help='''List of cooler URIs. If URIs at multiple
                        resolutions are provided, the program will first detect complex SVs from
                        each individual resolution, and then combine results from all resolutions
                        in a non-redundant way.''')
    parser.add_argument('-R', '--region-size', default=5000000, type=int,
                        help = '''The extended genomic span of SV break points.(bp)''')
    parser.add_argument('--minimum-size', default=500000, type=int,
                        help = '''For intra-chromosomal SVs, only SVs that are larger than this size will be considered by the pipeline.''')
    parser.add_argument('--balance-type', default='CNV', choices=['CNV', 'ICE'],
                        help = 'Normalization method.')
    parser.add_argument('--protocol', default='insitu', choices=['insitu', 'dilution'],
                        help='''Experimental protocol of your Hi-C.''')
    parser.add_argument('--nproc', default=1, type=int, help='Number of worker processes.')
    parser.add_argument('--logFile', default = 'assembleSVs.log', help = '''Logging file name.''')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    
    return args, commands


def run():

    # Parse Arguments
    args, commands = getargs()
    # Improve the performance if you don't want to run it
    if commands[0] not in ['-h', '-v', '--help', '--version']:
        ## Root Logger Configuration
        logger = logging.getLogger()
        logger.setLevel(10)
        console = logging.StreamHandler()
        filehandler = logging.handlers.RotatingFileHandler(args.logFile,
                                                           maxBytes=100000,
                                                           backupCount=5)
        # Set level for Handlers
        console.setLevel('INFO')
        filehandler.setLevel('INFO')
        # Customizing Formatter
        formatter = logging.Formatter(fmt = '%(name)-25s %(levelname)-7s @ %(asctime)s: %(message)s',
                                      datefmt = '%m/%d/%y %H:%M:%S')
        
        ## Unified Formatter
        console.setFormatter(formatter)
        filehandler.setFormatter(formatter)
        # Add Handlers
        logger.addHandler(console)
        logger.addHandler(filehandler)
        
        ## Logging for argument setting
        arglist = ['# ARGUMENT LIST:',
                   '# Output Prefix = {0}'.format(args.output),
                   '# Break Points = {0}'.format(args.break_points),
                   '# Minimum fragment size = {0}bp'.format(args.minimum_size),
                   '# Cooler URI = {0}'.format(args.hic),
                   '# Extended Genomic Span = {0}bp'.format(args.region_size),
                   '# Balance Type = {0}'.format(args.balance_type),
                   '# Experimental protocol = {0}'.format(args.protocol),
                   '# Number of Processes = {0}'.format(args.nproc),
                   '# Log file name = {0}'.format(args.logFile)
                   ]
        argtxt = '\n'.join(arglist)
        logger.info('\n' + argtxt)

        from neoloop.assembly import assembleSV
        from neoloop.util import calculate_expected
        import cooler, subprocess

        if args.balance_type == 'CNV':
            balance = 'sweight'
        elif args.balance_type == 'ICE':
            balance = 'weight'
        else:
            balance = False

        try:
            # load Hi-C matrix
            cools = {}
            for path in args.hic:
                lib = cooler.Cooler(path)
                res = lib.binsize
                cools[res] = path
            
            by_res = {}
            for res in cools:
                clr = cooler.Cooler(cools[res])
                logger.info('Current resolution: {0}'.format(res))
                logger.info('Calculate the global average contact frequencies at each genomic distance ...')
                chroms = []
                for c in clr.chromnames:
                    if (not '_' in c) and (not 'M' in c) and (not 'X' in c) and \
                       (not 'Y' in c) and (not 'MT' in c) and (not 'EBV' in c):
                       chroms.append(c)
                expected = calculate_expected(clr, chroms, maxdis=5000000, balance=balance, nproc=args.nproc)
                logger.info('Done')
                logger.info('Filtering SVs by checking distance decay of chromatin contacts across SV breakpoints ...')
                work = assembleSV(clr, args.break_points, span=args.region_size, col=balance,
                                  n_jobs=args.nproc, protocol=args.protocol, minIntra=args.minimum_size,
                                  expected_values=expected)
                logger.info('{0} SVs left'.format(len(work.queue)))
                logger.info('Building SV connecting graph ...')
                work.build_graph()
                logger.info('Discovering and re-ordering complex SVs ...')
                work.find_complexSV()
                allele = len(work.alleles)
                logger.info('Called {0} connected assemblies'.format(allele))
                outfil = '{0}.assemblies.{1}.txt'.format(args.output, res)
                work.output_all(outfil)
                by_res[res] = outfil
            
            outfil = '{0}.assemblies.txt'.format(args.output)
            if len(by_res) == 1:
                command = ['mv', by_res[res], outfil]
                subprocess.check_call(' '.join(command), shell=True)
            else:
                logger.info('Merge complex SV results from multiple resolutions ...')
                As, Cs = combine(by_res)
                with open(outfil, 'w') as out:
                    pool = []
                    for k in As:
                        pool.append((len(k),) + k + tuple(As[k]))
                    pool.sort(reverse=True)
                    for i, record in enumerate(pool):
                        tmp = ['A{0}'.format(i)] + list(record[1:])
                        out.write('\t'.join(tmp)+'\n')
                    
                    pool = []
                    for k in Cs:
                        pool.append((len(k),) + k + tuple(Cs[k]))
                    for i, record in enumerate(pool):
                        tmp = ['C{0}'.format(i)] + list(record[1:])
                        out.write('\t'.join(tmp)+'\n')

                for res in by_res:
                    os.remove(by_res[res])

            logger.info('Done')
        except:
            traceback.print_exc(file = open(args.logFile, 'a'))
            sys.exit(1)

def _issubset(p1, p2):

    import numpy as np

    index = []
    for i in p1:
        if i in p2:
            index.append(p2.index(i))
        else:
            return False
    
    idx = np.r_[index]
    df = np.diff(idx)

    return np.all(df==1)

def parse_assemblies(fil):

    As = {}
    Cs = {}
    with open(fil, 'r') as source:
        for line in source:
            parse = line.rstrip().split()
            key = tuple(parse[1:-2])
            values = parse[-2:]
            if parse[0].startswith('A'):
                As[key] = values
            if parse[0].startswith('C'):
                Cs[key] = values
    
    return As, Cs

def combine(by_res):

    As, Cs = parse_assemblies(by_res[min(by_res)])
    for r in by_res:
        if r == min(by_res):
            continue
        tA, tC = parse_assemblies(by_res[r])
        for k in tA:
            if not k in As:
                As[k] = tA[k]
        for k in tC:
            if not k in Cs:
                Cs[k] = tC[k]

    # remove redundant assembly
    Apool = [(len(k), k) for k in As]
    Apool.sort(reverse=True)
    alleles = []
    for _, p in Apool:
        for v in alleles:
            if _issubset(p, v):
                break
        else:
            alleles.append(p)
    Afinal = {}
    for k in alleles:
        Afinal[k] = As[k]
    
    return Afinal, Cs

if __name__ == '__main__':
    run()