#! /usr/bin/env python3

from typing import *

from hashlib import sha1, sha256, sha512, md5
from io import TextIOBase, BufferedIOBase
import math
import regex as re

import tagstats as tagmatches
from toolz.itertoolz import no_default

from .seqtools import commonsubseq, align, seq2grams, enumeratesubseqs
from .rangetools import intersect

def commonsubstr(a: str, b: str) -> str:
    return ''.join(commonsubseq(list(a), list(b)))


def editdist(a: str, b: str, bound: float = math.inf) -> float:
    res = align(list(a), list(b), bound=bound)
    return res[0] if res else None


def tagstats(tags: Iterable[str], lines: Iterable[str], separator: str = None) -> Mapping[str, int]:
    tagmatches.tagstats.tokenizer = None if separator is None else re.compile(separator)

    return {
        tag: sum(matches)
        for tag, matches in tagmatches.compute(
            lines,
            {tag: [tag] for tag in tags}
        ).items()
    }


def str2grams(s: str, n: int, pad: str = '') -> Iterable[str]:
    if pad != '' and len(pad) > 1:
        raise ValueError

    if pad == '':
        pad = no_default

    for seq in seq2grams(s, n, pad):
        yield ''.join(seq)


def rewrite(s: str, regex: Any, template: str, transformations: Optional[Mapping[Union[str, int], Callable[[str], str]]] = None) -> str:
    r = re.compile(regex) if isinstance(regex, str) else regex

    m = r.fullmatch(s)

    gs = m.groups()
    gd = m.groupdict()
    if transformations:
        gs = [
            transformations.get(i, lambda x: x)(v)
            for i, v in enumerate(gs)
        ]
        gd = {
            k: transformations.get(k, lambda x: x)(v)
            for k, v in gd.items()
        }

    return template.format(*gs, **gd)


def learnrewrite(src: str, dst: str, minlen: int = 3) -> Tuple[str, str]:
    def replace(target, poss, forregex):
        for k, i, j in sorted(poss, key=lambda p: p[1], reverse=True):
            target = "{}{}{}".format(
                target[:i],
                ("({})" if forregex else "{{{}}}").format(r".*" if forregex else k),
                target[j:]
            )

        return target


    xs: List[Tuple[int, int, int]] = []

    lastj = 0
    for i in range(len(src)):
        if i < lastj:
            continue

        currp = p = -1
        for j in range(i + 1, len(src)):
            s = src[i:j]

            p = dst.find(s)
            if p < 0:
                break

            currp = p
            lastj = j

        if currp >= 0 and lastj - i >= minlen:
            xs.append((i, currp, lastj - i))

    ys: List[Tuple[int, int, int]] = []
    for x, y, l in sorted(xs, key=lambda p: p[2], reverse=True):
        if any(
                intersect((y, y + l), (yy, yy + ll), allowempty=True) is not None
                for _, yy, ll in ys
            ):
            continue

        ys.append((x, y, l))

    ys = sorted(ys, key=lambda p: p[0])

    return (
        replace(src, (
            (k, x, x + l) for k, (x, _, l) in enumerate(ys)
        ), forregex=True),
        replace(dst, (
            (k, y, y + l) for k, (_, y, l) in enumerate(ys)
        ), forregex=False)
    )


def extract(s: str, entities: Iterable[str], useregex=False, ignorecase=True) -> Iterable[str]:
    for m in re.compile(
            r"\b(?:{})\b".format(r"|".join(
                e if useregex else re.escape(e).replace(' ', r"s+") for e in entities
            )),
            re.I if ignorecase else 0
        ).finditer(s):
        yield m.group(0)


def __findeqtagpair(s: str, pos: int, tag: str) -> Optional[str]:
    for match in re.finditer(r"{0}{1}{0}".format(re.escape(tag), r".*?"), s):
        if match.start() <= pos < match.end():
            return match.group()

    return None


def findtagpair(s: str, pos: int, tag: str, closetag: Optional[str] = None) -> Optional[str]:
    if closetag is None or tag == closetag:
        return __findeqtagpair(s, pos, tag)

    startposs = []

    currpos = 0
    slen = len(s)
    while currpos < slen:
        if s.find(tag, currpos, currpos + len(tag)) != -1:
            startposs.append(currpos)
            currpos += len(tag)
            continue

        if s.find(closetag, currpos, currpos + len(closetag)) != -1:
            if startposs:
                startpos = startposs.pop()
                endpos = currpos + len(closetag)

                if startpos <= pos < endpos:
                    return s[startpos:endpos]

            currpos += len(closetag)
            continue

        currpos += 1

    return None


def enumeratesubstrs(s: str) -> Iterable[str]:
    return map(str, enumeratesubseqs(s))


__renontext = re.compile(r"\W+", re.U)

def smartsplit(s: str) -> Tuple[Optional[str], Iterable[str]]:
    c: Counter = Counter()
    for sep in __renontext.findall(s):
        c.update([sep])
        c.update(set(enumeratesubstrs(sep)))

    if not c:
        return (None, [s])

    bestsep = max(
        c.items(),
        key=lambda p: (p[1], len(p[0]))
    )[0]

    return (bestsep, s.split(bestsep))


def __checksum(f: Any, func: Callable[[bytes], Any]) -> str:
    content: bytes

    if isinstance(f, str):
        content = f.encode("utf-8")
    elif isinstance(f, bytes):
        content = f
    elif isinstance(f, TextIOBase):
        content = f.read().encode("utf-8")
    elif isinstance(f, BufferedIOBase):
        content = f.read()

    return func(content).hexdigest()


def sha1sum(f: Any) -> str:
    return __checksum(f, sha1)


def sha256sum(f: Any) -> str:
    return __checksum(f, sha256)


def sha512sum(f: Any) -> str:
    return __checksum(f, sha512)


def md5sum(f: Any) -> str:
    return __checksum(f, md5)
