# Copyright 2019 Katteli Inc.
# TestFlows Test Framework (http://testflows.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import sys
import time
import threading

import testflows.settings as settings

from .compress import compress
from .constants import id_sep, end_of_message
from .exceptions import exception as get_exception
from .message import Message, MessageObjectType, dumps
from .objects import Tag, Metric, ExamplesRow
from .funcs import top
from . import __version__

def object_fields(obj, prefix):
    return {f"{prefix}_{field}":getattr(obj, field) for field in obj._fields}

def str_or_repr(v):
    try:
        return str(v)
    except:
        return repr(v)

class TestOutput(object):
    """Test output protocol.

    :param io: message IO
    """
    protocol_version = "TFSPv2.1"

    def __init__(self, test, io):
        self.io = io
        self.test = test
        self.msg_hash = ""
        self.msg_count = 0
        self.prefix = {
            "test_type": str(self.test.type),
            "test_subtype": str(self.test.subtype) if self.test.subtype is not None else None,
            "test_id": id_sep + id_sep.join(str(n) for n in self.test.id),
            "test_name": self.test.name,
            "test_flags": int(self.test.flags),
            "test_cflags": int(self.test.cflags),
            "test_level": len(self.test.id)
        }

    def message(self, keyword, message, object_type=0, stream=None):
        """Output message.

        :param keyword: keyword
        :param message: message
        """
        msg_time = time.time()

        msg = {
            "message_keyword": str(keyword),
            "message_hash": self.msg_hash,
            "message_object": object_type,
            "message_num": self.msg_count,
            "message_stream": stream,
            "message_level": (
                len(self.test.id) + 1
                if keyword not in (Message.TEST, Message.RESULT, Message.PROTOCOL, Message.VERSION)
                else len(self.test.id)
            ),
            "message_time": round(msg_time, settings.time_resolution),
            "message_rtime": round(msg_time - self.test.start_time, settings.time_resolution)
        }
        msg.update(self.prefix)
        msg.update(message)

        msg = dumps(msg)

        self.msg_hash = settings.hash_func(msg.encode("utf-8")).hexdigest()[:settings.hash_length]
        self.msg_count += 1

        parts = msg.split(",",2)
        parts[1] = f"\"message_hash\":\"{self.msg_hash}\""
        self.io.write(f"{parts[0]},{parts[1]},{parts[2]}{end_of_message}")

    def stop(self):
        """Output stop message."""
        self.message(Message.STOP, {})

    def protocol(self):
        """Output protocol version message.
        """
        msg = {"protocol_version": self.protocol_version}
        self.message(Message.PROTOCOL, msg)

    def version(self):
        """Output framework version message.
        """
        msg = {"framework_version": __version__}
        self.message(Message.VERSION, msg)

    def input(self, message):
        """Output input message.

        :param message: message
        """
        msg = {"message": str(message)}
        self.message(Message.INPUT, msg)

    def exception(self, exc_type=None, exc_value=None, exc_traceback=None):
        """Output exception message.

        Note: must be called from within finally block
        """
        msg = {"message": get_exception(exc_type, exc_value, exc_traceback)}
        self.message(Message.EXCEPTION, msg)

    def test_message(self):
        """Output test message.
        """
        msg = {
            "test_name": self.test.name,
            "test_uid": str(self.test.uid or "") or None,
            "test_description": str(self.test.description or "") or None,
        }

        self.message(Message.TEST, msg, object_type=MessageObjectType.TEST)

        [self.attribute(attr) for attr in self.test.attributes.values()]
        [self.requirement(req) for req in self.test.requirements.values()]
        [self.argument(arg) for arg in self.test.args.values()]
        [self.tag(Tag(tag)) for tag in self.test.tags]
        [self.example(ExamplesRow(row._idx, row._fields, [str(f) for f in row], row._row_format)) for row in self.test.examples]
        if self.test.node:
            self.node(self.test.node)
        if self.test.map:
            self.map(self.test.map)

    def attribute(self, attribute, object_type=MessageObjectType.TEST):
        msg = object_fields(attribute, "attribute")
        self.message(Message.ATTRIBUTE, msg, object_type=object_type)

    def requirement(self, requirement, object_type=MessageObjectType.TEST):
        msg = object_fields(requirement, "requirement")
        self.message(Message.REQUIREMENT, msg, object_type=object_type)

    def argument(self, argument, object_type=MessageObjectType.TEST):
        msg = object_fields(argument, "argument")
        value = msg["argument_value"]
        if value is not None:
            msg["argument_value"] = str_or_repr(value)
        self.message(Message.ARGUMENT, msg, object_type=object_type)

    def tag(self, tag, object_type=MessageObjectType.TEST):
        msg = object_fields(tag, "tag")
        self.message(Message.TAG, msg, object_type=object_type)

    def example(self, example, object_type=MessageObjectType.TEST):
        msg = object_fields(example, "example")
        self.message(Message.EXAMPLE, msg, object_type=object_type)

    def node(self, node, object_type=MessageObjectType.TEST):
        msg = object_fields(node, "node")
        self.message(Message.NODE, msg, object_type=object_type)

    def map(self, map, object_type=MessageObjectType.TEST):
        for node in map:
            msg = object_fields(node, "node")
            self.message(Message.MAP, msg, object_type=object_type)

    def ticket(self, ticket, object_type=MessageObjectType.TEST):
        msg = object_fields(ticket, "ticket")
        self.message(Message.TICKET, msg, object_type=object_type)

    def metric(self, metric, object_type=MessageObjectType.TEST):
        msg = object_fields(metric, "metric")
        self.message(Message.METRIC, msg, object_type=object_type)

    def value(self, value, object_type=MessageObjectType.TEST):
        msg = object_fields(value, "value")
        self.message(Message.VALUE, msg, object_type=object_type)

    def result(self, result):
        """Output result message.

        :param result: result object
        """
        msg = {
            "result_message": result.message,
            "result_reason": result.reason,
            "result_type": str(result.type),
            "result_test": result.test
        }
        self.message(Message.RESULT, msg, object_type=MessageObjectType.TEST)

    def note(self, message):
        """Output note message.

        :param message: message
        """
        msg = {"message": str(message)}
        self.message(Message.NOTE, msg)

    def debug(self, message):
        """Output debug message.

        :param message: message
        """
        msg = {"message": str(message)}
        self.message(Message.DEBUG, msg)

    def trace(self, message):
        """Output trace message.

        :param message: message
        """
        msg = {"message": str(message)}
        self.message(Message.TRACE, msg)


class TestInput(object):
    """Test input.
    """
    def __init__(self, test, io):
        self.test = test
        self.io = io


class TestIO(object):
    """Test input and output protocol.
    """
    def __init__(self, test):
        self.io = MessageIO(LogIO())
        self.output = TestOutput(test, self.io)
        self.input = TestInput(test, self.io)

    def message_io(self, name=None):
        """Return named line buffered message io.

        :param name: name of the message stream
        """
        return NamedMessageIO(self, name=name)

    def read(self, topic, timeout=None):
        """Read message.

        :param topic: message topic
        :param timeout: timeout, default: ``None``
        """
        raise NotImplementedError

    def write(self, msg, stream=None):
        """Write line buffered message.

        :param msg: line buffered message
        :param stream: name of the stream
        """
        if not msg:
            return
        self.output.message(Message.NONE, {"message":str(msg).rstrip()}, stream=stream)

    def flush(self):
        self.io.flush()

    def close(self, final=False):
        self.io.close(final=final)

class MessageIO(object):
    """Message input and output.

    :param io: io stream to write and read
    """

    def __init__(self, io):
        self.io = io
        self.buffer = ""

    def read(self, topic, timeout=None):
        """Read message.

        :param topic: message topic
        :param timeout: timeout, default: ``None``
        """
        raise NotImplementedError

    def write(self, msg):
        """Write message.

        :param msg: message
        """
        if not msg:
            return
        if msg[-1] == "\n" and not self.buffer:
            self.io.write(msg)
        elif msg.endswith("\n") or "\n" in msg:
            self.buffer += msg
            messages = self.buffer.split("\n")
            # last message is incomplete
            for message in messages[:-1]:
                self.io.write(f"{message}\n")
            self.buffer = messages[-1]
        else:
            self.buffer += msg

    def flush(self):
        """Flush output buffer.
        """
        if self.buffer:
            self.io.write(f"{self.buffer}\n")
        self.buffer = ""

    def close(self, final=False):
        self.io.close(final=final)

class NamedMessageIO(MessageIO):
    """Message input and output.

    :param io: io stream to write and read
    :param name: name of the stream, default: None
    """

    def __init__(self, io, name=None):
        self.io = io
        self.buffer = ""
        self.stream = name

    def write(self, msg):
        """Write message.

        :param msg: message
        """
        if not msg:
            return
        if not "\n" in msg:
            self.buffer += msg
        else:
            self.buffer += msg
            messages = self.buffer.split("\n")
            # last message is incomplete
            for message in messages[:-1]:
                self.io.write(f"{message}\n", stream=self.stream)
            self.buffer = messages[-1]

    def flush(self):
        """Flush output buffer.
        """
        if self.buffer:
            self.io.write(f"{self.buffer}\n", stream=self.stream)
        self.buffer = ""


class LogReader(object):
    """Read messages from the log.
    """
    def __init__(self):
        self.fd = open(settings.read_logfile, "r", buffering=1, encoding="utf-8")

    def tell(self):
        return self.fd.tell()

    def seek(self, pos):
        return self.fd.seek(pos)

    def read(self, topic, timeout=None):
        raise NotImplementedError

    def close(self, final=False):
        self.fd.close()


class LogWriter(object):
    """Singleton log file writer.
    """
    lock = threading.Lock()
    instance = None
    auto_flush_interval = 0.15

    def __new__(cls, *args, **kwargs):
        with cls.lock:
            if not cls.instance:
                self = object.__new__(LogWriter)
                self.fd = open(settings.write_logfile, "ab", buffering=1)
                self.lock = threading.Lock()
                self.buffer = []
                self.timer = threading.Timer(cls.auto_flush_interval, self.flush, (True,))
                self.timer.start()
                cls.instance = self
            return cls.instance

    def __init__(self):
        pass

    def write(self, msg):
        with self.lock:
            self.buffer.append(msg.encode("utf-8"))
            return len(msg)

    def flush(self, force=False, final=False):
        if not force:
            return

        with self.lock:
            self.timer.cancel()
            if self.buffer:
                self.fd.write(compress(b"".join(self.buffer)))
                self.fd.flush()
                del self.buffer[:]
            if not final:
                self.timer = threading.Timer(self.auto_flush_interval, self.flush, (True,))
                self.timer.start()

    def close(self, final=False):
        if final:
            self.flush(force=True, final=True)
            self.fd.close()

class LogIO(object):
    """Log file reader and writer.

    :param read: file descriptor for read
    :param write: file descriptor for write
    """
    def __init__(self):
        self.writer = LogWriter()
        self.reader = LogReader()

    def write(self, msg):
        return self.writer.write(msg)

    def flush(self):
        return self.writer.flush()

    def tell(self):
        return self.reader.tell()

    def seek(self, pos):
        return self.reader.seek(pos)

    def read(self, topic, timeout=None):
        return self.reader.read(topic, timeout)

    def close(self, final=False):
        self.writer.close(final=final)
        self.reader.close(final=final)

