import numpy as np
import warnings
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML

def generateBunch(bunch_position, bunch_length,
                  bunch_energy, energy_spread,
                  n_macroparticles):
    
    # Generating phase and energy arrays
    phase_array = np.linspace(bunch_position-bunch_length/2,
                              bunch_position+bunch_length/2,
                              100)
    energy_array = np.linspace(bunch_energy-energy_spread/2,
                              bunch_energy+energy_spread/2,
                              100)
    
    # Getting Hamiltonian on a grid
    phase_grid, deltaE_grid = np.meshgrid(
        phase_array, energy_array)
    
    # Bin sizes
    bin_phase = phase_array[1]-phase_array[0]
    bin_energy = energy_array[1]-energy_array[0]
    
    # Density grid
    isodensity_lines = ((phase_grid-bunch_position)/bunch_length*2)**2. + \
        ((deltaE_grid-bunch_energy)/energy_spread*2)**2.
    density_grid = 1-isodensity_lines**2.
    density_grid[density_grid<0] = 0
    density_grid /= np.sum(density_grid)
   
    # Generating particles randomly inside the grid cells according to the
    # provided density_grid
    indexes = np.random.choice(np.arange(0,np.size(density_grid)), 
                               n_macroparticles, p=density_grid.flatten())
    
    # Randomize particles inside each grid cell (uniform distribution)
    particle_phase = (np.ascontiguousarray(phase_grid.flatten()[indexes] +
        (np.random.rand(n_macroparticles) - 0.5) * bin_phase))
    particle_energy = (np.ascontiguousarray(deltaE_grid.flatten()[indexes] +
        (np.random.rand(n_macroparticles) - 0.5) * bin_energy))
    
    return particle_phase, particle_energy


def plotPhaseSpace(distribution, figname=None,
                   xbins=50, ybins=50,
                   xlim=None, ylim=None):

    plt.figure(figname, figsize=(8,8))
    plt.clf()
    # Definitions for placing the axes
    left, width = 0.115, 0.63
    bottom, height = 0.115, 0.63
    bottom_h = left_h = left+width+0.03

    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom_h, width, 0.2]
    rect_histy = [left_h, bottom, 0.2, height]
    # rect_txtBox= [left_h, bottom_h, 0.2, 0.2]

    axHistx = plt.axes(rect_histx)
    axHisty = plt.axes(rect_histy)
    axScatter = plt.axes(rect_scatter)

    # txtBox = plt.axes(rect_txtBox)
    
    hist_phase = np.histogram(distribution[0], xbins, range=xlim)
    global line_phase
    line_phase, = axHistx.plot(hist_phase[1][0:-1]+(hist_phase[1][1]-hist_phase[1][0])/2, hist_phase[0]/np.max(hist_phase[0]))    
    axHistx.axes.get_xaxis().set_ticklabels([])      
    axHistx.axes.get_yaxis().set_ticklabels([])  
    axHistx.set_ylabel('Bunch profile $\\lambda_{\\phi}$')
        
    hist_energy = np.histogram(distribution[1], ybins, range=ylim)
    global line_energy
    line_energy, = axHisty.plot(hist_energy[0]/np.max(hist_energy[0]), hist_energy[1][0:-1]+(hist_energy[1][1]-hist_energy[1][0])/2)    
    axHisty.axes.get_xaxis().set_ticklabels([])  
    axHisty.axes.get_yaxis().set_ticklabels([])  
    axHisty.set_xlabel('Energy spread $\\lambda_{\\Delta E}$')
    
    global distri_plot
    distri_plot, = axScatter.plot(*distribution, 'o', alpha=0.5)
    axScatter.set_xlabel('Phase $\\phi$ [rad]')
    axScatter.set_ylabel('Energy $\\Delta E$ [arb. units]')
    plt.xlim(xlim)
    plt.ylim(ylim)


def reduction_ratio(deltaE, voltage):
    
    phis = np.arcsin(deltaE/voltage)
    return (1-np.sin(phis))/(1+np.sin(phis))


def separatrix(phase_array, eta, beta, energy, charge, voltage, harmonic, energy_gain):

    warnings.filterwarnings("once")
    
    eom_factor_potential = -np.sign(eta)
    
    potential_well = eom_factor_potential * \
        voltage * np.cos(phase_array) - \
        eom_factor_potential*energy_gain/abs(charge)*phase_array
    
    phi_s = np.arcsin(energy_gain/abs(charge)/voltage)
    
    potential_at_ufp = eom_factor_potential * \
        voltage * np.cos(-np.pi+phi_s) - \
        eom_factor_potential*energy_gain/abs(charge)*(-np.pi+phi_s)

    potential_well -= potential_at_ufp
    
    separatrix_array = np.empty(len(phase_array))*np.nan
    separatrix_array[potential_well<0] = np.sqrt(-potential_well[potential_well<0])
    separatrix_array[potential_well<0] *= np.sqrt(2*energy*beta**2*charge*voltage/(np.pi*harmonic*eta))/np.nanmax(separatrix_array[potential_well<0])
    
    separatrix_array = np.append(separatrix_array, -separatrix_array[::-1])
    phase_sep = np.append(phase_array, phase_array[::-1])
         
    return phase_sep[np.isfinite(separatrix_array)], separatrix_array[np.isfinite(separatrix_array)]


class TrackAnimation(object):
    
    def __init__(
        self, particles, trackingFunction, figname, iterations, framerate,
        xbins=50, ybins=50, xlim=None, ylim=None,
        phase_sep=None, separatrix_array=None):
        
        self.particles = particles
        self.trackingFunction = trackingFunction
        self.figname = figname
        self.iterations = iterations
        self.framerate = framerate
        self.xbins = xbins
        self.ybins=ybins
        self.xlim = xlim
        self.ylim = ylim
        self.phase_sep=phase_sep
        self.separatrix_array=separatrix_array
    
    def run_animation(self):
        
        self._init()
        anim = animation.FuncAnimation(
            self.anim_fig, self._animate, init_func=self._init,
            frames=self.iterations, interval=1000/self.framerate, blit=True)
        return HTML(anim.to_jshtml())

    def _init(self):

        plotPhaseSpace(self.particles, figname=self.figname,
                       xbins=self.xbins, ybins=self.ybins,
                       xlim=self.xlim, ylim=self.ylim)

        if self.phase_sep is not None and self.separatrix_array is not None:
            plt.plot(self.phase_sep, self.separatrix_array, 'r')

        self.anim_fig = plt.gcf()

        return (line_phase, line_energy, distri_plot)

    def _animate(self, i):

        self.trackingFunction(self.particles)

        hist_phase = np.histogram(self.particles[0], self.xbins, range=self.xlim)
        line_phase.set_data(hist_phase[1][0:-1]+(hist_phase[1][1]-hist_phase[1][0])/2, hist_phase[0]/np.max(hist_phase[0]))

        hist_energy = np.histogram(self.particles[1], self.ybins, range=self.ylim)
        line_energy.set_data(hist_energy[0]/np.max(hist_energy[0]), hist_energy[1][0:-1]+(hist_energy[1][1]-hist_energy[1][0])/2)

        distri_plot.set_data(self.particles[0], self.particles[1])

        return (line_phase, line_energy, distri_plot)

    
def run_animation(particles, trackingFunction, figname, iterations, framerate,
                  xbins=50, ybins=50, xlim=None, ylim=None,
                  phase_sep=None, separatrix_array=None):

    trackanim = TrackAnimation(particles,
        trackingFunction, figname, iterations, framerate,
        xbins=xbins, ybins=ybins,
        xlim=xlim, ylim=ylim,
        phase_sep=phase_sep, separatrix_array=separatrix_array)
    
    return trackanim.run_animation()


def oscillation_spectrum(phase_track, fft_zero_padding=0):
    
    n_turns = len(phase_track)
    
    freq_array = np.fft.rfftfreq(n_turns+fft_zero_padding)
    fft_osc = np.abs(
        np.fft.rfft(
            phase_track-np.mean(phase_track),
            n_turns+fft_zero_padding)* 2/(n_turns))

    return freq_array, fft_osc


def synchrotron_tune(phase_track, fft_zero_padding=0):
    
    freq_array, spectrum_array = oscillation_spectrum(phase_track, fft_zero_padding=fft_zero_padding)

    oscillation_amplitude = np.max(spectrum_array)
    sync_tune = float(freq_array[spectrum_array==oscillation_amplitude])

    return oscillation_amplitude, sync_tune
