#include "tao.h"
#include "common.h"
#include <math.h>
#include <stdlib.h>
#include "tao_solver.h"
#include "iostream.h"

int myMatTest(int, int, int);
int myVecTest(int, int, int);

static  char help[] = "Softmax Regression with Gaussian prior\n";

int main(int argc,char **argv)
{
	int createMode, assignMode, outputMode;
	int info;                  /* used to check for functions returning nonzeros */
	PetscTruth  flg;

	cout<<"In the Main() function of MyTest"<<endl;
	///////////////			My Test			///////////////	
	
	PetscInitialize(&argc,&argv,(char *)0,help);
	TaoInitialize(&argc,&argv,(char *)0,help);

	info = PetscOptionsGetInt(PETSC_NULL,"-c", &createMode, &flg); CHKERRQ(info);
	info = PetscOptionsGetInt(PETSC_NULL,"-a", &assignMode, &flg); CHKERRQ(info);
	info = PetscOptionsGetInt(PETSC_NULL,"-o", &outputMode, &flg); CHKERRQ(info);

//	myVecTest(createMode, assignMode, outputMode);
	myMatTest(createMode, assignMode, outputMode);

	return 0;
}

#undef __FUNCT__
#define __FUNCT__ "myMatTest"
int myMatTest(int createMode, int assignMode, int outputMode)
{
	int info, i, j;                  /* used to check for functions returning nonzeros */
	Mat x;

	PetscScalar value=10.0;
	PetscScalar *tempArray;
	PetscViewer ascIIViewer, binaryViewer;
	PetscScalar *values, *vals;
	int indices1[5], indices2[4], nnz[5], ncols, *cols;

	PetscMalloc(20 * sizeof(PetscScalar), &values);
	cout<<"Create Mode = "<<createMode<<"  Assign Mode = "<<assignMode<<"   Output Mode = " <<outputMode<<endl;
	for(i=0; i<5; i++)
		for(j=0; j<4; j++)
			values[i * 4 + j] = (i+1)*(j+1);

//	PetscScalarView(20, values, PETSC_VIEWER_STDOUT_WORLD);

	//////////////////////////////////
	//  1.    Create the Matrix
	//////////////////////////////////

	//  1.1   Use MatCreateSeqAIJ / MatCreateMPIAIJ
	if(createMode == 1)  
	{
	  // the size of matrix x is 5*4
		for(i=0; i<5; i++)
			nnz[i] = 4;
		info = MatCreateMPIAIJ(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, 5, 4, 
								0, nnz, 0, PETSC_NULL, &x);  
		CHKERRQ(info);

	}
	//   1.2   Use MatCreate
	else if (createMode == 2)
	{
		info = MatCreate(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, 5, 4, &x);	
		CHKERRQ(info);
//		info = MatSetType(x, MATMPIAIJ);   CHKERRQ(info);
		info = MatSetFromOptions(x);	CHKERRQ(info);
	}

	//   1.3   Use Dense Matrix, 
	//   切记 不能用于 MatZeroRows，否则结果有错，但仍可运行
	else if (createMode == 3)
	{
		info = MatCreateMPIDense(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, 5, 4, values, &x);
		CHKERRQ(info);
		cout<<"hahaha"<<endl;
//		goto  MatAccess;
	}

	//   1.4   Use Load from file
	else if (createMode == 4)
	{
		info = PetscViewerBinaryOpen(PETSC_COMM_SELF, "matBinary.txt", PETSC_FILE_RDONLY, &binaryViewer);
		CHKERRQ(info);

		//  好像只有 MATSEQAIJ 能通过，MATMPIAIJ 会出错
		MatLoad (binaryViewer, MATSEQAIJ, &x);	CHKERRQ(info);

		info = PetscViewerDestroy(binaryViewer);	CHKERRQ(info);
		cout<<"Show in create"<<endl;
		MatView(x, PETSC_VIEWER_STDOUT_WORLD);
		cout<<"==============================="<<endl;

//		goto MatAccess;
	}
	//   1.5   Use Sub-Matrix  ref Mat/ex3
	//   用一个长数组表示大矩阵，要从中抽取出一个
	else if (createMode == 5)
	{
		PetscScalar *b;
		Mat A22;
		int ierr, size=8, size1=3, size2=5;

		info = PetscMalloc(size * size * sizeof(PetscScalar), &b);	CHKERRQ(info);
		for(i=0; i<size; i++)
			for(j=0; j<size; j++)
				b[i + j * size] = (1 + i) * 10 + j + 1;

		info = MatCreate(MPI_COMM_SELF,size2,size2,size2,size2,&A22); CHKERRQ(info);
		info = MatSetType(A22,MATSEQDENSE); CHKERRQ(info);

		//  注意 A22 是以 Column Major 顺序填的
		info = MatSeqDenseSetPreallocation(A22, b+size1*size+size1); CHKERRQ(info);
		//  用 size 来指出原来 Array 的维度，否则无法从 Array 中读取（如右下角）一块
		info = MatSeqDenseSetLDA(A22,size); CHKERRQ(info);
		info = MatAssemblyBegin(A22, MAT_FINAL_ASSEMBLY); CHKERRQ(info);
		info = MatAssemblyEnd(A22, MAT_FINAL_ASSEMBLY); CHKERRQ(info);
		
		//    即使已经给矩阵赋值了，也不能把原数组 Free 掉
		//	info = PetscFree(b);	CHKERRQ(info);
		MatView(A22, PETSC_VIEWER_STDOUT_WORLD);
		MatDestroy(A22);
		cout<<"Going out of Create Mode 5"<<endl;
	}

	//////////////////////////////////	  
	//  2.  Assign Value
	//////////////////////////////////
	for(i=0; i<5; i++)
		  indices1[i] = i;

	for(i=0; i<4; i++)
		  indices2[i] = i;

	info = MatSetOption(x, MAT_ROWS_SORTED);   CHKERRQ(info);
	info = MatSetOption(x, MAT_COLUMNS_SORTED);   CHKERRQ(info);
	info = MatSetValues(x, 5, indices1, 4, indices2, values, INSERT_VALUES); CHKERRQ(info);

	info = MatAssemblyBegin(x, MAT_FINAL_ASSEMBLY);   CHKERRQ(info);
	info = MatAssemblyEnd(x, MAT_FINAL_ASSEMBLY);   CHKERRQ(info);
	cout<<"Show right after create"<<endl;
	MatView(x, PETSC_VIEWER_STDOUT_WORLD);
	
	//  Use matZeroRows
	if(assignMode == 1)
	{
		int *rows;
		IS is;
		Mat y;

		PetscMalloc(2 * sizeof(int), &rows);
		rows[0] = 1;
		rows[1] = 2;

		//  IS 的构建方法随选一种
		ISCreateGeneral(PETSC_COMM_WORLD, 2, rows, &is);
		//ISCreateStride(PETSC_COMM_WORLD, 1, 1, 1, &is); 

//		info = MatCreate(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, 5, 4, &y);	CHKERRQ(info);
//		cout<<"0"<<endl;
//		info = MatSetType(y, MATMPIAIJ);   CHKERRQ(info);
//		cout<<"1"<<endl;
//		MatConvert(x, MATSAME, &y);
		cout<<"2"<<endl;
		info = MatZeroRows(x, is, &value);		CHKERRQ(info);
		cout<<"3"<<endl;
		//MatZeroRows(x, is, PETSC_NULL);

		PetscFree(rows);
		ISDestroy(is);
/*
		cout<<"The matrix right after MatZeroRows:"<<endl;
		MatView(x, PETSC_VIEWER_STDOUT_SELF);
		cout<<"================================="<<endl;
*/
	}
	//  2.2   Use MatSet
    if(assignMode == 2)
	{
		for(i=0; i<5; i++)
			  indices1[i] = i;

		for(i=0; i<4; i++)
			  indices2[i] = i;

		info = MatSetOption(x, MAT_ROWS_SORTED);   CHKERRQ(info);
		info = MatSetOption(x, MAT_COLUMNS_SORTED);   CHKERRQ(info);
		info = MatSetValues(x, 5, indices1, 4, indices2, values, INSERT_VALUES); CHKERRQ(info);

		info = MatAssemblyBegin(x, MAT_FINAL_ASSEMBLY);   CHKERRQ(info);
		info = MatAssemblyEnd(x, MAT_FINAL_ASSEMBLY);   CHKERRQ(info);

//		info = MatSetValues(x, 5, indices1, 4, indices2, values, ADD_VALUES); CHKERRQ(info);

//		info = MatAssemblyBegin(x, MAT_FINAL_ASSEMBLY);   CHKERRQ(info);
//		info = MatAssemblyEnd(x, MAT_FINAL_ASSEMBLY);   CHKERRQ(info);
	}	


MatAccess:
    /////////////////////////////////////////////
	//  3.  Get access to the components' value
	/////////////////////////////////////////////
	
	// do not allocate memory for tempArray, just point to the content of x
	cout<<"Come to Access Portion"<<endl;
	info = MatGetRow(x, 2, &ncols, &cols, &vals);   CHKERRQ(info);  
	cout<<"There are "<<ncols<<" non-zero numbers in Row 2"<<endl;
	for(i = 0; i < ncols; i ++)
	{
		//  Note:  We CAN change value of the Vector x by writing to its memory
		//         But we are NOT allowed to do so for Matrix cases
		cout<<"Row 2, Column "<<cols[i]<<"  Val: "<<vals[i]<<endl;
	}
	info = MatRestoreRow(x, 2, &ncols, &cols, &vals);		CHKERRQ(info);

	///////////////////////////////////////
	//  4.  Print value to screen or file
	///////////////////////////////////////

	//  4.1  Output the vector to stdout (screen)
	if(outputMode == 1)
		MatView(x, PETSC_VIEWER_STDOUT_WORLD);
	
	//  4.2  Output the vector to a file in ASCII format "output.txt"
	else if(outputMode == 2)
	{
		info = PetscViewerASCIIOpen(PETSC_COMM_WORLD, "output.txt", &ascIIViewer);  CHKERRQ(info);
		PetscViewerSetFormat (ascIIViewer, PETSC_VIEWER_ASCII_IMPL);

		PetscViewerPushFormat (ascIIViewer, PETSC_VIEWER_ASCII_DEFAULT);
		MatView(x, ascIIViewer);
		PetscViewerPopFormat (ascIIViewer);
		
		//  Still use output.txt,  will NOT overwrite the former VecView
		MatView(x, ascIIViewer);		
		info = PetscViewerDestroy(ascIIViewer);  CHKERRQ(info);
	}

	//  4.3   Output the vector to a binary file "binary.txt"
	else if(outputMode == 3)
	{
		info = PetscViewerBinaryOpen(PETSC_COMM_WORLD, "matBinary.txt", PETSC_FILE_CREATE, &binaryViewer);
		CHKERRQ(info);
		MatView(x, binaryViewer);
		info = PetscViewerDestroy(binaryViewer);	CHKERRQ(info);
	}
	
	info = MatDestroy(x); CHKERRQ(info);
	PetscFree(values);
	TaoFinalize();
	PetscFinalize();
	return 0;
}


#undef __FUNCT__
#define __FUNCT__ "myVecTest"
int myVecTest(int createMode, int assignMode, int outputMode)
{
	int info, i;                  // used to check for functions returning nonzeros
	Vec x;

	PetscScalar value = 10.0, values[10], *tempArray;
	PetscViewer ascIIViewer, binaryViewer;
	int indices[5];

	cout<<"Create Mode = "<<createMode<<"  Assign Mode = "<<assignMode<<"   Output Mode = " <<outputMode<<endl;
  
	//////////////////////////////////
	//  1.    Create the Vector
	//////////////////////////////////

	//  1.1   Use VecCreateMPI / VecCreateSeq
	if(createMode == 1)  
	{
	  // the length of vector x is 10
		info = VecCreateMPI(PETSC_COMM_WORLD, PETSC_DECIDE, 10, &x);  
		CHKERRQ(info);
	}
	//   1.2   Use VecCreate
	else if (createMode == 2)
	{
		info = VecCreate(PETSC_COMM_WORLD, &x);		CHKERRQ(info);
		info = VecSetSizes(x, PETSC_DECIDE, 10);	CHKERRQ(info);
		info = VecSetFromOptions(x);				CHKERRQ(info);
	}
	//   1.3   Use Load from file to Create
	else if(createMode == 3)
	{
		info = PetscViewerBinaryOpen(PETSC_COMM_WORLD, "vecBinary.txt", PETSC_FILE_RDONLY, &binaryViewer);
		CHKERRQ(info);
		VecLoad(binaryViewer, VECMPI, &x);		CHKERRQ(info);
		info = PetscViewerDestroy(binaryViewer);	CHKERRQ(info);
		goto vecDisplay;
	}

/*
	//   1.3   先写 Array, 然后一次性生成 Vector
	else if (createMode == 3)
	{
		PetscScalar *x, one=1;
		Vec X;
		ierr = PetscMalloc(size*sizeof(PetscScalar), &x); CHKERRQ(ierr);
		for (i=0; i<size; i++) {
			x[i] = one;
		}
		ierr = VecCreateSeqWithArray(MPI_COMM_SELF,size,x,&X); CHKERRQ(ierr);   ///////////////
		ierr = VecAssemblyBegin(X); CHKERRQ(ierr);
		ierr = VecAssemblyEnd(X); CHKERRQ(ierr);
	}
*/
	//////////////////////////////////	  
	//  2.  Assign Value
	//////////////////////////////////

	//  2.1   Use VecSet
    if(assignMode == 1)
	{
		info = VecSet(&value, x);		CHKERRQ(info);
	}
	//  2.2   Use VecSetValues
	else if(assignMode == 2)
	{
		for(i=0; i < 10; i++)		// 0 1 2 3 4 5 6 7 8 9
			indices[i] = i;
		for(i=0; i < 10; i++)		// 1 2 3 4 5 6 7 8 9 10
			values[i] = i + 1;

		info = VecSetValues(x, 10, indices, values, INSERT_VALUES);
		VecAssemblyBegin(x);
		VecAssemblyEnd(x);

		//   Then try ADD_VALUE,  but addition and insertion must be intervened with assembly routines
		for(i=0; i < 5; i++)		// only alter line 1 3 5 7 9
			indices[i] = 2 * i;
		for(i=0; i < 5; i++)		//  change X to X0X
			values[i] = 100 * (2 * i + 1);

		info = VecSetValues(x, 5, indices, values, ADD_VALUES);
		VecAssemblyBegin(x);
		VecAssemblyEnd(x);
	}
	

    /////////////////////////////////////////////
	//  3.  Get access to the components' value
	/////////////////////////////////////////////

	// do not allocate memory for tempArray, just point to the content of x
	info = VecGetArray(x,&tempArray);   CHKERRQ(info);  
	for(i = 1; i < 10; i += 2)
	{
		//  Note:  We CAN change value of the Vector x by writing to its memory
		//         But we are NOT allowed to do so for Matrix cases
		tempArray[i] += (i + 1) * 4;	// change even lines from X to 5 * X 
	}
	info = VecRestoreArray(x,&tempArray); CHKERRQ(info);

	///////////////////////////////////////
	//  4.  Print value to screen or file
	///////////////////////////////////////
vecDisplay:
	//  4.1  Output the vector to stdout (screen)
	if(outputMode == 1)
		VecView(x, PETSC_VIEWER_STDOUT_WORLD);
	
	//  4.2  Output the vector to a file in ASCII format "output.txt"
	else if(outputMode == 2)
	{
		info = PetscViewerASCIIOpen(PETSC_COMM_WORLD, "output.txt", &ascIIViewer);  CHKERRQ(info);
		PetscViewerSetFormat (ascIIViewer, PETSC_VIEWER_ASCII_IMPL);

		PetscViewerPushFormat (ascIIViewer, PETSC_VIEWER_ASCII_DEFAULT);
		VecView(x, ascIIViewer);
		PetscViewerPopFormat (ascIIViewer);
		
		//  Still use output.txt,  will NOT overwrite the former VecView
		VecView(x, ascIIViewer);		
		info = PetscViewerDestroy(ascIIViewer);  CHKERRQ(info);
	}

	//  4.3   Output the vector to a binary file "binary.txt"
	else if(outputMode == 3)
	{
		info = PetscViewerBinaryOpen(PETSC_COMM_WORLD, "vecBinary.txt", PETSC_FILE_CREATE, &binaryViewer);
		CHKERRQ(info);
		VecView(x, binaryViewer);
		info = PetscViewerDestroy(binaryViewer);	CHKERRQ(info);
	}
	
	info = VecDestroy(x); CHKERRQ(info);

	TaoFinalize();
	PetscFinalize();
	return 0;
}