#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
from os import path
import gzip
import argparse
import sys
import gzip
from collections import defaultdict
import unicodedata
import re
import pandas as pd
import numpy as np
from Bio import SeqIO
from Bio.SeqIO.QualityIO import FastqGeneralIterator
from Bio.Seq import Seq
from typing import Dict, List, Union, Tuple
import beret

import logging

logging.basicConfig(level=logging.INFO,
                     format='%(levelname)-5s @ %(asctime)s:\n\t %(message)s \n',
                     datefmt='%a, %d %b %Y %H:%M:%S',
                     stream=sys.stderr,
                     filemode="w"
                     )
error   = logging.critical
warn    = logging.warning
debug   = logging.debug
info    = logging.info


_ROOT = os.path.abspath(os.path.dirname(__file__))

class InputFileError(Exception):
    pass

####Support functions###

def check_file(filename):
    try:
        with open(filename): pass
    except IOError:
        raise Exception('I cannot open the file: '+filename)
 

def check_library(library_name):
        try:
                return __import__(library_name)
        except:
                error('You need to install %s module to use CRISPRessoCount!' % library_name)
                sys.exit(1)


def slugify(value): #adapted from the Django project
    
    value = unicodedata.normalize('NFKD', unicode(value)).encode('ascii', 'ignore')
    value = unicode(re.sub('[^\w\s-]', '_', value).strip())
    value = unicode(re.sub('[-\s]+', '-', value))
    
    return str(value)




def read_is_good_quality(record: SeqIO.SeqRecord, 
        min_bp_quality = 0, 
        min_single_bp_quality = 0, 
        qend = -1):
    mean_quality_pass = np.array(record.letter_annotations["phred_quality"])[:qend].mean() >= min_bp_quality
    min_quality_pass = np.array(record.letter_annotations["phred_quality"])[:qend].min()>=min_single_bp_quality
    return(mean_quality_pass and min_quality_pass)

        



def find_wrong_nt(sequence):
    return(list(set(sequence.upper()).difference(set(['A','T','C','G','N']))))


def mask_sequence_positions(seq: str, pos: np.array) -> str:
    return("".join([seq[i] if i in pos else "N" for i in range(len(seq))]))






def get_input_parser():
    """Get the input data"""
    print('  \n~~~beretCount~~~')
    print('-Utility to perform sgRNA and reporter count from CRISPR base editors-')
    print(r'''
          )                                             )
         (           ________________________          (
        __)__       | __   __            ___ |        __)__
     C\|     \      |/  ` /  \ |  | |\ |  |  |     C\|     \
       \     /      |\__, \__/ \__/ | \|  |  |       \     /
        \___/       |________________________|        \___/
    ''')
    
    
    print('\n[Luca Pinello 2017, Jayoung Ryu 2021, send bugs, suggestions or *green coffee* to jayoung_ryu AT g DOT harvard DOT edu]\n\n')
    
    
    parser = argparse.ArgumentParser(description='CRISPRessoCount parameters',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--R1', type=str,  help='fastq file for read 1', required=True,default='Fastq filename' )
    parser.add_argument('--R2', type=str,  help='fastq file for read 2, sorted as the same name order as in --R1 file.', required=True, default='Fastq filename' )
    parser.add_argument('-b', '--edited_base', type = str, required = True, help = 'For base editors, the base that should be ignored when matching the gRNA sequence')
    parser.add_argument('-f','--sgRNA_filename', type=str, required = True, help='''sgRNA description file. The format requires three columns: gRNA, Reporter, gRNA_barcode.''')

    #optional
    parser.add_argument('--guide_start_seq', type = str, help = "Guide starts after this sequence in R1", default = "GGAAAGGACGAAACACCG")
    parser.add_argument('-r', '--count_reporter', help = "Count reporter edits.", action = 'store_true')
    parser.add_argument('-q','--min_average_read_quality', type=int, help='Minimum average quality score (phred33) to keep a read', default=30)
    parser.add_argument('-s','--min_single_bp_quality', type=int, help='Minimum single bp score (phred33) to keep a read', default=0)
    parser.add_argument('-n','--name',  help='Output name', default='')
    parser.add_argument('-o','--output_folder',  help='', default='')
    parser.add_argument('-l', '--reporter_length', type = int, help = "length of the reporter", default = 32)
    parser.add_argument('--keep_intermediate',help='Keep all the  intermediate files',action='store_true')
    parser.add_argument('--qstart_R1', help='Start position of the read when filtering for quality score of the read 1', type = int, default = 0)
    parser.add_argument('--qend_R1', help = 'End position of the read when filtering for quality score of the read 1', type = int, default = 47)
    parser.add_argument('--qstart_R2', help = 'Same as qstart_R1, for read 2 fastq file', default = 0)
    parser.add_argument('--qend_R2', help = 'Same as qstart_R2, for read 2 fastq file', default = 36)
    parser.add_argument('--gstart_reporter', help = "Start position of the guide sequence in the reporter", type = int, default = 6)
    parser.add_argument('--match_target_pos', help = "Only count the edit in the exact target position.", action = 'store_true')
    
    parser.add_argument('--guide_bc', help = 'Construct has guide barcode', default = True)
    parser.add_argument('--guide_bc_len', help = 'Guide barcode sequence length at the beginning of the R2', type = str, default = 4)
    parser.add_argument('--offset', help = 'Guide file has offest column that will be added to the relative position of reporters.', action = 'store_true')
    parser.add_argument('--align_fasta', help = 'gRNA is aligned to this sequence to infer the offset. Can be used when the exact offset is not provided.', type = str, default = '')
    
    parser.add_argument('-a', '--count_allele', help = 'count gRNA alleles', action = "store_true")
    parser.add_argument('-g', "--count_guide_edits", help = "count the self editing of guides", action = 'store_true')
    parser.add_argument('-m', "--count_guide_reporter_alleles", help = "count the matched allele of guide and reporter edit", action = 'store_true')

    return(parser)


def check_arguments(args):
    """Check the argument validity of the ArgumentParser"""
    check_file(args.R1)
    check_file(args.R2)

    if args.sgRNA_filename:
        check_file(args.sgRNA_filename)

    # Edited base should be one of A/C/T/G
    if args.edited_base.upper() not in ['A', 'C', "T", "G"]:
        raise ValueError("The edited base should be one of A/C/T/G, {} provided.".format(args.edited_base))
    else:
        edited_base = args.edited_base.upper()
    info('Using specified edited base: %s' % edited_base)

    read_length = get_first_read_length(args.R1)

    # Check if positions of guide and quality control is valid
    NotImplemented

    if not (args.qstart_R1 < read_length and args.qstart_R2 < read_length):
        raise ValueError("The start position of base quality filter is not nonnegative ({} for R1, {} for R2 provided)".format(args.qstart_R1, args.qstart_R2))
    
    if not (args.qend_R1 < read_length and args.qend_R2 < read_length):
        raise ValueError("The start position of base quality filter is not nonnegative ({} for R1, {} for R2 provided)".format(args.qstart_R1, args.qstart_R2))

    if args.qend_R2 != args.guide_bc_len + args.reporter_length :
        warn("Quality of R2 checked up until {}bp, while the length of guide barcode and reporter combined is {}bp.".format(
            args.qend_R2, args.guide_bc_len + args.reporter_length))
    info("Using guide barcode length {}, guide start '{}'".format(args.guide_bc_len, args.guide_start_seq)) 
    #normalize name and remove not allowed characters
    if args.name:   
        clean_name=slugify(args.name)
        if args.name!= clean_name:
               warn('The specified name %s contained characters not allowed and was changed to: %s' % (args.name,clean_name))
               args.name=clean_name

    if args.offset:
        df = pd.read_csv(args.sgRNA_filename)
        if not 'offset' in df.columns:
            raise InputFileError("Offset option is set but the input file doesn't contain the offset column.")
        if len(args.align_fasta) > 0:
            error("Can't have --offset and --align_fasta option together.")

    info("Done checking input arguments.")

    return(args)


def get_first_read_length(fastq_filename):
    if fastq_filename.split(".")[-1] == "gz":
        handle = gzip.open(fastq_filename, "rt")
    else:
        handle = fastq_filename
    
    for record in SeqIO.parse(handle, "fastq"):
        return(len(record))
    raise InputFileError("Provided R1 file doesn't have any read to parse")


if __name__ == '__main__':
    parser = get_input_parser()
    args = parser.parse_args()

    args = check_arguments(args)

    args_dict = vars(args)
    base_editing_map = {"A":"G", "C":"T"}
    edited_from = args_dict["edited_base"]
    edited_to = base_editing_map[edited_from]
    match_target_pos = args_dict["match_target_pos"]

    counter = beret.pp.GuideEditCounter(**args_dict)
    counter.check_filter_fastq()
    
    counter.get_counts()
    if counter.count_reporter_edits:
        pass
        #counter.screen.get_edit_mat_from_uns(edited_from, edited_to, match_target_pos)
    counter.screen.write("{}.h5ad".format(counter.output_dir))
    #counter.screen.to_Excel("{}.xlsx".format(counter.output_dir))
    info("Output written at:\n {}.h5ad,\n {}.xlsx".format(counter.output_dir, counter.output_dir))
    info('All Done!')
    print(r'''
          )                                             )
         (           ________________________          (
        __)__       | __   __            ___ |        __)__
     C\|     \      |/  ` /  \ |  | |\ |  |  |     C\|     \
       \     /      |\__, \__/ \__/ | \|  |  |       \     /
        \___/       |________________________|        \___/
    ''')
    sys.exit(0)
