import sys
import unittest

from rdkit.Chem import Mol, AllChem as Chem
from ase.atoms import Atoms
from modules.formats import adapt_format, confs_to_mol_list, \
    rdkit_mol_to_ase_atoms

mol = Chem.MolFromMolFile('acetic.mol', removeHs=False)
num_confs = 5
Chem.EmbedMultipleConfs(mol, num_confs)


class TestConfsToMolList(unittest.TestCase):
    def test_confs_to_mol_list(self):
        self.assertIsInstance(confs_to_mol_list(mol), list)
        self.assertEqual(len(confs_to_mol_list(mol)), num_confs)
        self.assertEqual(len(confs_to_mol_list(mol, [0, 2, 4])), 3)


class TestMolToAtoms(unittest.TestCase):
    def test_rdkit_mol_to_ase_atoms(self):
        self.assertIsInstance(rdkit_mol_to_ase_atoms(mol), Atoms)
        atoms = rdkit_mol_to_ase_atoms(mol)
        self.assertEqual(mol.GetNumAtoms(), atoms.get_global_number_of_atoms())


class TestAdaptFormat(unittest.TestCase):
    def test_rdkit_mol(self):
        self.assertIsInstance(
            adapt_format('rdkit', 'acetic.mol'),
            Mol)

    def test_rdkit_xyz(self):
        self.assertIsInstance(
            adapt_format('rdkit', 'acetic.xyz'), Mol)

    def test_ase_mol(self):
        self.assertIsInstance(
            adapt_format('ase', 'acetic.mol'),
            Atoms)

    def test_ase_xyz(self):
        self.assertIsInstance(
            adapt_format('ase', 'acetic.xyz'), Atoms)

    def test_not_adeq_req(self):
        self.assertRaises(NotImplementedError, adapt_format, 'hola',
                          'acetic.xyz')

    def test_wrong_file_type(self):
        self.assertRaises(NotImplementedError, adapt_format, 'ase', 'good.inp')


if __name__ == '__main__':
    unittest.main()
