#
# KdV_FFT_RK.py :
#
# A python 3 integrator for KdV. It uses FFT for spatial derivatives  
# and Runge-Kutta 4 for time evolution, and an integrating factor to 
# eliminate the dispersive term.
#
# NOTE: to get the animations working, best avoid jupyter - simply run
# the script straight from a desktop terminal, or using ipython.
# 
# Solves u_t = - 6 u u_x - u_xxx
# 
# The main program was originally written by Sam Webster in AIMS South
# Africa in 2007; modifications by Patrick Dorey in 2010, 2013 and 2014.
#
# Try changing N in the "initial conditions for u" section a little 
# way down; you can also experiment with completely different initial 
# conditions to see what happens...

import numpy as np
import matplotlib.pyplot as plt
import time

plt.ion()  # turn off interaction mode for plotting

########################################################

# range of time, and time step:
tmax = 0.5
dt = 0.00005

# number of points: (should be a power of 2 for the FFT)   
M = 512 

# x period:
L = 20.0

# x step size:
h = L/M

# approx time step between plots on screen:
dtplot = 0.001

# x-axis points for plots: (note p.b.c. equates -L/2 to L/2)
x = np.linspace(-L/2,L/2-h,M) 

########################################################

# *** initial conditions for u *** :
N = 4
u = N*(N+1)/np.cosh(x)**2

########################################################

# initial y range for the plots:
ymin = -0.1
ymax = 1.1*max(u)

########################################################

# Note that Uhat(k,t) = exp(-i(pi k /L)^3 t) uhat(k,t) where 
# uhat(k,t) is the FT of u(x,t). Hence Uhat(k,0)=uhat(k,0), 
# and the initial Uhat is simply the Fourier Transform of 
# the initial data:

Uhat = np.fft.rfft(u)

# The FT assumes the period to be 2 pi (hence the scaling)
# and so the values of k range from -M/2 to M/2. However 
# since u is real we can use Python's "real" FFT, called
# rfft. This drops the negative-frequency terms (which are 
# just the complex conjugates of the positive frequency ones) 
# and so k needs only to run from 0 to M/2:

k = np.arange(M/2+1)

########################################################

# The function f calculates the right-hand-side of the Uhat 
# ODE. The FFT routines need the function u to have period 
# 2 pi; the factors of s=2 pi/L take this into account by 
# rescaling to/from period L:

s = 2*np.pi/L
sk = s*k
A = 1j*sk**3
B = -3j*sk
def f(tt,uu):
    ee = np.exp(A*tt)
    a1 = np.fft.irfft(ee*uu)
    a2 = np.fft.rfft(a1**2)
    return B*(1/ee)*a2

########################################################

# commands to initialise plot:
line, = plt.plot(x,u)         
line.axes.set_ylim(ymin,ymax)  

# a counter so that only every Kth configuration is plotted:
c = 0               
K = int(dtplot/dt)

# default message that the program has finished successfully:
stop = 'run completed'

######################################
#    START OF INTEGRATION ROUTINE    #
######################################

start = time.process_time()

try:
 for t in np.arange(0.0,tmax+dt,dt):
     
# Solve (d/dt)Uhat(k,t)=f(t,Uhat) using a 4th-order Runge-Kutta 
# method in time, where f(,) is given by the above routine:

    k1 = f(t,Uhat)
    k2 = f(t+0.5*dt,Uhat+0.5*dt*k1)
    k3 = f(t+0.5*dt,Uhat+0.5*dt*k2)
    k4 = f(t+dt,Uhat+dt*k3)
    Uhat += (dt/6)*(k1+2*k2+2*k3+k4)

# Every Kth configuration, compute u, check it and plot it:
    if c%K==0:
        # Compute u:
        e = np.exp(A*t)    
        uhat = e*Uhat
        u = np.fft.irfft(uhat)
        # Is u blowing up? If so, interrupt:
        if abs(uhat[-1])>200:
            stop = 'stopped early: unstable. Decrease the time step!'
            break
        # Has u gone off the screen? If so, increase y-range for plot:
        um = max(u)
        if um > ymax:
            ymax = ymax*1.5
            line.axes.set_ylim(ymin,ymax)
        # Plot u:
        plt.title('t='+'%.3f'%t)
        line.set_ydata(u.real)
        plt.draw()
        plt.pause(0.0001)
    c += 1

# Allow the user to interrupt the program:
except KeyboardInterrupt:
    stopearly = 'stopped early: keyboard interrupt'

runtime=(time.process_time()-start)

####################################
#    END OF INTEGRATION ROUTINE    #
####################################

########################################################
# end by printing out some information about the run:

print(stop)
print('run time: '+'%.5f' % runtime)
print('t = '+'%.4f' % t)
print('M = '+'%i' %  M)
print('max u = '+'%.4f' % um)
i=np.argmax(u)
print('index of max u = '+'%i' %  np.argmax(u))
print('location of max u : '+'%.4f' % x[i])
umaxpos=x[i]+(u[i-1]-u[i+1])/(u[i-1]-2*u[i]+u[i+1])*h/2
print('corrected location of max u : '+'%.4f' % umaxpos)

########################################################