import math
import sys
import matplotlib.pyplot as plt
import numpy as np

# import orbit_mpi
import random
import sys

from orbit.core.bunch import Bunch
#from orbit.teapot import teapot
from orbit.teapot import TEAPOT_Lattice, TEAPOT_Ring, TEAPOT_MATRIX_Lattice
from orbit.lattice import AccLattice, AccNode, AccActionsContainer
from orbit.utils.orbit_mpi_utils import bunch_orbit_to_pyorbit, bunch_pyorbit_to_orbit
from orbit.core.spacecharge import Boundary2D
from orbit.diagnostics import TeapotTuneAnalysisNode
from orbit.diagnostics import addTeapotDiagnosticsNode
from orbit.bunch_generators import TwissContainer, TwissAnalysis
from orbit.bunch_generators import GaussDist2D,GaussDist3D
from orbit.matching import Optics

from orbit.teapot import teapot
from orbit.rf_cavities import RFNode, RFLatticeModifications



print("Start.")

# =====Make a Teapot style lattice======

lattice = TEAPOT_Ring()
print("Read MAD.")
lattice.readMAD("FODO_pyorbit.MADX", "RING")
print("Lattice=", lattice.getName(), " length [m] =", lattice.getLength(), " nodes=", len(lattice.getNodes()))
lattice.setUseRealCharge(useCharge=1)

# ------------------------------
# Remove Fringe Fields
# ------------------------------

nodes = lattice.getNodes( )
include_fringe_fields = False

if not include_fringe_fields:
    name_extra = '_no_fringe_fields'
    print('\n removing fringe fields \n')
    nodes = lattice.getNodes()
    for node in nodes:
        #print(node)
        # Check if the node has the methods before calling them
        if hasattr(node, 'setFringeFieldFunctionIN') and hasattr(node, 'setFringeFieldFunctionOUT'):
            node.setFringeFieldFunctionIN(lambda node, paramsDict: None)
            node.setFringeFieldFunctionOUT(lambda node, paramsDict: None)
        else:
            print(f'not changing fringe field for {node}')

#-------------------------------
# set up RF for simulation
#-------------------------------

#print('Adding RF element')
#ZtoPhi = 2.0 * np.pi / lattice.getLength()
#dESync = 0.0 #synchronous particle dE (GeV)
#RFHNum = 1.0 #harmonic number
#RFVoltage = 0.000 # RF voltage in GV
#RFPhase = 180.0 # RF phase in radians
#length = 0.0
#name = "harmonic_rfnode"
#rf_node = RFNode.Harmonic_RFNode(ZtoPhi, dESync, RFHNum, RFVoltage, RFPhase, length, name)
#position = 157.080 #position in the lattice
#RFLatticeModifications.addRFNode(lattice, position, rf_node)

#-------------------------------
# Slice thick elements
#-------------------------------
max_length = 0.1

for node in lattice.getNodes():
    length = node.getLength()
    if length > max_length:
        nparts = 1 + int(length / max_length)
        node.setnParts(nparts)

# ------------------------------
# Calculate lattice functions
# ------------------------------

bunch = Bunch()
bunch.mass(0.93827208816)
bunch.getSyncParticle().kinEnergy(4.0)
bunch.charge(1.0)

beamline = Optics().readtwiss_teapot(lattice, bunch)

matrix_lattice = TEAPOT_MATRIX_Lattice(lattice, bunch)

(arrmuX, arrPosAlphaX, arrPosBetaX) = matrix_lattice.getRingTwissDataX()
(arrmuY, arrPosAlphaY, arrPosBetaY) = matrix_lattice.getRingTwissDataY()
(DispersionX, DispersionXP) = matrix_lattice.getRingDispersionDataX()

print("hor:", arrPosAlphaX[0][1], arrPosBetaX[0][1])
print("ver:", arrPosAlphaY[0][1], arrPosBetaY[0][1])
print("disp:", DispersionX[0][1],DispersionXP[0][1])

# ------------------------------
# Main Bunch init
# ------------------------------


fig1 = plt.figure(1, figsize=(6.4, 7))
ax21 = fig1.add_subplot(3,1,1)
ax22 = fig1.add_subplot(3,1,2)
ax23 = fig1.add_subplot(3,1,3)
ax21.set_xlabel(r'x [mm]')
ax21.set_ylabel(r'px [mrad]')
ax22.set_xlabel(r'y [mm]')
ax22.set_ylabel(r'py [mrad]')
ax23.set_xlabel(r'z [m]')
ax23.set_ylabel(r'$\delta$ [GeV]')

#analysis = TwissAnalysis(3)
filename = 'pyOrbit_ini_energygrid_unmatched.dat'
dataPO_ini = np.loadtxt(filename,skiprows=14,unpack=True)

for i in range(len(dataPO_ini[0])):
    x=dataPO_ini[0][i]
    xp=dataPO_ini[1][i]
    y=dataPO_ini[2][i]
    yp=dataPO_ini[3][i]
    z =dataPO_ini[4][i]
    dE=dataPO_ini[5][i]
    bunch.addParticle(x, xp, y, yp, z, dE)
    #analysis.account([x, xp, y, yp, z, zp])
    ax21.plot(x, xp, '.', markersize=1,color='r')
    ax22.plot(y, yp, '.', markersize=1,color='r')
    ax23.plot(z,dE, '.', markersize=1,color='r')

fig1.subplots_adjust(bottom=.08, top=.93, hspace=.33, right=.96, wspace=.33)
plt.show()

#print(analysis.getTwiss(0))
#print(analysis.getTwiss(1))
#print(analysis.getTwiss(2))

nParticlesGlobal = bunch.getSizeGlobal()
print('bunch # particles = ',nParticlesGlobal)
print("Bunch Generated.")
bunch.macroSize(0.0)

# # -----------------------------------
# # Add Tune Analysis node
# # -----------------------------------

tunes = TeapotTuneAnalysisNode("tune_analysis")
tunes.assignTwiss(arrPosBetaX[0][1], arrPosAlphaX[0][1], DispersionX[0][1],DispersionXP[0][1], arrPosBetaY[0][1], arrPosAlphaY[0][1])
addTeapotDiagnosticsNode(lattice, 0, tunes)

# # -----------------------------------
# # Tracking
# # -----------------------------------
import nafflib

paramsDict = {}
paramsDict["bunch"] = bunch

n_turns = 100
turn_fft = n_turns

Dx=2.7067979928599013 # in meters

N_part=len(dataPO_ini[0])

Qx = np.zeros(N_part)
Qy = np.zeros(N_part)

x_pos=np.zeros([N_part,n_turns])
y_pos=np.zeros([N_part,n_turns])
delta=np.zeros([N_part,n_turns])

for i in range(n_turns):
    lattice.trackBunch(bunch, paramsDict)
    for n in range(N_part):
        x_pos[n,i]=bunch.x(n)
        y_pos[n,i]=bunch.y(n)
        delta[n,i]=bunch.dE(n)/4.0

for i_part in range(N_part):
    Qx[i_part] = nafflib.tune(x_pos[i_part, :turn_fft]-Dx*delta[i_part, :turn_fft],window_order=2, window_type="hann")
    Qy[i_part] = nafflib.tune(y_pos[i_part, :turn_fft],window_order=2, window_type="hann")

bunch.dumpBunch(filename.split('.')[0]+"_turn_" + str(n_turns) + ".dat")
np.save(filename.split('.')[0]+'_tunes_turn_'+str(n_turns),np.array([Qx,Qy]))
