# coding: utf-8
"""
CAS Intensity Limitations in Hadron Beams

This python package contains all support functions for the hands-on session
on the study of the beam induced heating for the CERN SPS Wire Scanners

Authors: H. Damerau, J. Flowerdew, L. Intelisano, A. Lasheen, M. Neroni, C. Völlinger (CERN)
"""

import numpy as np
from scipy.constants import e
from scipy.integrate import cumulative_trapezoid
from scipy.special import hyp0f1, gamma


def parabolic_bunch(time_array, bunch_position, bunch_length, bunch_charge):
    """
    Generates a parabolic bunch profile given the time array,
    bunch position, bunch length, and bunch charge.

    Parameters
    ----------
    time_array : array
        Time array [s]
    bunch_position : float
        Position of the bunch [s]
    bunch_length : float
        Length of the bunch [s]
    bunch_charge : float
        Charge of the bunch [number of elementary charges e]

    Returns
    -------
    profile : array
        Parabolic bunch profile (current) [Ampères]
    """

    profile = 1 - 4 * ((time_array - bunch_position) / bunch_length) ** 2.0
    profile[profile < 0] = 0
    normalization_factor = 2 * gamma(2.5) / (bunch_length * np.sqrt(np.pi) * gamma(2))
    profile = bunch_charge * e * profile * normalization_factor

    return profile


def generate_full_beam_profile(
    time_array,
    rf_bucket_length,
    bunch_charge,
    bunch_length,
    number_of_bunches_per_batch,
    bunch_spacing,
    number_of_batches=1,
    batch_spacing=0,
):
    """
    Generates a full beam profile from a parabolic bunch shape.

    Parameters
    ----------
    time_array : np.array
        The time array in seconds
    rf_bucket_length : float
        The length of the RF bucket in seconds
    bunch_charge : float
        The bunch charge in number of particles
    bunch_length : float
        The bunch length in seconds
    number_of_bunches_per_batch : int
        The number of bunches per batch
    bunch_spacing : float
        The spacing between bunches in seconds
    number_of_batches : int, optional
        The number of batches, by default 1
    batch_spacing : float, optional
        The spacing between batches in seconds, by default 0

    Returns
    -------
    np.array
        The full beam profile (current) in Ampères
    """

    beam_profile = np.zeros(len(time_array))

    bunch_position = rf_bucket_length / 2

    for _ in range(number_of_batches):
        for _ in range(number_of_bunches_per_batch):
            beam_profile[
                (time_array >= bunch_position - bunch_length / 2)
                * (time_array <= bunch_position + bunch_length / 2)
            ] = parabolic_bunch(
                time_array[
                    (time_array >= bunch_position - bunch_length / 2)
                    * (time_array <= bunch_position + bunch_length / 2)
                ],
                bunch_position,
                bunch_length,
                bunch_charge,
            )
            bunch_position += bunch_spacing
        bunch_position += batch_spacing

    return beam_profile


def compute_beam_spectrum(time_array, beam_current):
    """
    Compute the beam spectrum from the beam profile.

    Parameters
    ----------
    time_array : array
        Time array [s]
    beam_current : array
        Beam current [A]

    Returns
    -------
    beam_spectrum_freq : array
        Frequency array [Hz]
    beam_spectrum : array
        Beam current spectrum [A]
    """

    beam_spectrum_freq = np.fft.rfftfreq(
        len(time_array), d=time_array[1] - time_array[0]
    )

    beam_spectrum = np.fft.rfft(beam_current, len(time_array)) * 2 / len(time_array)

    return beam_spectrum_freq, beam_spectrum


def analytical_parabolic_bunch_spectrum(
    frequency_array,
    bunch_length,
):
    """
    Computes the analytical spectrum of a parabolic bunch.

    Parameters
    ----------
    frequency_array : array
        Array of frequencies [Hz].
    bunch_length : float
        Length of the bunch [s].

    Returns
    -------
    spectrum : array
        Analytical spectrum of the parabolic bunch.
    """

    return hyp0f1(
        2.5,
        -((np.pi * bunch_length * frequency_array) ** 2.0) / 4.0,
    )


def resonator_wakefield(time_array, resonant_frequency, r_shunt, quality_factor):
    """
    Calculates the wakefield using the Resonator model

    Parameters
    ----------
    time_array : array
        Time array [s]
    resonant_frequency : float or array_like
        Resonant frequency of the resonator [Hz]
    r_shunt : float or array_like
        Shunt impedance of the resonator [Ohm]
    quality_factor : float or array_like
        Quality factor of the resonator

    Returns
    -------
    wake : array
        The wakefield [V/C]
    """

    wake = np.zeros(len(time_array))

    resonant_frequency = np.array(resonant_frequency, ndmin=1)
    r_shunt = np.array(r_shunt, ndmin=1)
    quality_factor = np.array(quality_factor, ndmin=1)

    for idx_resonator, _ in enumerate(resonant_frequency):
        omega_r = 2 * np.pi * resonant_frequency[idx_resonator]
        alpha = omega_r / (2 * quality_factor[idx_resonator])

        omega_bar = np.sqrt(omega_r**2 - alpha**2)

        wake += (
            (np.sign(time_array) + 1)
            * r_shunt[idx_resonator]
            * alpha
            * np.exp(-alpha * time_array)
            * (
                np.cos(omega_bar * time_array)
                - alpha / omega_bar * np.sin(omega_bar * time_array)
            )
        )

    return wake


def resonator_impedance(frequency_array, resonant_frequency, r_shunt, quality_factor):
    """
    Calculates the impedance of a resonator using the Resonator model

    Parameters
    ----------
    frequency_array : array
        Frequency array [Hz]
    resonant_frequency : float or array_like
        Resonant frequency of the resonator [Hz]
    r_shunt : float or array_like
        Shunt impedance of the resonator [Ohm]
    quality_factor : float or array_like
        Quality factor of the resonator

    Returns
    -------
    impedance : array
        The impedance of the resonator [Ohm]
    """

    impedance = np.zeros(len(frequency_array), complex)

    resonant_frequency = np.array(resonant_frequency, ndmin=1)
    r_shunt = np.array(r_shunt, ndmin=1)
    quality_factor = np.array(quality_factor, ndmin=1)

    for idx_resonator, _ in enumerate(resonant_frequency):
        impedance[frequency_array > 0] += r_shunt[idx_resonator] / (
            1
            + 1j
            * quality_factor[idx_resonator]
            * (
                frequency_array[frequency_array > 0] / resonant_frequency[idx_resonator]
                - resonant_frequency[idx_resonator]
                / frequency_array[frequency_array > 0]
            )
        )

    return impedance


def compute_induced_voltage_time_domain(
    time_array,
    beam_current,
    wakefield,
):
    """
    Computes the induced voltage from the beam current and single particle wakefield.

    Parameters
    ----------
    time_array : array
        Time array [s]
    beam_current : array
        Beam current [A]
    single_particle_wakefield : array
        Single particle wakefield [V/C]

    Returns
    -------
    induced_voltage : array
        Induced voltage [V]
    """

    induced_voltage = -np.convolve(beam_current, wakefield, mode="full")[
        : len(beam_current)
    ] * (time_array[1] - time_array[0])

    return induced_voltage


def compute_induced_voltage_frequency_domain(
    frequency_array,
    beam_spectrum,
    impedance,
):
    """
    Computes the induced voltage from the beam spectrum and impedance.

    Parameters
    ----------
    frequency_array : array
        Frequency array [Hz]
    beam_spectrum : array
        Beam spectrum [A/Hz]
    impedance : array
        Impedance [Ohm]

    Returns
    -------
    induced_voltage : array
        Induced voltage [V]
    """

    induced_voltage = (
        -1 * np.fft.irfft(impedance * beam_spectrum) * (len(frequency_array) - 1)
    )

    return induced_voltage


def compute_beam_induced_power_time_domain(
    time_array, beam_current, induced_voltage, time_averaging=None
):
    """
    Computes the beam induced power in time domain.
    NB: the energy loss is averaged on the time span defined by time_array

    Parameters
    ----------
    time_array : array
        Time array [s]
    beam_current : array
        Beam current [A]
    induced_voltage : array
        Induced voltage [V]

    Returns
    -------
    beam_induced_power : float
        Beam induced power [W]
    """

    if time_averaging is None:
        time_averaging = time_array[-1] - time_array[0]

    return - (
        cumulative_trapezoid(beam_current * induced_voltage, time_array)[-1]
        / time_averaging
    )


def compute_beam_induced_power_frequency_domain(beam_current_spectrum, impedance):
    """
    Computes the beam induced power in frequency domain.

    NB: the energy loss is averaged on the time span that was used to compute the beam spectrum


    Parameters
    ----------
    beam_current_spectrum : array
        Spectrum of the beam current [A]
    impedance : array
        Impedance [Ohm]

    Returns
    -------
    beam_induced_power : float
        Beam induced power [W]
    """

    # The normalization assumes that the beam current spectrum is normalized using rfft
    # and hence needs to be divided by 2
    # The last * 2 is from the integral is done in f=[0, +inf]
    return np.sum(impedance.real * np.abs(beam_current_spectrum / 2) ** 2) * 2


def load_impedance(filename, frequency_array):
    """
    Loads impedance data from a file and interpolates it onto a given frequency array.

    Parameters
    ----------
    filename : str
        Path to the file containing the impedance data.
    frequency_array : array
        Array of frequencies [Hz] onto which the impedance data will be interpolated.

    Returns
    -------
    loaded_impedance : array
        Complex array of interpolated impedance values [Ohm].
    """

    loaded_frequency_array, loaded_impedance_real, loaded_impedance_imag = np.loadtxt(
        filename
    ).T

    loaded_frequency_array *= 1e9

    loaded_impedance = np.interp(
        frequency_array, loaded_frequency_array, loaded_impedance_real, right=0
    ) + 1j * np.interp(
        frequency_array, loaded_frequency_array, loaded_impedance_imag, right=0
    )

    return loaded_impedance


__packages = {
    "numpy": "numpy",
    "scipy": "scipy",
    "matplotlib": "matplotlib",
    # "ipympl": "ipympl",
}
__setup_ok = True
for __package, __import_name in __packages.items():
    try:
        __module = __import__(__import_name)
        print(f"{__package} is installed, version: {__module.__version__}")
    except ImportError:
        print(f"{__package} is not installed")
        __setup_ok = False

if __setup_ok:
    print("-> Setup is OK!! Have fun!!")
else:
    print("-> Setup is NOT OK!!")
