#!/usr/bin/env python

# before lounching the executable activate the timagetk conda environement
# conda activate florivar

# general imports :
from sys import path, argv
import os
import time
import numpy as np
from skimage.filters import threshold_otsu

from titk_tools.io import imread, imsave, SpatialImage
#from timagetk.plugins import linear_filtering
from scipy.ndimage import gaussian_filter

# multiprocessing
from multiprocessing import Pool
nproc = 20 # number of processors

print('=======================================')

def gauss_titk(img, radius):
	'''
	gaussian filter with radius in real units
	- if radius=a number : isotropic
	- if radius= a list of numbers : anisotropic
	'''
	img_filtered = linear_filtering(img, std_dev=radius, method='gaussian_smoothing')#, real=True)
	return img_filtered


def gauss_nd(img, radius):
	'''
	gaussian filter with radius in real units
	- if radius=a number : isotropic
	- if radius= a list of numbers : anisotropic
	'''
	img_filtered = gaussian_filter(img, radius)
	return img_filtered

def iterated_gauss_2d(img):
	vs=img.voxelsize
	rad=2*vs[2]/vs[0]
	#rad=10
	radius = [rad, rad, 0]
	img_filtered=gaussian_filter(img, radius)
	rad=rad/2
	while rad>1:
		print(rad)
		radius=[rad,rad,0]
		img_filtered = gaussian_filter(img_filtered, radius)
		rad=rad/2
	return SpatialImage(img_filtered, dtype='uint8', voxelsize=vs, origin=[0, 0, 0])


indir = argv[1]
T = int(argv[2])

newdir=indir+'_normalised-OtsuT'+str(T)
os.system("mkdir -p "+newdir)

print('Normalising images in the folder:',indir)
print('Common Otsu threshold for normalisation:',T)
print('Results will be written in the folder:',newdir)

pathdir=indir
listfiles=os.listdir(pathdir)

def treat_file(i):
	filename = listfiles[i]
	extension = filename.split('.')[-1]
	if extension=='tif':
		print('--------------------------')
		print(filename)
		im = imread(pathdir+'/'+filename)
		# filter for computing Otsu
		#imf = iterated_gauss_2d(im)
		imf=im
		if len(np.unique(imf))>0:#20:
			iotsu = threshold_otsu(imf)
			# normalise the original image
			imin = np.min(im)+0.0
			imax= np.max(im)+0.0
			a = T/(iotsu-imin)
			b = - imin*a
			transformed_im = SpatialImage(a*im+b, dtype='uint16', voxelsize=im.voxelsize, origin=[0, 0, 0])
			transformed_im[transformed_im>255]=255 		
			transformed_im = SpatialImage(transformed_im, dtype='uint8', voxelsize=im.voxelsize, origin=[0, 0, 0])
			tiotsu = threshold_otsu(transformed_im)
			imf = iterated_gauss_2d(transformed_im)
			tiotsuf = threshold_otsu(imf)
			print('   otsu=',iotsu,'  otsu normalised=',tiotsu,'  otsu normalised filtered=',tiotsuf)
			newname = 'norm'+str(T)+'-'+filename
			newnamef = 'filtered-'+newname
			#imsave(newdir+'/'+newname,transformed_im)
			imsave(newdir+'/'+newnamef,imf)
	return 0


inputs = range(len(listfiles))


for i in range(len(inputs)):
	treat_file(i)


'''
pool = Pool(processes=nproc)
pool.map(treat_file, inputs)
pool.close()
pool.join()
'''
