# Faraday Penetration Test IDE
# Copyright (C) 2016  Infobyte LLC (http://www.infobytesec.com/)
# See the file 'doc/LICENSE' for the license information
from builtins import str, bytes
from io import TextIOWrapper

import json
import threading
import logging
import csv

from werkzeug.exceptions import Conflict
from flask import Blueprint, request, abort, jsonify, make_response
from flask_classful import route
from filteralchemy import (
    FilterSet,
    operators,
)
from flask_wtf.csrf import validate_csrf
from marshmallow import fields, ValidationError, Schema, post_load
from marshmallow.validate import OneOf
import wtforms


from faraday.server.api.base import (
    AutoSchema,
    FilterAlchemyMixin,
    FilterSetMeta,
    PaginatedMixin,
    ReadWriteView,
    FilterMixin,
    PatchableMixin
)

from faraday.server.schemas import (
    PrimaryKeyRelatedField,
    SeverityField,
    SelfNestedField,
    FaradayCustomField,
)

from faraday.server.models import (
    db,
    CustomFieldsSchema,
    Vulnerability,
    VulnerabilityTemplate,
)

vulnerability_template_api = Blueprint('vulnerability_template_api', __name__)
logger = logging.getLogger(__name__)


class ImpactSchema(Schema):
    accountability = fields.Boolean(attribute='impact_accountability', default=False)
    availability = fields.Boolean(attribute='impact_availability', default=False)
    confidentiality = fields.Boolean(attribute='impact_confidentiality', default=False)
    integrity = fields.Boolean(attribute='impact_integrity', default=False)


class VulnerabilityTemplateSchema(AutoSchema):
    _id = fields.Integer(dump_only=True, attribute='id')
    id = fields.Integer(dump_only=True, attribute='id')
    _rev = fields.String(default='', dump_only=True)
    cwe = fields.String(dump_only=True, default='') # deprecated field, the legacy data is added to refs on import
    exploitation = SeverityField(attribute='severity', required=True)
    references = fields.Method('get_references', deserialize='load_references')
    refs = fields.List(fields.String(), dump_only=True, attribute='references')
    desc = fields.String(dump_only=True, attribute='description')
    data = fields.String(attribute='data')
    impact = SelfNestedField(ImpactSchema())
    easeofresolution = fields.String(
        attribute='ease_of_resolution',
        validate=OneOf(Vulnerability.EASE_OF_RESOLUTIONS),
        allow_none=True)
    policyviolations = fields.List(fields.String,
                                   attribute='policy_violations')
    creator = PrimaryKeyRelatedField('username', dump_only=True, attribute='creator')
    creator_id = fields.Integer(dump_only=True, attribute='creator_id')

    create_at = fields.DateTime(attribute='create_date',
                        dump_only=True)

    # Here we use vulnerability instead of vulnerability_template to avoid duplicate row
    # in the custom_fields_schema table.
    # All validation will be against vulnerability table.
    external_id = fields.String(allow_none=True)
    customfields = FaradayCustomField(table_name='vulnerability', attribute='custom_fields')

    class Meta:
        model = VulnerabilityTemplate
        fields = ('id', '_id', '_rev', 'cwe', 'description', 'desc',
                  'exploitation', 'name', 'references', 'refs', 'resolution',
                  'impact', 'easeofresolution', 'policyviolations', 'data',
                  'external_id', 'creator', 'create_at', 'creator_id',
                  'customfields')

    def get_references(self, obj):
        return ', '.join(map(lambda ref_tmpl: ref_tmpl.name, obj.reference_template_instances))

    def load_references(self, value):
        if isinstance(value, bytes):
            value = value.decode('utf-8')
        if isinstance(value, list):
            references = value
        elif isinstance(value, str):
            if len(value) == 0:
                # Required because "".split(",") == [""]
                return []
            references = [ref.strip() for ref in value.split(',')]
        else:
            raise ValidationError('references must be a either a string '
                                  'or a list')
        if any(len(ref) == 0 for ref in references):
            raise ValidationError('Empty name detected in reference')
        return references

    @post_load
    def post_load_impact(self, data, **kwargs):
        # Unflatten impact (move data[impact][*] to data[*])
        impact = data.pop('impact', None)
        if impact:
            data.update(impact)
        return data


class VulnerabilityTemplateFilterSet(FilterSet):
    class Meta(FilterSetMeta):
        model = VulnerabilityTemplate  # It has all the fields
        fields = (
            'severity')
        operators = (operators.Equal,)


lock = threading.Lock()


class VulnerabilityTemplateView(PaginatedMixin,
                                FilterAlchemyMixin,
                                ReadWriteView,
                                FilterMixin):
    route_base = 'vulnerability_template'
    model_class = VulnerabilityTemplate
    schema_class = VulnerabilityTemplateSchema
    filterset_class = VulnerabilityTemplateFilterSet
    get_joinedloads = [VulnerabilityTemplate.creator]

    def _envelope_list(self, objects, pagination_metadata=None):
        vuln_tpls = []
        for template in objects:
            vuln_tpls.append({
                'id': template['_id'],
                'key': template['_id'],
                'value': {'rev': ''},
                'doc': template
            })
        return {
            'rows': vuln_tpls,
            'total_rows': len(objects)
        }

    def post(self, **kwargs):
        """
        ---
        post:
          tags: ["VulnerabilityTemplate"]
          summary: Creates VulnerabilityTemplate
          requestBody:
            required: true
            content:
              application/json:
                schema: VulnerabilityTemplateSchema
          responses:
            201:
              description: Created
              content:
                application/json:
                  schema: VulnerabilityTemplateSchema
            409:
              description: Duplicated key found
              content:
                application/json:
                  schema: VulnerabilityTemplateSchema
        """
        with lock:
            return super(VulnerabilityTemplateView, self).post(**kwargs)

    def _get_schema_instance(self, route_kwargs, **kwargs):
        schema = super(VulnerabilityTemplateView, self)._get_schema_instance(
            route_kwargs, **kwargs)

        return schema

    @route('/bulk_create/', methods=['POST'])
    def bulk_create(self):
        """
        ---
        post:
          tags: ["Bulk", "VulnerabilityTemplate"]
          description: Creates Vulnerability templates in bulk
          responses:
            201:
              description: Created
              content:
                application/json:
                  schema: VulnerabilityTemplateSchema
            400:
              description: Bad request
            403:
              description: Forbidden
        tags: ["Bulk", "VulnerabilityTemplate"]
        responses:
          200:
            description: Ok
        """
        csrf_token = request.form.get('csrf_token', '')
        if not csrf_token:
            csrf_token = request.json.get('csrf_token', '')
        try:
            validate_csrf(csrf_token)
        except wtforms.ValidationError:
            logger.error("Invalid CSRF token.")
            abort(make_response({"message": "Invalid CSRF token."}, 403))

        if 'file' in request.files:
            logger.info("Create vulns template from CSV")
            vulns_file = request.files['file']

            io_wrapper = TextIOWrapper(vulns_file, encoding=request.content_encoding or "utf8")
            vulns_reader = csv.DictReader(io_wrapper, skipinitialspace=True)

            required_headers = {'name', 'exploitation'}
            diff_required = required_headers.difference(set(vulns_reader.fieldnames))
            if diff_required:
                logger.error(f"Missing required headers in CSV: {diff_required}")
                abort(
                    make_response(
                        {"message": f"Missing required headers in CSV: {diff_required}"}, 400
                    )
                )

            vulns_to_create = self._parse_vuln_from_file(vulns_reader)
        elif request.json.get('vulns'):
            logger.info("Create vulns template from vulnerabilities in Status Report")

            vulns_to_create = request.json.get('vulns')
            for vuln in vulns_to_create:
                # Due to the definition in the model, we need to
                # rename 'custom_fields' attribute to 'customfields'
                vuln['customfields'] = vuln.get('custom_fields', {})
        else:
            logger.error("Missing data to create vulnerabilities templates.")
            abort(make_response({"message": "Missing data to create vulnerabilities templates."}, 400))

        if not vulns_to_create:
            logger.error("Missing data to create vulnerabilities templates.")
            abort(make_response({"message": "Missing data to create vulnerabilities templates."}, 400))

        vulns_created = []
        vulns_with_errors = []
        vulns_with_conflict = []
        schema = self.schema_class()
        for vuln in vulns_to_create:
            try:
                vuln_schema = schema.load(vuln)
                super(VulnerabilityTemplateView, self)._perform_create(vuln_schema)
                db.session.commit()
            except ValidationError as e:
                vulns_with_errors.append((vuln.get('_id', ''), vuln['name']))
            except Conflict:
                vulns_with_conflict.append((vuln.get('_id', ''), vuln['name']))
            else:
                vulns_created.append((vuln.get('_id', ''), vuln['name']))

        if vulns_created:
            status_code = 200
        elif not vulns_created and vulns_with_conflict:
            status_code = 409
        elif not vulns_created and vulns_with_errors:
            status_code = 400

        return make_response(
            jsonify(vulns_created=vulns_created,
                    vulns_with_errors=vulns_with_errors,
                    vulns_with_conflict=vulns_with_conflict),
            status_code
        )


    def _parse_vuln_from_file(self, vulns_reader):
        custom_fields = {cf_schema.field_name: cf_schema for cf_schema in db.session.query(CustomFieldsSchema).all()}
        vulns_list = []
        for index, vuln_dict in enumerate(vulns_reader):
            vuln_dict['customfields'] = {}
            vuln_dict['impact'] = {}
            for key in vuln_dict.keys():
                if key in custom_fields.keys():
                    if custom_fields[key].field_type == 'list' and vuln_dict[key]:
                        custom_field_value = vuln_dict[key].replace('‘', '"').replace('’', '"')
                        try:
                            vuln_dict['customfields'][key] = json.loads(custom_field_value)
                        except ValueError:
                            logger.warning(f'Invalid list for custom field {key}. '
                                           f'Faraday will skip this custom field.')
                    elif custom_fields[key].field_type == 'choice' and vuln_dict[key]:
                        cf_choices = custom_fields[key].field_metadata
                        if isinstance(cf_choices, str):
                            cf_choices = json.loads(cf_choices)
                        if vuln_dict[key] not in cf_choices:
                            logger.warning(f'Invalid choice for custom field {key}. '
                                           f'Faraday will skip this custom field.')
                        else:
                            vuln_dict['customfields'][key] = vuln_dict[key]
                    else:
                        vuln_dict['customfields'][key] = vuln_dict[key]

            vuln_dict['impact']['accountability'] = vuln_dict.get('accountability', False)
            vuln_dict['impact']['availability'] = vuln_dict.get('availability', False)
            vuln_dict['impact']['confidentiality'] = vuln_dict.get('confidentiality', False)
            vuln_dict['impact']['integrity'] = vuln_dict.get('integrity', False)
            vulns_list.append(vuln_dict)

        return vulns_list


class VulnerabilityTemplateV3View(VulnerabilityTemplateView, PatchableMixin):
    route_prefix = 'v3/'
    trailing_slash = False

    @route('/bulk_create', methods=['POST'])
    def bulk_create(self):
        return super(VulnerabilityTemplateV3View, self).bulk_create()

    bulk_create.__doc__ = VulnerabilityTemplateView.bulk_create.__doc__


VulnerabilityTemplateView.register(vulnerability_template_api)
VulnerabilityTemplateV3View.register(vulnerability_template_api)
