from unittest import TestCase

from isolated import *

mol = Chem.MolFromMolFile('acetic.mol', removeHs=False)
num_confs = 5
Chem.EmbedMultipleConfs(mol, num_confs)


class Test(TestCase):
    def test_remove_c_linked_hs(self):
        h2o = Chem.MolFromSmiles('O')
        h2o = Chem.AddHs(h2o)
        ch4 = Chem.MolFromSmiles('C')
        ch4 = Chem.AddHs(ch4)
        self.assertEqual(len(remove_C_linked_Hs(h2o).GetAtoms()), 3)
        self.assertEqual(len(remove_C_linked_Hs(ch4).GetAtoms()), 1)

    def test_gen_confs(self):
        self.assertEqual(gen_confs(mol, num_confs).GetNumConformers(),
                         num_confs)

    """def test_rmsd(self):
        self.assertIsInstance(get_rmsd(mol), np.ndarray)
        self.assertEqual(get_rmsd(mol).shape, (num_confs, num_confs))
        tril_T = np.tril(get_rmsd(mol)).T
        triu = np.triu(get_rmsd(mol))
        self.assertTrue(np.array_equal(tril_T, triu))
        mol.RemoveAllConformers()
        self.assertRaises(ValueError, get_rmsd, mol)"""  # TODO -> clustering

    def test_moments_of_inertia(self):
        self.assertIsInstance(get_moments_of_inertia(mol), np.ndarray)
        self.assertEqual(get_moments_of_inertia(mol).shape, (num_confs, 3))

    def test_mmff_opt_confs(self):
        Chem.EmbedMultipleConfs(mol, num_confs)
        self.assertIsInstance(mmff_opt_confs(mol)[0], Chem.rdchem.Mol)
        self.assertIsInstance(mmff_opt_confs(mol)[1], np.ndarray)
        self.assertIsInstance(mmff_opt_confs(mol, max_iters=0), np.ndarray)
