#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Demonstrateur OpenCL d'interaction NCorps

Emmanuel QUEMENER <emmanuel.quemener@ens-lyon.fr> CeCILLv2
"""
import getopt
import sys
import time
import numpy as np
import pyopencl as cl
import pyopencl.array as cl_array
from numpy.random import randint as nprnd

def DictionariesAPI():
    Marsaglia={'CONG':0,'SHR3':1,'MWC':2,'KISS':3}
    Computing={'FP32':0,'FP64':1}
    return(Marsaglia,Computing)

BlobOpenCL= """
#define znew  ((z=36969*(z&65535)+(z>>16))<<16)
#define wnew  ((w=18000*(w&65535)+(w>>16))&65535)
#define MWC   (znew+wnew)
#define SHR3  (jsr=(jsr=(jsr=jsr^(jsr<<17))^(jsr>>13))^(jsr<<5))
#define CONG  (jcong=69069*jcong+1234567)
#define KISS  ((MWC^CONG)+SHR3)

#define MWCfp MWC * 2.328306435454494e-10f
#define KISSfp KISS * 2.328306435454494e-10f
#define SHR3fp SHR3 * 2.328306435454494e-10f
#define CONGfp CONG * 2.328306435454494e-10f

#define TFP32 0
#define TFP64 1

#define LENGTH 1.

#define PI 3.14159265359

#define SMALL_NUM 0.000000001

#if TYPE == TFP32
#define MYFLOAT4 float4
#define MYFLOAT8 float8
#define MYFLOAT float
#else
#pragma OPENCL EXTENSION cl_khr_fp64: enable
#define MYFLOAT4 double4
#define MYFLOAT8 double8
#define MYFLOAT double
#endif

MYFLOAT4 Interaction(MYFLOAT4 m,MYFLOAT4 n)
{
    // return((n-m)/(MYFLOAT)pow(distance(n,m),2));
    return((n-m)/(MYFLOAT)pow(distance(n,m),2));
}

MYFLOAT PairPotential(MYFLOAT4 m,MYFLOAT4 n)
{
    return((MYFLOAT)-1./distance(n,m));
}

// Elements from : http://doswa.com/2009/01/02/fourth-order-runge-kutta-numerical-integration.html


MYFLOAT8 AtomicRungeKutta(__global MYFLOAT8* clDataIn,int gid,MYFLOAT dt)
{
    MYFLOAT4 x0=(MYFLOAT4)clDataIn[gid].lo;
    MYFLOAT4 v0=(MYFLOAT4)clDataIn[gid].hi;
    MYFLOAT4 a0=(MYFLOAT4)(0.,0.,0.,0.);
    int N = get_global_size(0);    
    
    for (int i=0;i<N;i++)
    {
        if (gid != i)
        a0+=Interaction(x0,clDataIn[i].lo);
    }
        
    MYFLOAT4 x1=x0+v0*(MYFLOAT)0.5*dt;
    MYFLOAT4 v1=v0+a0*(MYFLOAT)0.5*dt;
    MYFLOAT4 a1=(MYFLOAT4)(0.,0.,0.,0.);
    for (int i=0;i<N;i++)
    {
        if (gid != i)
        a1+=Interaction(x1,clDataIn[i].lo);
    }

    MYFLOAT4 x2=x0+v1*(MYFLOAT)0.5*dt;
    MYFLOAT4 v2=v0+a1*(MYFLOAT)0.5*dt;
    MYFLOAT4 a2=(MYFLOAT4)(0.,0.,0.,0.);
    for (int i=0;i<N;i++)
    {
        if (gid != i)
        a2+=Interaction(x2,clDataIn[i].lo);
    }
    
    MYFLOAT4 x3=x0+v2*dt;
    MYFLOAT4 v3=v0+a2*dt;
    MYFLOAT4 a3=(MYFLOAT)(0.,0.,0.,0.);
    for (int i=0;i<N;i++)
    {
        if (gid != i)
        a3+=Interaction(x3,clDataIn[i].lo);
    }
    
    MYFLOAT4 xf=x0+dt*(v0+(MYFLOAT)2.*(v1+v2)+v3)/(MYFLOAT)6.;
    MYFLOAT4 vf=v0+dt*(a0+(MYFLOAT)2.*(a1+a2)+a3)/(MYFLOAT)6.;
     
    return((MYFLOAT8)(xf.s0,xf.s1,xf.s2,xf.s3,vf.s0,vf.s1,vf.s2,vf.s3));
}

// Elements from : http://doswa.com/2009/01/02/fourth-order-runge-kutta-numerical-integration.html

MYFLOAT8 AtomicRungeKutta2(__global MYFLOAT8* clDataIn,int gid,MYFLOAT dt)
{
    MYFLOAT4 x[4],v[4],a[4],xf,vf;
    int N=get_global_size(0);

    x[0]=clDataIn[gid].lo;
    v[0]=clDataIn[gid].hi;
    a[0]=(0.,0.,0.,0.);
    
    for (int i=0;i<N;i++)
    {
        if (gid != i)
        a[0]+=Interaction(x[0],clDataIn[i].lo);
    }
        
    x[1]=x[0]+v[0]*(MYFLOAT)0.5*dt;
    v[1]=v[0]+a[0]*(MYFLOAT)0.5*dt;
    a[1]=(0.,0.,0.,0.);
    for (int i=0;i<N;i++)
    {
        if (gid != i)
        a[1]+=Interaction(x[1],clDataIn[i].lo);
    }

    x[2]=x[0]+v[1]*(MYFLOAT)0.5*dt;
    v[2]=v[0]+a[1]*(MYFLOAT)0.5*dt;
    a[2]=(0.,0.,0.,0.);
    for (int i=0;i<N;i++)
    {
        if (gid != i)
        a[2]+=Interaction(x[2],clDataIn[i].lo);
    }
    
    x[3]=x[0]+v[2]*dt;
    v[3]=v[0]+a[2]*dt;
    a[3]=(0.,0.,0.,0.);
    for (int i=0;i<N;i++)
    {
        if (gid != i)
        a[3]+=Interaction(x[3],clDataIn[i].lo);
    }
    
    xf=x[0]+dt*(v[0]+(MYFLOAT)2.*(v[1]+v[2])+v[3])/(MYFLOAT)6.;
    vf=v[0]+dt*(a[0]+(MYFLOAT)2.*(a[1]+a[2])+a[3])/(MYFLOAT)6.;
     
    return((MYFLOAT8)(xf.s0,xf.s1,xf.s2,xf.s3,vf.s0,vf.s1,vf.s2,vf.s3));
}

MYFLOAT8 AtomicEuler(__global MYFLOAT8* clDataIn,int gid,MYFLOAT dt)
{
    MYFLOAT4 x,v,a,xf,vf;

    x=clDataIn[gid].lo;
    v=clDataIn[gid].hi;
    a=(0.,0.,0.,0.);
    for (int i=0;i<get_global_size(0);i++)
    {
        if (gid != i)
        a+=Interaction(x,clDataIn[i].lo);
    }
       
    vf=v+dt*a;
    xf=x+dt*vf;
 
    return((MYFLOAT8)(xf.s0,xf.s1,xf.s2,xf.s3,vf.s0,vf.s1,vf.s2,vf.s3));
}

__kernel void SplutterPoints(__global MYFLOAT8* clData, MYFLOAT box, MYFLOAT velocity,
                               uint seed_z,uint seed_w)
{
    int gid = get_global_id(0);
    MYFLOAT N = (MYFLOAT) get_global_size(0);
    uint z=seed_z+(uint)gid;
    uint w=seed_w-(uint)gid;
    
    MYFLOAT theta=MWCfp*PI;
    MYFLOAT phi=MWCfp*PI*(MYFLOAT)2.;
    MYFLOAT sinTheta=sin(theta);
    clData[gid].s01234567 = (MYFLOAT8) (box*(MYFLOAT)(MWCfp-0.5),box*(MYFLOAT)(MWCfp-0.5),box*(MYFLOAT)(MWCfp-0.5),0.,0.,0.,0.,0.);
    MYFLOAT v=sqrt(N*(MYFLOAT)2./distance(clData[gid].lo,(MYFLOAT4) (0.,0.,0.,0.)));
    clData[gid].s4=v*sinTheta*cos(phi);
    clData[gid].s5=v*sinTheta*sin(phi);
    clData[gid].s6=v*cos(theta);
}

__kernel void RungeKutta(__global MYFLOAT8* clData,MYFLOAT h)
{
    int gid = get_global_id(0);
    
    MYFLOAT8 clDataGid=AtomicRungeKutta(clData,gid,h);
    barrier(CLK_GLOBAL_MEM_FENCE);
    clData[gid]=clDataGid;
}

__kernel void Euler(__global MYFLOAT8* clData,MYFLOAT h)
{
    int gid = get_global_id(0);
    
    MYFLOAT8 clDataGid=AtomicEuler(clData,gid,h);
    barrier(CLK_GLOBAL_MEM_FENCE);
    clData[gid]=clDataGid;
}

__kernel void Potential(__global MYFLOAT8* clData,__global MYFLOAT* clPotential)
{
    int gid = get_global_id(0);

    MYFLOAT potential=0.;
    MYFLOAT4 x=clData[gid].lo; 
    
    for (int i=0;i<get_global_size(0);i++)
    {
        if (gid != i)
        potential+=PairPotential(x,clData[i].lo);
    }
                 
    barrier(CLK_GLOBAL_MEM_FENCE);
    clPotential[gid]=(MYFLOAT)0.5*potential;
}

__kernel void Kinetic(__global MYFLOAT8* clData,__global MYFLOAT* clKinetic)
{
    int gid = get_global_id(0);
    
    clKinetic[gid]=(MYFLOAT)0.5*(pow(clData[gid].s4,2)+pow(clData[gid].s5,2)+pow(clData[gid].s6,2));
}
"""

def Energy(MyData):
    return(sum(pow(MyData,2)))

if __name__=='__main__':
    
    # ValueType
    ValueType='FP32'
    class MyFloat(np.float32):pass
    clType=cl_array.vec.float8
    # Set defaults values
    np.set_printoptions(precision=2)  
    # Id of Device : 1 is for first find !
    Device=1
    # Iterations is integer
    Number=4
    # Size of box
    SizeOfBox=MyFloat(1.)
    # Initial velocity of particules
    Velocity=MyFloat(1.)
    # Redo the last process
    Iterations=100
    # Step
    Step=MyFloat(0.01)
    # Method of integration
    Method='RungeKutta'
    # InitialRandom
    InitialRandom=False
    # RNG Marsaglia Method
    RNG='MWC'
    # CheckEnergies
    CheckEnergies=False
    # Display samples in 3D
    GraphSamples=False    
    
    HowToUse='%s -h [Help] -r [InitialRandom] -g [GraphSamples] -c [CheckEnergies] -d <DeviceId> -n <NumberOfParticules> -z <SizeOfBox> -v <Velocity> -s <Step> -i <Iterations> -m <RungeKutta|Euler> -t <FP32|FP64>'

    try:
        opts, args = getopt.getopt(sys.argv[1:],"rhgcd:n:z:v:i:s:m:t:",["random","graph","check","device=","number=","size=","velocity=","iterations=","step=","method=","valuetype="])
    except getopt.GetoptError:
        print(HowToUse % sys.argv[0])
        sys.exit(2)

    for opt, arg in opts:
        if opt == '-h':
            print(HowToUse % sys.argv[0])

            print("\nInformations about devices detected under OpenCL:")
            try:
                Id=0
                for platform in cl.get_platforms():
                    for device in platform.get_devices():
                        deviceType=cl.device_type.to_string(device.type)
                        print("Device #%i from %s of type %s : %s" % (Id,platform.vendor.lstrip(),deviceType,device.name.lstrip()))
                        Id=Id+1
                sys.exit()
            except ImportError:
                print("Your platform does not seem to support OpenCL")
                sys.exit()

        elif opt in ("-t", "--valuetype"):
            if arg=='FP64':
                class MyFloat(np.float64): pass
                clType=cl_array.vec.double8
            else:
                class MyFloat(np.float32):pass
                clType=cl_array.vec.float8
            ValueType = arg
        elif opt in ("-d", "--device"):
            Device=int(arg)
        elif opt in ("-m", "--method"):
            Method=arg
        elif opt in ("-n", "--number"):
            Number=int(arg)
        elif opt in ("-z", "--size"):
            SizeOfBox=MyFloat(arg)
        elif opt in ("-v", "--velocity"):
            Velocity=MyFloat(arg)
        elif opt in ("-s", "--step"):
            Step=MyFloat(arg)
        elif opt in ("-i", "--iterations"):
            Iterations=int(arg)
        elif opt in ("-r", "--random"):
            InitialRandom=True
        elif opt in ("-c", "--check"):
            CheckEnergies=True
        elif opt in ("-g", "--graph"):
            GraphSamples=True
                        
    SizeOfBox=MyFloat(SizeOfBox)
    Velocity=MyFloat(Velocity)
    Step=MyFloat(Step)
                
    print("Device choosed : %s" % Device)
    print("Number of particules : %s" % Number)
    print("Size of Box : %s" % SizeOfBox)
    print("Initial velocity % s" % Velocity)
    print("Number of iterations % s" % Iterations)
    print("Step of iteration % s" % Step)
    print("Method of resolution % s" % Method)
    print("Initial Random for RNG Seed % s" % InitialRandom)
    print("Check for Energies % s" % CheckEnergies)
    print("Graph for Samples % s" % GraphSamples)
    print("ValueType is % s" % ValueType)

    # Create Numpy array of CL vector with 8 FP32    
    MyData = np.zeros(Number, dtype=clType)
    MyPotential = np.zeros(Number, dtype=MyFloat)
    MyKinetic = np.zeros(Number, dtype=MyFloat)

    Marsaglia,Computing=DictionariesAPI()

    # Scan the OpenCL arrays
    Id=0
    HasXPU=False
    for platform in cl.get_platforms():
        for device in platform.get_devices():
            if Id==Device:
                PlatForm=platform
                XPU=device
                print("CPU/GPU selected: ",device.name.lstrip())
                HasXPU=True
            Id+=1

    if HasXPU==False:
        print("No XPU #%i found in all of %i devices, sorry..." % (Device,Id-1))
        sys.exit()      

    # Create Context
    try:
        ctx = cl.Context([XPU])
        queue = cl.CommandQueue(ctx,properties=cl.command_queue_properties.PROFILING_ENABLE)
    except:
        print("Crash during context creation")

    print(Marsaglia[RNG],Computing[ValueType])
    # Build all routines used for the computing
    MyRoutines = cl.Program(ctx, BlobOpenCL).build(options = "-cl-mad-enable -cl-fast-relaxed-math -DTRNG=%i -DTYPE=%i" % (Marsaglia[RNG],Computing[ValueType]))

# Initial forced values for exploration
#    MyData[0][0]=np.float32(-1.)
#    MyData[0][1]=np.float32(0.)
#    MyData[0][5]=np.float32(1.)
#    MyData[1][0]=np.float32(1.)
#    MyData[1][1]=np.float32(0.)
#    MyData[1][5]=np.float32(-1.)

    mf = cl.mem_flags
    clData = cl.Buffer(ctx, mf.READ_WRITE, MyData.nbytes)
    clPotential = cl.Buffer(ctx, mf.READ_WRITE, MyPotential.nbytes)
    clKinetic = cl.Buffer(ctx, mf.READ_WRITE, MyKinetic.nbytes)
    #clData = cl.Buffer(ctx, mf.WRITE_ONLY|mf.COPY_HOST_PTR,hostbuf=MyData)

    print('All particles superimposed.')

    print(SizeOfBox.dtype)
    
    # Set particles to RNG points
    if InitialRandom:
        MyRoutines.SplutterPoints(queue,(Number,1),None,clData,SizeOfBox,Velocity,np.uint32(nprnd(2**32)),np.uint32(nprnd(2**32)))
    else:
        MyRoutines.SplutterPoints(queue,(Number,1),None,clData,SizeOfBox,Velocity,np.uint32(110271),np.uint32(250173))

    print('All particules distributed')
 
    CLLaunch=MyRoutines.Potential(queue,(Number,1),None,clData,clPotential)
    CLLaunch.wait()
    if CheckEnergies:
        cl.enqueue_copy(queue,MyPotential,clPotential)
        CLLaunch=MyRoutines.Kinetic(queue,(Number,1),None,clData,clKinetic)
        CLLaunch.wait()
        cl.enqueue_copy(queue,MyKinetic,clKinetic)
        # print(np.sum(MyPotential)+2*np.sum(MyKinetic),np.sum(MyPotential),np.sum(MyKinetic),MyPotential,MyKinetic)
        print(np.sum(MyPotential)+2*np.sum(MyKinetic),np.sum(MyPotential),np.sum(MyKinetic))
 
    if GraphSamples:
        cl.enqueue_copy(queue, MyData, clData)
        t0=np.array([[MyData[0][0],MyData[0][1],MyData[0][2]]])
        t1=np.array([[MyData[1][0],MyData[1][1],MyData[1][2]]])
        tL=np.array([[MyData[-1][0],MyData[-1][1],MyData[-1][2]]])

    time_start=time.time()
    for i in range(Iterations):
        if Method=="RungeKutta":            
            CLLaunch=MyRoutines.RungeKutta(queue,(Number,1),None,clData,Step)
        else:
            CLLaunch=MyRoutines.Euler(queue,(Number,1),None,clData,Step)
        CLLaunch.wait()
        if CheckEnergies:
            CLLaunch=MyRoutines.Potential(queue,(Number,1),None,clData,clPotential)
            CLLaunch.wait()
            cl.enqueue_copy(queue,MyPotential,clPotential)
            CLLaunch=MyRoutines.Kinetic(queue,(Number,1),None,clData,clKinetic)
            CLLaunch.wait()
            cl.enqueue_copy(queue,MyKinetic,clKinetic)
            # print(np.sum(MyPotential)+2*np.sum(MyKinetic),np.sum(MyPotential),np.sum(MyKinetic),MyPotential,MyKinetic)
            print(np.sum(MyPotential)+2*np.sum(MyKinetic),np.sum(MyPotential),np.sum(MyKinetic))

        if GraphSamples:
            cl.enqueue_copy(queue, MyData, clData)
            t0=np.append(t0,[MyData[0][0],MyData[0][1],MyData[0][2]])
            t1=np.append(t1,[MyData[1][0],MyData[1][1],MyData[1][2]])
            tL=np.append(tL,[MyData[-1][0],MyData[-1][1],MyData[-1][2]])
    print("\nDuration on %s for each %s" % (Device,(time.time()-time_start)/Iterations))

    if GraphSamples:    
        t0=np.transpose(np.reshape(t0,(Iterations+1,3)))
        t1=np.transpose(np.reshape(t1,(Iterations+1,3)))
        tL=np.transpose(np.reshape(tL,(Iterations+1,3)))
    
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
    
        fig = plt.figure()
        ax = fig.gca(projection='3d')
        ax.scatter(t0[0],t0[1],t0[2], marker='^',color='blue')
        ax.scatter(t1[0],t1[1],t1[2], marker='o',color='red')
        ax.scatter(tL[0],tL[1],tL[2], marker='D',color='green')
   
        ax.set_xlabel('X Label')
        ax.set_ylabel('Y Label')
        ax.set_zlabel('Z Label')

        plt.show()
    
    clData.release()
    clKinetic.release()
    clPotential.release()
