"""symeig test functions."""

import numpy, unittest
from testing_tools import \
     assert_array_almost_equal_diff as assert_array_almost_equal
from testing_tools import assert_array_equal, assert_type_equal,\
     symrand, hermitian

from symeig import symeig, SymeigException


# some abbrev.
dot = numpy.dot

# matrix dimension in tests
DIM = 5
# digit precision to use in asserts for the different types
DIGITS = {'d':11, 'D':11, 'f':3, 'F':3}

class SymeigTestCase(unittest.TestCase):
    
    def eigenproblem(self, dim, dtype, overwrite, turbo, range):
        """Solve a standard eigenvalue problem."""
        a = symrand(dim, dtype)
        if overwrite == 1:
            a_c = a.copy()
        else:
            a_c = a
        w,z = symeig(a, turbo=turbo, overwrite=overwrite, range=range)
        # assertions
        assert_type_equal(z.dtype, dtype)
        w = w.astype(dtype)
        diag = numpy.diagonal(dot(hermitian(z), dot(a_c, z))).real
        assert_array_almost_equal(diag, w, DIGITS[dtype])

    def geneigenproblem(self, dim, dtype, overwrite, turbo, range):
        """Solve a generalized eigenvalue problem."""
        a = symrand(dim, dtype) 
        # make it definite positive
        b = symrand(dim, dtype)+numpy.diag([2.1]*dim).astype(dtype)
        if overwrite == 1:
            a_c, b_c = a.copy(),  b.copy()
        else:
            a_c, b_c = a, b
        w,z = symeig(a,b,turbo=turbo,overwrite=overwrite,range=range)
        # assertions
        assert_type_equal(z.dtype, dtype)
        w = w.astype(dtype)
        diag1 = numpy.diagonal(dot(hermitian(z), dot(a_c, z))).real
        assert_array_almost_equal(diag1, w, DIGITS[dtype])
        diag2 = numpy.diagonal(dot(hermitian(z), dot(b_c, z))).real
        assert_array_almost_equal(diag2, \
                                  numpy.ones(diag2.shape[0]), DIGITS[dtype])

    # non-contiguous case
    def eigenproblem_nc(self, dim, dtype):
        a_c = symrand(dim, dtype)
        
        assert a_c.flags.c_contiguous, \
               'Contiguous matrix doen not appear to be contiguous!'
        a_nc = a_c[::-1,::-1]
        assert not a_c.flags.f_contiguous,\
               'Not contiguous matrix appears to be fortran contiguous!'
        #self.assertRaises(SymeigException,symeig,a_nc,overwrite=1)
        a_fc = numpy.transpose(a_c)
        assert not a_c.flags.f_contiguous,\
               'Fortran contiguous matrix does not appear to be '+\
               'fortran contiguous!'
        w,z = symeig(a_fc, overwrite=1)
        
    def testOverwriteBug(self):
        # reproduces a bug when B=None and overwrite=0
        a = symrand(5, 'd')
        w,z = symeig(a, overwrite=0)

    def testWrongDimensionsBug(self):
        # reproduces a bug when len(A.shape) != 2 or A.shape[0] != A.shape[1]
        try: 
            a = numpy.random.rand(10)
            w, z = symeig(a)
            msg = 'Did not complain for wrong dimensions!'
            raise Exception
        except SymeigException:
            pass
        try: 
            a = numpy.random.rand(10, 3)
            w, z = symeig(a)
            msg = 'Did not complain for wrong dimensions!'
            raise Exception
        except SymeigException:
            pass
        try: 
            a = numpy.random.rand(10, 10)
            b = numpy.random.rand(10, 2)
            w, z = symeig(a, b)
            msg = 'Did not complain for wrong dimensions!'
            raise Exception
        except SymeigException:
            pass

    def testReal(self):
        range = (2,DIM-1)
        self.eigenproblem(DIM,'d',0,'off',None)
        self.eigenproblem(DIM,'f',0,'off',None)
        self.eigenproblem(DIM,'d',1,'off',None)
        self.eigenproblem(DIM,'f',1,'off',None)
        self.eigenproblem(DIM,'d',0,'off',range)
        self.eigenproblem(DIM,'f',0,'off',range)
        self.eigenproblem(DIM,'d',1,'off',range)
        self.eigenproblem(DIM,'f',1,'off',range)
        
    def testRealGeneralized(self):
        range = (2,DIM-1)        
        self.geneigenproblem(DIM,'d',0,'on',None)
        self.geneigenproblem(DIM,'f',0,'on',None)
        self.geneigenproblem(DIM,'d',1,'on',None)
        self.geneigenproblem(DIM,'f',1,'on',None)
        self.geneigenproblem(DIM,'d',0,'off',None)
        self.geneigenproblem(DIM,'f',0,'off',None)
        self.geneigenproblem(DIM,'d',1,'off',None)
        self.geneigenproblem(DIM,'f',1,'off',None)
        self.geneigenproblem(DIM,'d',0,'off',range)
        self.geneigenproblem(DIM,'f',0,'off',range)
        self.geneigenproblem(DIM,'d',1,'off',range)
        self.geneigenproblem(DIM,'f',1,'off',range)


    def testComplex(self):
        range = (2,DIM-1)
        self.eigenproblem(DIM,'D',0,'off',None)
        self.eigenproblem(DIM,'F',0,'off',None)
        self.eigenproblem(DIM,'D',1,'off',None)
        self.eigenproblem(DIM,'F',1,'off',None)
        self.eigenproblem(DIM,'D',0,'off',range)
        self.eigenproblem(DIM,'F',0,'off',range)
        self.eigenproblem(DIM,'D',1,'off',range)
        self.eigenproblem(DIM,'F',1,'off',range)
        
    def testComplexGeneralized(self):
        range = (2,DIM-1)
        self.geneigenproblem(DIM,'D',0,'on',None)
        self.geneigenproblem(DIM,'F',0,'on',None)
        self.geneigenproblem(DIM,'D',1,'on',None)
        self.geneigenproblem(DIM,'F',1,'on',None)
        self.geneigenproblem(DIM,'D',0,'off',None)
        self.geneigenproblem(DIM,'F',0,'off',None)
        self.geneigenproblem(DIM,'D',1,'off',None)
        self.geneigenproblem(DIM,'F',1,'off',None)
        self.geneigenproblem(DIM,'D',0,'off',range)
        self.geneigenproblem(DIM,'F',0,'off',range)
        self.geneigenproblem(DIM,'D',1,'off',range)
        self.geneigenproblem(DIM,'F',1,'off',range)
        
    def testNonContiguousMatrix(self):
        self.eigenproblem_nc(DIM,'f')
        self.eigenproblem_nc(DIM,'d')

    def testIntegerMatrix(self):
        a = numpy.array([[1,2],[2,7]])
        b = numpy.array([[3,1],[1,5]])
        w,z = symeig(a,overwrite=0)
        w,z = symeig(a,b,overwrite=0)

    def testDTypeConversion(self):
        types = {('f','d'): 'd', ('f','F'): 'F', ('f','D'): 'D',
                 ('d','F'): 'D', ('d','D'): 'D',
                 ('F','d'): 'D', ('F','D'): 'D'}
        for t1,t2 in types.keys():
            a = symrand(DIM,t1)
            b = symrand(DIM,t2)+numpy.diag([2.1]*DIM).astype(t2)
            w,z = symeig(a,b)
            assert_type_equal(z.dtype, types[(t1,t2)])

    def testWrongRangeBug(self):
        a = symrand(10)
        w,z = symeig(a, range=(20,30))
        w,z = symeig(a, range=(30,20))
        w,z = symeig(a, range=(0,3))

def get_suite():
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(SymeigTestCase))
    return suite

if __name__ == '__main__':
    numpy.random.seed(1268049219)
    unittest.TextTestRunner(verbosity=2).run(get_suite())
