# coding=utf-8
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import json
import functools
from collections import defaultdict
from six.moves.urllib.parse import urlparse, parse_qsl
from azure.core.exceptions import (
    HttpResponseError,
    ClientAuthenticationError,
    ODataV4Format,
)
from azure.core.paging import ItemPaged
from ._models import (
    RecognizeEntitiesResult,
    CategorizedEntity,
    TextDocumentStatistics,
    RecognizeLinkedEntitiesResult,
    LinkedEntity,
    ExtractKeyPhrasesResult,
    AnalyzeSentimentResult,
    SentenceSentiment,
    DetectLanguageResult,
    DetectedLanguage,
    DocumentError,
    SentimentConfidenceScores,
    TextAnalyticsError,
    TextAnalyticsWarning,
    RecognizePiiEntitiesResult,
    PiiEntity,
    AnalyzeHealthcareEntitiesResult,
    ExtractSummaryResult,
    _AnalyzeActionsType,
)


class CSODataV4Format(ODataV4Format):
    def __init__(self, odata_error):
        try:
            if odata_error["error"]["innererror"]:
                super(CSODataV4Format, self).__init__(
                    odata_error["error"]["innererror"]
                )
        except KeyError:
            super(CSODataV4Format, self).__init__(odata_error)


def process_http_response_error(error):
    """Raise detailed error message."""
    raise_error = HttpResponseError
    if error.status_code == 401:
        raise_error = ClientAuthenticationError
    raise raise_error(response=error.response, error_format=CSODataV4Format)


def order_results(response, combined):
    """Order results in the order the user passed them in.

    :param response: Used to get the original documents in the request
    :param combined: A combined list of the results | errors
    :return: In order list of results | errors (if any)
    """
    request = json.loads(response.http_response.request.body)["documents"]
    mapping = {item.id: item for item in combined}
    ordered_response = [mapping[item["id"]] for item in request]
    return ordered_response


def order_lro_results(doc_id_order, combined):
    """Order results in the order the user passed them in.
    For long running operations, we need to explicitly pass in the
    document ids since the initial request will no longer be available.

    :param doc_id_order: A list of document IDs from the original request.
    :param combined: A combined list of the results | errors
    :return: In order list of results | errors (if any)
    """

    mapping = [(item.id, item) for item in combined]
    ordered_response = [
        i[1] for i in sorted(mapping, key=lambda m: doc_id_order.index(m[0]))
    ]
    return ordered_response


def prepare_result(func):
    def choose_wrapper(*args, **kwargs):
        def wrapper(
            response, obj, response_headers, ordering_function
        ):  # pylint: disable=unused-argument
            if obj.errors:
                combined = obj.documents + obj.errors
                results = ordering_function(response, combined)

            else:
                results = obj.documents

            for idx, item in enumerate(results):
                if hasattr(item, "error"):
                    results[idx] = DocumentError(
                        id=item.id,
                        error=TextAnalyticsError._from_generated(  # pylint: disable=protected-access
                            item.error
                        ),
                    )
                else:
                    results[idx] = func(item, results)
            return results

        lro = kwargs.get("lro", False)

        if lro:
            return wrapper(*args, ordering_function=order_lro_results)
        return wrapper(*args, ordering_function=order_results)

    return choose_wrapper


@prepare_result
def language_result(language, results):  # pylint: disable=unused-argument
    return DetectLanguageResult(
        id=language.id,
        primary_language=DetectedLanguage._from_generated(  # pylint: disable=protected-access
            language.detected_language
        ),
        warnings=[
            TextAnalyticsWarning._from_generated(w)  # pylint: disable=protected-access
            for w in language.warnings
        ],
        statistics=TextDocumentStatistics._from_generated(  # pylint: disable=protected-access
            language.statistics
        ),
    )


@prepare_result
def entities_result(
    entity, results, *args, **kwargs
):  # pylint: disable=unused-argument
    return RecognizeEntitiesResult(
        id=entity.id,
        entities=[
            CategorizedEntity._from_generated(e)  # pylint: disable=protected-access
            for e in entity.entities
        ],
        warnings=[
            TextAnalyticsWarning._from_generated(w)  # pylint: disable=protected-access
            for w in entity.warnings
        ],
        statistics=TextDocumentStatistics._from_generated(  # pylint: disable=protected-access
            entity.statistics
        ),
    )


@prepare_result
def linked_entities_result(
    entity, results, *args, **kwargs
):  # pylint: disable=unused-argument
    return RecognizeLinkedEntitiesResult(
        id=entity.id,
        entities=[
            LinkedEntity._from_generated(e)  # pylint: disable=protected-access
            for e in entity.entities
        ],
        warnings=[
            TextAnalyticsWarning._from_generated(w)  # pylint: disable=protected-access
            for w in entity.warnings
        ],
        statistics=TextDocumentStatistics._from_generated(  # pylint: disable=protected-access
            entity.statistics
        ),
    )


@prepare_result
def key_phrases_result(
    phrases, results, *args, **kwargs
):  # pylint: disable=unused-argument
    return ExtractKeyPhrasesResult(
        id=phrases.id,
        key_phrases=phrases.key_phrases,
        warnings=[
            TextAnalyticsWarning._from_generated(w)  # pylint: disable=protected-access
            for w in phrases.warnings
        ],
        statistics=TextDocumentStatistics._from_generated(  # pylint: disable=protected-access
            phrases.statistics
        ),
    )


@prepare_result
def sentiment_result(
    sentiment, results, *args, **kwargs
):  # pylint: disable=unused-argument
    return AnalyzeSentimentResult(
        id=sentiment.id,
        sentiment=sentiment.sentiment,
        warnings=[
            TextAnalyticsWarning._from_generated(w)  # pylint: disable=protected-access
            for w in sentiment.warnings
        ],
        statistics=TextDocumentStatistics._from_generated(  # pylint: disable=protected-access
            sentiment.statistics
        ),
        confidence_scores=SentimentConfidenceScores._from_generated(  # pylint: disable=protected-access
            sentiment.confidence_scores
        ),
        sentences=[
            SentenceSentiment._from_generated(  # pylint: disable=protected-access
                s, results, sentiment
            )
            for s in sentiment.sentences
        ],
    )


@prepare_result
def pii_entities_result(
    entity, results, *args, **kwargs
):  # pylint: disable=unused-argument
    return RecognizePiiEntitiesResult(
        id=entity.id,
        entities=[
            PiiEntity._from_generated(e)  # pylint: disable=protected-access
            for e in entity.entities
        ],
        redacted_text=entity.redacted_text
        if hasattr(entity, "redacted_text")
        else None,
        warnings=[
            TextAnalyticsWarning._from_generated(w)  # pylint: disable=protected-access
            for w in entity.warnings
        ],
        statistics=TextDocumentStatistics._from_generated(  # pylint: disable=protected-access
            entity.statistics
        ),
    )


@prepare_result
def healthcare_result(
    health_result, results, *args, **kwargs
):  # pylint: disable=unused-argument
    return AnalyzeHealthcareEntitiesResult._from_generated(  # pylint: disable=protected-access
        health_result
    )


@prepare_result
def summary_result(
    summary, results, *args, **kwargs
):  # pylint: disable=unused-argument
    return ExtractSummaryResult._from_generated(  # pylint: disable=protected-access
        summary
    )


def healthcare_extract_page_data(
    doc_id_order, obj, response_headers, health_job_state
):  # pylint: disable=unused-argument
    return (
        health_job_state.next_link,
        healthcare_result(
            doc_id_order, health_job_state.results, response_headers, lro=True
        ),
    )


def _get_deserialization_callback_from_task_type(task_type):
    if task_type == _AnalyzeActionsType.RECOGNIZE_ENTITIES:
        return entities_result
    if task_type == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES:
        return pii_entities_result
    if task_type == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES:
        return linked_entities_result
    if task_type == _AnalyzeActionsType.ANALYZE_SENTIMENT:
        return sentiment_result
    if task_type == _AnalyzeActionsType.EXTRACT_SUMMARY:
        return summary_result
    return key_phrases_result


def _get_property_name_from_task_type(task_type):
    if task_type == _AnalyzeActionsType.RECOGNIZE_ENTITIES:
        return "entity_recognition_tasks"
    if task_type == _AnalyzeActionsType.RECOGNIZE_PII_ENTITIES:
        return "entity_recognition_pii_tasks"
    if task_type == _AnalyzeActionsType.RECOGNIZE_LINKED_ENTITIES:
        return "entity_linking_tasks"
    if task_type == _AnalyzeActionsType.ANALYZE_SENTIMENT:
        return "sentiment_analysis_tasks"
    if task_type == _AnalyzeActionsType.EXTRACT_SUMMARY:
        return "extractive_summarization_tasks"
    return "key_phrase_extraction_tasks"


def _get_good_result(
    current_task_type,
    index_of_task_result,
    doc_id_order,
    response_headers,
    returned_tasks_object,
):
    deserialization_callback = _get_deserialization_callback_from_task_type(
        current_task_type
    )
    property_name = _get_property_name_from_task_type(current_task_type)
    response_task_to_deserialize = getattr(returned_tasks_object, property_name)[
        index_of_task_result
    ]
    return deserialization_callback(
        doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
    )


def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state):
    iter_items = defaultdict(list)  # map doc id to action results
    task_type_to_index = defaultdict(
        int
    )  # need to keep track of how many of each type of tasks we've seen
    returned_tasks_object = analyze_job_state.tasks
    for current_task_type in task_order:
        index_of_task_result = task_type_to_index[current_task_type]
        results = _get_good_result(
            current_task_type,
            index_of_task_result,
            doc_id_order,
            response_headers,
            returned_tasks_object,
        )
        for result in results:
            iter_items[result.id].append(result)

        task_type_to_index[current_task_type] += 1
    return [iter_items[doc_id] for doc_id in doc_id_order if doc_id in iter_items]


def analyze_extract_page_data(
    doc_id_order, task_order, response_headers, analyze_job_state
):
    # return next link, list of
    iter_items = get_iter_items(
        doc_id_order, task_order, response_headers, analyze_job_state
    )
    return analyze_job_state.next_link, iter_items


def lro_get_next_page(
    lro_status_callback, first_page, continuation_token, show_stats=False
):
    if continuation_token is None:
        return first_page

    try:
        continuation_token = continuation_token.decode("utf-8")

    except AttributeError:
        pass

    parsed_url = urlparse(continuation_token)
    job_id = parsed_url.path.split("/")[-1]
    query_params = dict(parse_qsl(parsed_url.query.replace("$", "")))
    if "showStats" in query_params:
        query_params.pop("showStats")
    query_params["show_stats"] = show_stats

    return lro_status_callback(job_id, **query_params)


def healthcare_paged_result(
    doc_id_order, health_status_callback, _, obj, response_headers, show_stats=False
):  # pylint: disable=unused-argument
    return ItemPaged(
        functools.partial(
            lro_get_next_page, health_status_callback, obj, show_stats=show_stats
        ),
        functools.partial(
            healthcare_extract_page_data, doc_id_order, obj, response_headers
        ),
    )


def analyze_paged_result(
    doc_id_order,
    task_order,
    analyze_status_callback,
    _,
    obj,
    response_headers,
    show_stats=False,
):  # pylint: disable=unused-argument
    return ItemPaged(
        functools.partial(
            lro_get_next_page, analyze_status_callback, obj, show_stats=show_stats
        ),
        functools.partial(
            analyze_extract_page_data, doc_id_order, task_order, response_headers
        ),
    )
