import itertools
import logging
import logging.config
import multiprocessing
import os
import traceback

__version__ = "1.2.0"

__author__ = "Luís Gomes"


LOG = logging.getLogger(__name__)


class Job(object):
    def __init__(self, serial, data, status):
        self.serial = serial
        self.data = data
        self.status = status

    def __str__(self):
        return f"Job {self.serial}, {self.status}"

    def __eq__(self, other):
        return isinstance(other, Job) and other.serial == self.serial


class Status(object):
    def __init__(self, manager):
        self.d = manager.dict(
            producing=False,
            produced=0,
            processed=0,
            failed=0,
            consumed=0,
        )

    @property
    def producing(self):
        return self.d["producing"]

    def started_producing(self):
        self.d["producing"] = True

    def stopped_producing(self):
        self.d["producing"] = False

    @property
    def produced(self):
        return self.d["produced"]

    def incr_produced(self):
        self.d["produced"] += 1

    @property
    def processed(self):
        return self.d["processed"]

    def incr_processed(self):
        self.d["processed"] += 1

    @property
    def failed(self):
        return self.d["failed"]

    def incr_failed(self):
        self.d["failed"] += 1

    @property
    def consumed(self):
        return self.d["consumed"]

    def incr_consumed(self):
        self.d["consumed"] += 1

    @property
    def standing(self):
        return self.produced - self.consumed

    @property
    def running(self):
        return self.producing or self.standing > 0

    def __str__(self):
        return (
            f"{'running' if self.running else 'finished'}"
            f" ({self.produced} produced,"
            f" {self.processed} processed,"
            f" {self.failed} failed,"
            f" {self.consumed} consumed)"
        )


class Component(object):
    def init(self):
        pass

    def shutdown(self):
        pass

    def __str__(self):
        return f"{self.__class__.__name__} id={id(self)} pid={os.getpid()} ppid={os.getppid()}"


class Producer(Component):
    def produce(self, data):
        raise NotImplementedError


class Processor(Component):
    def process(self, data):
        raise NotImplementedError


class Consumer(Component):
    def consume(self, data, result, exception):
        raise NotImplementedError


class WorkUnit(object):
    def __init__(self, serial, job_serial, data):
        self.serial = serial
        self.job_serial = job_serial
        self.data = data
        self.result = None
        self.exception = None

    def __str__(self):
        data = "None" if self.data is None else "<...>"
        result = "None" if self.result is None else "<...>"
        exception = "None" if self.exception is None else "<...>"
        return (
            f"WorkUnit {self.serial} job_serial={self.job_serial} "
            f"data={data} result={result} exception={exception}"
        )


class Shutdown(WorkUnit):
    def __init__(self, serial):
        super().__init__(None, serial, None)


class ParallelSequencePipeline(object):
    def __init__(
        self,
        producer,
        processor,
        consumer,
        n_processors=None,
        require_in_order=None,
        mp_context=None,
        logging_config=None,
    ):
        """
        Processes a sequence in parallel.

        Argument mp_context must be obtained by calling
        multiprocessing.get_context().
        See https://docs.python.org/3/library/multiprocessing.html

        """
        if mp_context is None:
            self.mp_context = multiprocessing.get_context(method="forkserver")
        else:
            self.mp_context = mp_context
        self.mp_manager = self.mp_context.Manager()
        if not isinstance(producer, Producer):
            raise TypeError("producer not an instance of Producer")
        if not isinstance(processor, Processor):
            raise TypeError("processor not an instance of Processor")
        if not isinstance(consumer, Consumer):
            raise TypeError("consumer not an instance of Consumer")
        if n_processors is None:
            self.n_processors = self.mp_context.cpu_count()
        else:
            self.n_processors = n_processors
        self.require_in_order = True if require_in_order is None else require_in_order
        self.job_serial = itertools.count(start=1)
        self.job_input_queue = self.mp_context.Queue()
        self.job_output_queue = self.mp_context.Queue()
        self.work_unit_input_queue = self.mp_context.Queue(
            maxsize=2 * self.n_processors
        )
        self.work_unit_output_queue = self.mp_context.Queue(
            maxsize=2 * self.n_processors
        )
        self.active_jobs = self.mp_manager.dict()
        self.procs = [
            self.mp_context.Process(
                target=produce,
                args=(
                    producer,
                    self.job_input_queue,
                    self.active_jobs,
                    self.work_unit_input_queue,
                    self.n_processors,
                    logging_config,
                ),
            ),
            self.mp_context.Process(
                target=consume,
                args=(
                    consumer,
                    self.work_unit_output_queue,
                    self.active_jobs,
                    self.job_output_queue,
                    self.n_processors,
                    self.require_in_order,
                    logging_config,
                ),
            ),
        ]
        self.procs.extend(
            self.mp_context.Process(
                target=process,
                args=(
                    processor,
                    self.work_unit_input_queue,
                    self.active_jobs,
                    self.work_unit_output_queue,
                    logging_config,
                ),
            )
            for _ in range(self.n_processors)
        )

    def start(self):
        LOG.debug(f"starting [{self}]")
        for proc in self.procs:
            proc.start()

    def shutdown(self):
        LOG.debug(f"shutting down [{self}]")
        self.job_input_queue.put(None)
        for proc in self.procs:
            proc.join()

    def __str__(self):
        return f"{self.__class__.__name__} pid={os.getpid()}"

    def submit(self, job_data):
        job = Job(
            serial=next(self.job_serial),
            data=job_data,
            status=Status(self.mp_manager),
        )
        LOG.debug(f"submitting [{job}] to [{self}]")
        self.active_jobs[job.serial] = job
        self.job_input_queue.put(job.serial)
        return job

    def fetch(self, wait=False):
        if self.job_output_queue.empty() and not wait:
            LOG.debug(f"no job to fetch from [{self}]")
            return None
        serial = self.job_output_queue.get()
        job = self.active_jobs.pop(serial)
        LOG.debug(f"fetched [{job}] from [{self}]")
        return job


def _log_gen_exc(gen, logmsg):
    try:
        yield from gen
    except:  # noqa E722
        LOG.exception(logmsg)


def produce(
    producer,
    job_input_queue,
    active_jobs,
    work_unit_input_queue,
    n_processors,
    logging_config,
):
    if logging_config:
        logging.config.dictConfig(logging_config)
    LOG.info(f"starting production sub-process (pid={os.getpid()})")
    LOG.debug(f"initializing [{producer}]")
    producer.init()
    work_unit_serial = itertools.count(start=1)
    LOG.debug(f"waiting for a job")
    job_serial = job_input_queue.get()
    while job_serial is not None:
        job = active_jobs[job_serial]
        LOG.debug(f"starting production for [{job}] by [{producer}]")
        job.status.started_producing()
        logmsg = f"exception raised during production for [{job}] by [{producer}]"
        for work_unit_data in _log_gen_exc(producer.produce(job.data), logmsg):
            work_unit = WorkUnit(
                serial=next(work_unit_serial),
                job_serial=job_serial,
                data=work_unit_data,
            )
            LOG.debug(f"produced [{work_unit}] for [{job}] by [{producer}]")
            job.status.incr_produced()
            work_unit_input_queue.put(work_unit)
        LOG.debug(f"production finished for [{job}] by [{producer}]")
        job.status.stopped_producing()
        LOG.debug(f"waiting for a job")
        job_serial = job_input_queue.get()
    LOG.info(f"shutting down production sub-process (pid={os.getpid()})")
    try:
        LOG.debug(f"shutting down [{producer}]")
        producer.shutdown()
    except:  # noqa E722
        LOG.exception(f"exception raised while shutting down [{producer}] ")
    LOG.debug(f"sending {n_processors} shutdown messages (one for each processor)")
    for _ in range(n_processors):
        shutdown_unit = Shutdown(next(work_unit_serial))
        work_unit_input_queue.put(shutdown_unit)
    LOG.info(f"exiting production sub-process (pid={os.getpid()})")


def process(
    processor,
    work_unit_input_queue,
    active_jobs,
    work_unit_output_queue,
    logging_config,
):
    if logging_config:
        logging.config.dictConfig(logging_config)
    LOG.info(f"starting processing sub-process (pid={os.getpid()})")
    LOG.debug(f"initializing [{processor}]")
    processor.init()
    work_unit = None
    while not isinstance(work_unit, Shutdown):
        LOG.debug(f"waiting for work unit")
        work_unit = work_unit_input_queue.get()
        if isinstance(work_unit, Shutdown):
            try:
                LOG.debug(f"shutting down [{processor}]")
                processor.shutdown()
            except:  # noqa E722
                work_unit.exception = traceback.format_exc()
                LOG.exception(f"exception raised while shutting down [{processor}]")
        else:
            job = active_jobs[work_unit.job_serial]
            try:
                LOG.debug(f"processing [{work_unit}] of [{job}] with [{processor}]")
                work_unit.result = processor.process(job.data, work_unit.data)
                job.status.incr_processed()
            except:  # noqa E722
                work_unit.exception = traceback.format_exc()
                LOG.exception(
                    f"exception raised while processing [{work_unit}] of "
                    f"[{job}] with [{processor}]"
                )
                job.status.incr_failed()
        LOG.debug(f"sending processed [{work_unit}] to consumer")
        work_unit_output_queue.put(work_unit)
    LOG.info(f"exiting processing sub-process (pid={os.getpid()})")


def consume(
    consumer,
    work_unit_output_queue,
    active_jobs,
    job_output_queue,
    n_processors,
    require_in_order,
    logging_config,
):
    if logging_config:
        logging.config.dictConfig(logging_config)
    LOG.info(f"starting consumption sub-process (pid={os.getpid()})")
    LOG.debug(f"initializing [{consumer}]")
    consumer.init()
    work_units = get_done_work_units(work_unit_output_queue, n_processors)
    if require_in_order:
        work_units = arrange_work_units_in_order(work_units)
    for work_unit in work_units:
        if not isinstance(work_unit, Shutdown):
            job = active_jobs[work_unit.job_serial]
            try:
                LOG.debug(f"consuming [{work_unit}] of [{job}] with [{consumer}]")
                consumer.consume(
                    job.data, work_unit.data, work_unit.result, work_unit.exception
                )
            except:  # noqa E722
                LOG.exception(
                    f"exception raised while consuming [{work_unit}] of "
                    f"[{job}] with [{consumer}]"
                )
            job.status.incr_consumed()
            if not job.status.running:
                job_output_queue.put(job.serial)
    LOG.debug(f"shutting down [{consumer}]")
    consumer.shutdown()
    LOG.info(f"exiting consumption sub-process (pid={os.getpid()})")


def get_done_work_units(processed, n_processors):
    n_running = n_processors
    while n_running:
        LOG.debug(f"waiting for work unit")
        work_unit = processed.get()
        if isinstance(work_unit, Shutdown):
            LOG.debug(f"acknowledging that a processor has been shutdown")
            n_running -= 1
        else:
            assert isinstance(work_unit, WorkUnit)
            yield work_unit


def arrange_work_units_in_order(work_units):
    serial = itertools.count(start=1)
    next_serial = next(serial)
    on_hold = dict()
    for work_unit in work_units:
        if work_unit.serial == next_serial:
            yield work_unit
            next_serial = next(serial)
            while next_serial in on_hold:
                held_unit = on_hold.pop(next_serial)
                yield held_unit
                next_serial = next(serial)
        else:
            LOG.debug(f"holding processed [{work_unit}] to preserve ordering")
            on_hold[work_unit.serial] = work_unit
    assert not on_hold


__all__ = [
    "Status",
    "Component",
    "Producer",
    "Processor",
    "Consumer",
    "ParallelSequencePipeline",
]
