# SPDX-FileCopyrightText: © 2024-2025 Jimmy Fitzpatrick <jcfitzpatrick12@gmail.com>
# This file is part of SPECTRE
# SPDX-License-Identifier: GPL-3.0-or-later

from dataclasses import dataclass
from datetime import datetime

import numpy as np
import numpy.typing as npt
from astropy.io import fits
from astropy.io.fits.hdu.image import PrimaryHDU
from astropy.io.fits.hdu.table import BinTableHDU
from astropy.io.fits.hdu.hdulist import HDUList

from spectre_core.spectrograms import Spectrogram, SpectrumUnit
from ._batch_keys import BatchKey
from .._base import BaseBatch, BatchFile
from .._register import register_batch


@dataclass(frozen=True)
class _BatchExtension:
    """Supported extensions for a `CallistoBatch`.

    :ivar FITS: Corresponds to the `.fits` file extension.
    """

    FITS: str = "fits"


class _FitsFile(BatchFile[Spectrogram]):
    """Stores spectrogram data in the FITS format generated by the e-Callisto network."""

    def __init__(self, parent_dir_path: str, base_file_name: str) -> None:
        """Initialise a `_FitsFile` instance.

        :param parent_dir_path: The parent directory for the batch.
        :param base_file_name: The batch name.
        """
        super().__init__(parent_dir_path, base_file_name, _BatchExtension.FITS)

    def _read(self) -> Spectrogram:
        """Parses a FITS file to generate a `Spectrogram` instance.

        Reverses the spectra along the frequency axis and converts units to linearised
        values if necessary. Infers the spectrum type from the `BUNIT` header.

        :raises NotImplementedError: If the `BUNIT` header value represents an unsupported spectrum type.
        :return: A `Spectrogram` instance containing the parsed FITS file data.
        """
        with fits.open(self.file_path, mode="readonly") as hdulist:
            primary_hdu = self._get_primary_hdu(hdulist)
            dynamic_spectra = self._get_dynamic_spectra(primary_hdu)
            spectrogram_start_datetime = self._get_spectrogram_start_datetime(
                primary_hdu
            )
            bintable_hdu = self._get_bintable_hdu(hdulist)
            times = self._get_times(bintable_hdu)
            frequencies = self._get_frequencies(bintable_hdu)
            bunit = self._get_bunit(primary_hdu)

            # bunit is interpreted as a SpectrumUnit
            spectrum_unit = SpectrumUnit(bunit)
            if spectrum_unit == SpectrumUnit.DIGITS:
                dynamic_spectra_linearised = self._convert_units_to_linearised(
                    dynamic_spectra
                )

                return Spectrogram(
                    dynamic_spectra_linearised[
                        ::-1, :
                    ],  # reverse the spectra along the frequency axis
                    times,
                    frequencies[::-1],  # sort the frequencies in ascending order
                    self.tag,
                    spectrum_unit,
                    spectrogram_start_datetime,
                )
            else:
                raise NotImplementedError(
                    f"SPECTRE does not currently support spectrum type with BUNITS '{spectrum_unit}'"
                )

    def _get_primary_hdu(self, hdulist: HDUList) -> PrimaryHDU:
        return hdulist["PRIMARY"]

    def _get_bintable_hdu(self, hdulist: HDUList) -> BinTableHDU:
        return hdulist[1]

    def _get_dynamic_spectra(self, primary_hdu: PrimaryHDU) -> npt.NDArray[np.float32]:
        return primary_hdu.data.astype(np.float32)

    def _get_spectrogram_start_datetime(self, primary_hdu: PrimaryHDU) -> datetime:
        date_obs = primary_hdu.header["DATE-OBS"]
        time_obs = primary_hdu.header["TIME-OBS"]
        return datetime.strptime(f"{date_obs}T{time_obs}", "%Y/%m/%dT%H:%M:%S.%f")

    def _get_bunit(self, primary_hdu: PrimaryHDU) -> str:
        return primary_hdu.header["BUNIT"]

    def _convert_units_to_linearised(
        self, raw_digits: npt.NDArray[np.float32]
    ) -> npt.NDArray[np.float32]:
        """Converts spectrogram data from raw digit values to linearised units.

        Applies a transformation based on ADC specifications to convert raw values
        to dB and then to linearised units.

        :param dynamic_spectra: Raw dynamic spectra in digit values.
        :return: The dynamic spectra with linearised units.
        """
        # conversion as per ADC specs [see email from C. Monstein]
        dB = (raw_digits / 255) * (2500 / 25)
        return 10 ** (dB / 10)

    def _get_times(self, bintable_hdu: BinTableHDU) -> npt.NDArray[np.float32]:
        """Extracts the elapsed times for each spectrum in seconds, with the first spectrum set to t=0
        by convention.
        """
        return bintable_hdu.data["TIME"][0]  # already in seconds

    def _get_frequencies(self, bintable_hdu: BinTableHDU) -> npt.NDArray[np.float32]:
        """Extracts the frequencies for each spectral component."""
        frequencies_MHz = bintable_hdu.data["FREQUENCY"][0]
        return frequencies_MHz * 1e6  # convert to Hz


@register_batch(BatchKey.CALLISTO)
class CallistoBatch(BaseBatch):
    """A batch of data generated by the e-Callisto network.

    Supports the following file extensions:
    - `.fits` (via the `spectrogram_file` attribute)
    """

    def __init__(self, start_time: str, tag: str) -> None:
        """Initialise a `CallistoBatch` instance.

        :param start_time: The start time of the batch.
        :param tag: The batch name tag.
        """
        super().__init__(start_time, tag)
        self._fits_file = _FitsFile(self.parent_dir_path, self.name)

        # add files formally to the batch
        self.add_file(self.spectrogram_file)

    @property
    def spectrogram_file(self) -> _FitsFile:
        """The batch file corresponding to the `.fits` extension."""
        return self._fits_file
