import numpy as np
import xtrack as xt
import xobjects as xo
import xpart as xp
import matplotlib.pyplot as plt

# Load a line and build tracker
line = xt.Line.from_json('FODO_4GeV.json')
line.particle_ref = xt.Particles(mass0=xt.PROTON_MASS_EV, q0=1, p0c=4.8482168e9)
line.build_tracker()

# Twiss
# We consider a case in which all RF cavities are off
tab = line.get_table()
tab_cav = tab.rows[tab.element_type == 'Cavity']
for nn in tab_cav.name:
    line[nn].voltage = 0.0#2000000
    #line[nn].lag = 0.0
    print(tab_cav.name,line[nn].voltage,line[nn].lag)


## Choose a context
context = xo.ContextCpu()         # For CPU
# context = xo.ContextCupy()      # For CUDA GPUs
# context = xo.ContextPyopencl()  # For OpenCL GPUs

## Transfer lattice on context and compile tracking code
line.build_tracker(_context=context)

filename = 'pyOrbit_ini_energygrid_unmatched.dat'
dataPO_ini = np.loadtxt(filename,skiprows=14,unpack=True)

gamma = np.sqrt(1+((4.8482168e9+dataPO_ini[5]*1e9)/xt.PROTON_MASS_EV)**2)
gamma0 = np.sqrt(1+(4.8482168e9/xt.PROTON_MASS_EV)**2)
factor = np.sqrt(1-1/gamma0**2)/np.sqrt(1-1/gamma**2)

particles = line.build_particles(
                        x=dataPO_ini[0],
                        px=dataPO_ini[1],
                        y=dataPO_ini[2],
                        py=dataPO_ini[3],
                        zeta=dataPO_ini[4],
                        delta=dataPO_ini[5]/4.8482168*factor)#(np.sqrt((dataPO_ini[5]+5.8636781018061)**2+m0c2**2)-m0c2-5.0)/5.0)

fig1 = plt.figure(1, figsize=(6.4, 7))
ax21 = fig1.add_subplot(3,1,1)
ax22 = fig1.add_subplot(3,1,2)
ax23 = fig1.add_subplot(3,1,3)
ax21.set_xlabel(r'x [m]')
ax21.set_ylabel(r'px [rad]')
ax22.set_xlabel(r'y [m]')
ax22.set_ylabel(r'py [rad]')
ax23.set_xlabel(r'z [m]')
ax23.set_ylabel(r'$\delta E/E$')

ax21.plot(particles.x,particles.px, '.', markersize=1,color='r')
ax22.plot(particles.y,particles.py, '.', markersize=1,color='r')
ax23.plot(particles.zeta,particles.delta, '.', markersize=1,color='r')

fig1.subplots_adjust(bottom=.08, top=.93, hspace=.33, right=.96, wspace=.33)
plt.show()

## Track (no saving of turn-by-turn data)
n_turns = 100
line.track(particles, num_turns=n_turns,turn_by_turn_monitor=True)
#Calculate tunes

import nafflib

turn_fft = n_turns

Dx=2.7067979928599013 # in meters

N_part=len(dataPO_ini[0])

Qx = np.zeros(N_part)
Qy = np.zeros(N_part)

for i_part in range(N_part):
    Qx[i_part] = nafflib.tune(line.record_last_track.x[i_part, :turn_fft]-Dx*line.record_last_track.delta[i_part, :turn_fft],window_order=2, window_type="hann")
    Qy[i_part] = nafflib.tune(line.record_last_track.y[i_part, :turn_fft],window_order=2, window_type="hann")

#print(Qx,Qy)
filename='XSuite'+filename[7:-4]
np.save(filename+'_turn_'+str(n_turns),np.array([particles.x,particles.px,particles.y,particles.py,particles.zeta,particles.delta/factor,Qx,Qy]))
