import numpy as np
import matplotlib.pyplot as plt
import math
from scipy import optimize
from scipy.ndimage import binary_dilation

from titk_tools.io import imread, imsave, SpatialImage

def center_on_point(img, cm):
	'''
	Centers the image img on the point cm
	'''
	x = np.array(np.where(img > 0) )
	xp=np.zeros_like(x)
	img1=np.zeros_like(img)
	s=img.shape
	center_img=np.array(s)/2
	for i in [0, 1, 2]:
		xp[i]=x[i]+center_img[i]-cm[i]
	for i in range(0, len(x[0])):
		if ((xp[0][i]>=0) and (xp[0][i]<s[0]) and (xp[1][i]>=0) and (xp[1][i]<s[1]) and (xp[2][i]>=0) and (xp[2][i]<s[2])):
			img1[xp[0][i],xp[1][i],xp[2][i]]=img[x[0][i],x[1][i],x[2][i]]
	img1=SpatialImage(img1)
	img1.resolution=img.resolution
	return img1


def center_on_cm(img, volumic=True):
	'''
	Centers the image img on the point cm
	'''
	s=img.shape
	res=img.voxelsize
	if volumic:
		obj=img
	else :
		obj=binary_dilation(img)-img
	x = np.array(np.where(obj > 0) )
	cm = np.array([np.mean(x[i]) for i in [0,1,2]])
	x = np.array(np.where(img > 0) )
	xp=np.zeros_like(x)
	img1=np.zeros_like(img)
	center_img=np.array(s)/2
	for i in [0, 1, 2]:
		xp[i]=x[i]+center_img[i]-cm[i]
	for i in range(0, len(x[0])):
		if ((xp[0][i]>=0) and (xp[0][i]<s[0]) and (xp[1][i]>=0) and (xp[1][i]<s[1]) and (xp[2][i]>=0) and (xp[2][i]<s[2])):
			img1[xp[0][i],xp[1][i],xp[2][i]]=img[x[0][i],x[1][i],x[2][i]]
	img1=SpatialImage(img1)
	img1.voxelsize=res
	return img1


def principal_directions(img, threshold=10, volumic=True):	
	'''
	Returns the first two principal eigen vectors of an input image
	Fixed threshold of 10 for a 8 bit image
	It suppposes an isotropic sampling of the image.
	'''
	if volumic:
		obj=img
	else :
		obj=binary_dilation(img)-img
	x,y,z = np.where(obj*255 > threshold) ## appropriate thresholding 
	x = x - np.mean(x)
	y = y - np.mean(y)
	z = z - np.mean(z)
	coords = np.vstack([x,y,z])
	cov = np.cov(coords)
	evals,evecs = np.linalg.eig(cov)
	sort_indices = np.argsort(evals)[::-1]
	return list(evecs[:,sort_indices][0:2])

def direct_frame(pv):
	'''
	Constructs a direct frame with first two unitvectors v0 and v1
	'''
	return np.array([pv[0], pv[1], np.cross(pv[0], pv[1])])


def write_transformation(T, trsf_file):
	'''
	writes the affine transformation T (4x4 matrix) in
	a textfile read by blockmatching.
	'''
	f = open(trsf_file,'w')
	f.write("( "+"\n")
	f.write("08 "+"\n")
	for i in T:
		for j in i:
			f.write("  "+str(j)+" ")
		f.write("\n")
	f.write(") ")
	return



def write_rigid(trsf_file, R, a):
	'''
	write a rigid transformation with rotation matrix R and translation a
	'''
	T = np.identity(4)
	T[0:3,0:3] = R
	T = np.transpose(T)
	T[0:3,3] = a
	#write_transformation(T, trsf_file)
	np.savetxt(trsf_file, T, delimiter=";")
	return



def write_superposing_transfo(pv_flo, pv_ref, img, trsf_file):
	'''
	- pv_flo and pv-ref are lists of the first two principal vectors of the floating and the reference images respectively
	- img is the floating image
	'''
	# compute the rotation matrix R
	pframe_ref = direct_frame(pv_ref)
	pframe_flo = direct_frame(pv_flo)
	F = np.transpose(pframe_ref)
	G = np.transpose(pframe_flo)
	R = np.matmul(G, np.linalg.inv(F))
	# blockmatching rotates arround the origin,
	# so a rotation arround the middle of the image contains an additional translation t
	s = img.shape
	res = img.voxelsize
	com = np.array(res)*np.array(s)/2
	i = np.identity(3)
	t = np.dot((i-R),np.array(com))
	R = np.transpose(R)
	write_rigid(trsf_file, R, t)
	return


def add_side_slices(Img, x0, x1, y0, y1, z0, z1):
	"""
	adds 0 value slices on different sides of the image
	"""
	s=Img.shape
	ss=(s[0]+x0+x1,s[1]+y0+y1,s[2]+z0+z1)
	Img_new = (SpatialImage(np.zeros(ss))).astype('uint8')
	Img_new[x0:ss[0]-x1,y0:ss[1]-y1,z0:ss[2]-z1]=Img
	Img_new.voxelsize = Img.voxelsize
	return Img_new


def remove_side_slices(Img, x0, x1, y0, y1, z0, z1):
	"""
	removes slices on different sides of the image
	"""
	ss=Img.shape
	Img_new=Img[x0:ss[0]-x1,y0:ss[1]-y1,z0:ss[2]-z1]
	#Img_new.resolution = Img.resolution
	return Img_new


def measure_hight(a):
	b= np.where(a)
	height = 0
	if b[0].shape[0]>0 :
		height = b[0][-1]
	return height


def measure_height(a):
	b= np.where(a)
	height = [0,0]
	if b[0].shape[0]>0 :
		height = [b[0].min(),b[0].max()]
	return height


def extract_curve(x_section):
	sx = x_section.shape
	xx = np.zeros((sx[0],2))
	for i in range(0,sx[0]):
		xx[i,0] = i
		xx[i,1] = measure_hight(x_section[i,:])
	return xx


def extract_curves(x_section):
	sx = x_section.shape
	xup = np.zeros((sx[0],2))
	xdown = np.zeros((sx[0],2))
	for i in range(0,sx[0]):
		xup[i,0] = i
		xdown[i,0] = i
		xup[i,1] = measure_height(x_section[i,:])[1]
		xdown[i,1] = measure_height(x_section[i,:])[0]
	return xup, xdown


def write_curve(y, pixelsize, filename) :
	FILE = open(filename,"w")
	FILE.write('# '+str(pixelsize[0])+' '+str(pixelsize[1])+'\n')
	s = y.shape
	#FILE.write('# '+str(s[0])+'\n')
	for i in range(0,s[0]) :
		FILE.write(str(y[i,0])+' '+str(y[i,1])+'\n')
	FILE.close()
	return


def struct_element_sphere(radius):
	s=int(2*radius+1)
	structure=np.zeros((s,s,s))
	center = np.array([radius, radius, radius])
	Nl=range(s)
	for i in Nl:
		for j in Nl:
			for k in Nl:
				p=np.array([i,j,k])
				d=p-center
				dist=np.sqrt(sum(d*d))
				if dist<=radius:
					structure[i,j,k]=1
	return structure


def create_artificial_sepal(AA, RR, resolution):
	# A = parameters of an ellipse, as projected on xy plane (2d array)
	# R = curvature radius in directions x and y (2d array)
	# (A and R are given in micrometers)
	# resolution = voxelsize of the output stack
	A = np.array([AA[i]/resolution[i] for i in [0,1]])
	R = np.array([RR[i]/resolution[i] for i in [0,1]])
	h = np.array([math.sqrt(R[i]**2-A[i]**2) for i in range(len(A))])
	zmax=np.array([R[i]-math.sqrt(R[i]**2-A[i]**2) for i in [0,1]])
	#zmax= int(math.sqrt(zmax[0]*zmax[1]))
	zmax=zmax.mean()
	#zmax = 5
	print(zmax)
	marge = 0.2
	# creating the stack and the surface
	s = (int(A[0]*2*(1+marge)), int(A[1]*2*(1+marge)), int(zmax*(1+2*marge)))
	cm = np.array(s)/2
	x = np.array(range(0,s[0]))
	y = np.array(range(0,s[1]))
	xx = x - cm[0]
	yy = y - cm[1]
	#xgrid, ygrid = np.meshgrid(xx, yy, sparse=True)  # make sparse output arrays
	#mask=(((xgrid**2)/float(A[0])**2+(ygrid**2)/float(A[1])**2)<1).astype('uint8')
	z = np.zeros((s[0],s[1]))
	stack= np.zeros(s).astype('uint8')
	for i in x:
		for j in range(0,s[1]):
			if xx[i]**2/float(A[0])**2+yy[j]**2/float(A[1])**2<=1 :
				zx = (math.sqrt(R[0]**2-xx[i]**2)-h[0])*(1-abs(yy[j])/A[1])
				zy = (math.sqrt(R[1]**2-yy[j]**2)-h[1])*(1-abs(xx[i])/A[0])
				z[i,j] = math.sqrt(zx*zy)
				#z[i,j] = (zx+zy)/2
				#z[i,j] = 5
				stack[i,j,int(zmax*marge+z[i,j])]=1
	stack = SpatialImage(stack)
	stack.resolution = resolution
	return z, stack



def create_artificial_sepal_ellipse(AA, H, resolution):
	'''
	# A = parameters of an ellipse, as projected on xy plane (2d array)
	# R = curvature radius in directions x and y (2d array)
	# (A and R are given in micrometers)
	# resolution = voxelsize of the output stack
	'''
	A = np.array([AA[i]/resolution[i] for i in [0,1]])
	h = H/resolution[2]
	zmax=h
	marge = 0.2
	# creating the stack and the surface
	s = (int(A[0]*2*(1+marge)), int(A[1]*2*(1+marge)), int(zmax*(1+5*marge)))
	cm = np.array(s)/2
	x = np.array(range(0,s[0]))
	y = np.array(range(0,s[1]))
	xx = x - cm[0]
	yy = y - cm[1]
	#xgrid, ygrid = np.meshgrid(xx, yy, sparse=True)  # make sparse output arrays
	#mask=(((xgrid**2)/float(A[0])**2+(ygrid**2)/float(A[1])**2)<1).astype('uint8')
	z = np.zeros((s[0],s[1]))
	#z = -h*((xgrid**2)/float(A[0])**2+(ygrid**2)/float(A[1])**2)
	stack= np.zeros(s).astype('uint8')
	for i in x:
		for j in y:
			if xx[i]**2/float(A[0])**2+yy[j]**2/float(A[1])**2<=1 :
				z[i,j] = -h*((xx[i]**2)/float(A[0])**2+(yy[j]**2)/float(A[1])**2-1)
				stack[i,j,int(zmax*4*marge+z[i,j])]=1
	stack = SpatialImage(stack)
	stack.voxelsize = resolution
	return z, stack

# -------------------------------------------------------------------

def read_curve(filename) :
	"""
	Reads the coordinates of the point-list defining a 2D curve.
	**Returns**
	y : 2-column array
		defines a curve as an ordered list of points, 
		each row of the array represents a point on the curve and contains
		its 2D cartesian coordinates
	pixelsize : tuple (res0, res1) in micrometers
	"""	
	x0 = []
	pixelsize = (1.0,1.0)
	for line in file(filename):
		if line[0] == "#" :
			line = line.lstrip('# ')
			line = line.rstrip('\n')
			line_list = [float(x) for x in line.split(' ')]
			pixelsize = (line_list[0], line_list[1])
		else :
			line = line.rstrip('\n')
			line_list = [float(x) for x in line.split(' ')]
			x0.append(line_list)
	y = np.array(x0)
	return y, pixelsize



def distance(p1, p2):
	return np.sqrt((p1[0]-p2[0])**2+(p1[1]-p2[1])**2)

def distance_to_line(p1, p2, p0):
	'''
	gives the distance of point p0 to the line defined by points p1 and p2 (in 2d)
	'''
	return abs((p2[1]-p1[1])*p0[0]-(p2[0]-p1[0])*p0[1]+p2[0]*p1[1]-p2[1]*p1[0])/distance(p1,p2)

def curvature(p1,p2,p3):
	t1 = (p3[0]**2-p2[0]**2+p3[1]**2-p2[1]**2)/(2.0*(p3[1]-p2[1]))
	t2 = (p2[0]**2-p1[0]**2+p2[1]**2-p1[1]**2)/(2.0*(p2[1]-p1[1]))
	n1 = (p3[0]-p2[0])/(p3[1]-p2[1])
	n2 = (p2[0]-p1[0])/(p2[1]-p1[1])
	pcx = (t1-t2+p3[1]-p1[1])/(n1-n2)
	pc = [pcx, -n1*pcx+t1/2+p3[1]/2+p1[1]/2]
	R = distance(p1, pc)
	return 1.0/R, pc


def curvature2(p1,p2,p3):
	# http://www.ambrsoft.com/TrigoCalc/Circle3D.htm
	A = p1[0]*(p2[1]-p3[1])-p1[1]*(p2[0]-p3[0])+p2[0]*p3[1]-p3[0]*p2[1]
	B = (p1[0]**2+p1[1]**2)*(p3[1]-p2[1])+(p2[0]**2+p2[1]**2)*(p1[1]-p3[1])+(p3[0]**2+p3[1]**2)*(p2[1]-p1[1])
	C = (p1[0]**2+p1[1]**2)*(p3[0]-p2[0])+(p2[0]**2+p2[1]**2)*(p1[0]-p3[0])+(p3[0]**2+p3[1]**2)*(p2[0]-p1[0])
	D = (p1[0]**2+p1[1]**2)*(p3[0]*p2[1]-p2[0]*p3[1])+(p2[0]**2+p2[1]**2)*(p1[0]*p3[1]-p3[0]*p1[1])+(p3[0]**2+p3[1]**2)*(p2[0]*p1[1]-p1[0]*p2[1])
	x = -B/(2*A)
	y = C/(2*A)
	pc = [x, y]
	R = distance(p1, pc)
	return 1.0/R, pc


def curve_perle(y, first, last, step):
	yy=[]
	for i in range(first, last, step):
		if y[i][1]>0:
			yy.append(y[i])
	if i<last:
		yy.append(y[last])
	return np.array(yy)



def compute_curvature_radius(x,y):
	# coordinates of the barycenter
	x_m = np.mean(x)
	y_m = np.mean(y)
	def calc_R(c):
		""" calculate the distance of each 2D points from the center c=(xc, yc) """
		return np.sqrt((x-c[0])**2 + (y-c[1])**2)
	#
	def calc_ecart(c):
		""" calculate the algebraic distance between the 2D points and the mean circle centered at c=(xc, yc) """
		Ri = calc_R(c)
		return Ri - Ri.mean()
	#
	center_estimate = x_m, y_m
	center_2, ier = optimize.leastsq(calc_ecart, center_estimate)
	#
	xc_2, yc_2 = center_2
	Ri_2       = calc_R(center_2)
	R_2        = Ri_2.mean()
	residu_2   = sum((Ri_2 - R_2)**2)
	pc=[xc_2, yc_2]
	return R_2, pc


def circle(R, c):
	phi=np.linspace(np.pi/6., 5*np.pi/6., 51)
	x=c[0]+R*np.cos(phi)
	y=c[1]+R*np.sin(phi)
	return x,y



def analyze_curve1(name, y, pixelsize, outfilename,  title, graph_name='graph.png'):
	# isotropic pixels
	y[:,0] = y[:,0]*pixelsize[0]
	y[:,1] = y[:,1]*pixelsize[1]
	maxv = [y[:,i].max() for i in [0,1]]
	# find first and last points on the curve
	step = 5
	sep_indices = np.where(y[:,1])[0]
	first = sep_indices[0]#+step/2
	last = sep_indices[-1]#-step/2
	# compute flat length
	flatlength = distance(y[first],y[last])
	print("flat length =", flatlength," um")
	# compute curved length and height (maximal distance to the base)
	yy = curve_perle(y, first, last, step)
	length = 0
	height = 0
	height_index = 0
	p1 = yy[0]
	p2 = yy[-1]
	for i in range(0, len(yy)-1):
		length = length + distance(yy[i+1],yy[i])
		d = distance_to_line(p1, p2, yy[i])
		if d>height :
			height = d
			height_index = i
	print("length = ",length," um")
	print("height = ",height," um")
	middle_point = yy[int(len(yy)/2)]
	# compute curvature radius by fitting the sepal section to a circle
	yyy=y[first:last]
	R, pcR = compute_curvature_radius(yyy[:,0],yyy[:,1])
	print("R=",R," um")
	# cutting curve in pieces in order to estimate curvature radius on portions of the curve
	slength=len(yyy)
	y21=yyy[0:int(slength/2)]
	y22=yyy[int(slength/2):slength]
	y41=yyy[0:int(slength/4)]
	y423=yyy[int(slength/4):3*int(slength/4)]
	y44=yyy[3*int(slength/4):slength]
	# and computing the curvature radius for each portion
	R21, pc21 = compute_curvature_radius(y21[:,0],y21[:,1])
	R22, pc22 = compute_curvature_radius(y22[:,0],y22[:,1])
	R41, pc41 = compute_curvature_radius(y41[:,0],y41[:,1])
	R423, pc423 = compute_curvature_radius(y423[:,0],y423[:,1])
	R44, pc44 = compute_curvature_radius(y44[:,0],y44[:,1])
	fig = plt.figure(figsize=(15, 5))
	plt.plot(yy[:,0], yy[:,1], color='lightgreen', linewidth=10, label='length='+"{:.0f}".format(length)+" um")
	#plt.plot(y[:,0], y[:,1], '--')
	#plt.plot(y[first][0], y[first][1], 'ro')
	#plt.plot(middle_point[0], middle_point[1], 'ro')
	#plt.plot(y[last][0], y[last][1], 'ro')
	#plt.plot(yy[height_index][0], yy[height_index][1], 'go')
	#plt.plot(pcR[0], pcR[1], 'ro', color='royalblue')
	#plt.plot(pc423[0], pc423[1], 'ro', color='darkorange')
	xR,yR = circle(R, pcR)
	plt.plot(xR, yR, "--",color='royalblue', linewidth=3, label='R='+"{:.0f}".format(R)+" um")
	xR,yR = circle(R423, pc423)
	plt.plot(xR, yR, "--", color='tomato', linewidth=3, label='R423='+"{:.0f}".format(R423)+" um")
	plt.gca().set_aspect('equal', adjustable='box')
	plt.xlabel('x (um)')
	plt.ylabel('z (um)')
	plt.legend()
	plt.grid(True)
	plt.title(title)
	#plt.show()
	fig.savefig(graph_name)
	#plt.close()
	# write the data into a file
	FILE = open(outfilename,"a")
	FILE.write(name+';'+str(length)+';'+str(R)+';'+str(flatlength)+';'+str(height)+'\n')
	FILE.close()
	return length, flatlength, height, R, R21, R22, R41, R423, R44

def imsave2D(filename, matrix):
	fig = plt.figure()
	plt.gca().imshow(np.transpose(matrix), interpolation='none')
	fig.savefig(filename)
	return



def register_objects(img_flo, img_ref, preregisterd=True, ph=6, pl=2, es=5, fs=3, save_files=True, iterations=2):
	# -- "Specify paths to image files" --
	# ------------------------------------------------
	sequence_name = "register_"
	file_times = [1, 2]
	filenames = [sequence_name + str(t).zfill(2)  for t in file_times]
	list_images = [img_flo, img_ref]

	images = {}
	for i in [0,1]:
		images[filenames[i]] = list_images[i]

	rigid_images = {}
	rigid_transformations = {}
	affine_images = {}
	affine_transformations = {}
	registered_images = {}
	non_linear_transformations = {}
	[reference_filename, floating_filename, reference_time, floating_time] = [ filenames[0], filenames[1], file_times[0], file_times[1]]
	#for reference_filename, floating_filename, reference_time, floating_time, in zip(filenames[1:],filenames[:-1],file_times[1:],file_times[:-1]):
	print("================================")
	print("Registering ", floating_time, " on ", reference_time)
	print("================================")

	###################################################
	# Output structure
	###################################################

	namestring = "_ph"+str(ph)+"_pl"+str(pl)+"_es"+str(es)+"_fs"+str(fs)

	output_dirname = dirname + '/' + sequence_name + "_" + str(floating_time).zfill(2) + "_on_" + str(reference_time).zfill(2) + "" + namestring
	if not os.path.exists(output_dirname):
		os.makedirs(output_dirname)

	reference_img = images[reference_filename]
	floating_img = images[floating_filename]

	####################################################################

	# Registration of the floating images on ref
	# -----------------------------------------------------
	
	if (not preregistered) :
		# A. Find optimised rigid transformations
		rigid_img, rigid_trsf = find_rigid_transfo2(reference_img, floating_img)
		rigid_images[floating_filename] = rigid_img
		rigid_transformations[floating_filename] = rigid_trsf
		
		if save_files:
			# rigid-registered images
			rigid_filename = output_dirname + '/' + floating_filename + "_rigid_on_" + str(reference_time).zfill(2) + ".inr.gz"
			imsave(rigid_filename, rigid_img)

			# optimised (with block-matching) rigid transformations
			rigid_trsf_filename = output_dirname + '/tr_' + floating_filename + '_rigid.txt'
			tfsave(rigid_trsf_filename, rigid_trsf)

		# B. Find optimised affine transformations (initialised by the rigid tfs computed above)
		affine_img, affine_trsf = find_affine_transfo(reference_img, floating_img, init_trsf=rigid_trsf)
		affine_images[floating_filename] = affine_img
		affine_transformations[floating_filename] = affine_trsf

		if save_files:
			# affine-registered images
			affine_filename = output_dirname + '/' + floating_filename + "_affine_on_" + str(reference_time).zfill(2) + ".inr.gz"
			imsave(affine_filename, affine_img)

			# optimised (with block-matching) affine transformations
			affine_trsf_filename = output_dirname + '/tr_' + floating_filename + '_affine.txt'
			tfsave(affine_filename, affine_trsf)
		floating_img0=affine_img
	else :
		floating_img0=floating_img
	# C. Find nonlinear transformations which register the affine-transformed images
	#registered_img, nl_trsf = find_nl_transfo(reference_img, affine_img, ph=ph, pl=pl, es=es, fs=fs)
	for it in range(iterations):
		registered_img, nl_trsf = find_nl_transfo(reference_img, floating_img0, ph=ph, pl=pl, es=es, fs=fs)
		registered_images[floating_filename] = registered_img
		non_linear_transformations[floating_filename] = nl_trsf

		# paths to nonlinearly registered images
		registered_filename = output_dirname + '/' + floating_filename + "_registered_on_" + str(reference_time).zfill(2) + "_it"+str(it)+".inr.gz"
		imsave(registered_filename, registered_img)
		floating_img0=registered_img

		if save_files:
			# non-linear transformations
			nl_trsf_filename = output_dirname + '/' + floating_filename + "_it"+str(it)+"_vectorfield.inr.gz"
			save_trsf(nl_trsf, nl_trsf_filename, compress=True)
	return registered_img

