/* Ergo, version 3.3, a program for linear scaling electronic structure
 * calculations.
 * Copyright (C) 2013 Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek.
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Primary academic reference:
 * Kohn−Sham Density Functional Theory Electronic Structure Calculations 
 * with Linearly Scaling Computational Time and Memory Usage,
 * Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek,
 * J. Chem. Theory Comput. 7, 340 (2011),
 * <http://dx.doi.org/10.1021/ct100611z>
 * 
 * For further information about Ergo, see <http://www.ergoscf.org>.
 */

#include <stdio.h>
#include <unistd.h>
#include <memory>
#include <limits>
#include "basisinfo.h"

#ifdef USE_CHUNKS_AND_TASKS

#include "chunks_and_tasks.h"
#include "BasisInfoStructChunk.h"
#include "IntegralInfoChunk.h"
#include "compute_overlap_task_implementations.h"
#include "matrix_utilities.h"
#include "integrals_general.h"
#include "integral_matrix_wrappers.h"
#include "utilities.h"

/* CHTTL registration stuff */
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkBasic<int>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkBasic<size_t>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkBasic<basisset_struct>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkVector<int>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkVector<double>));
CHTTL_REGISTER_CHUNK_TYPE((chttl::ChunkVector<Atom>));
CHTTL_REGISTER_TASK_TYPE((chttl::ChunkBasicAdd<int>));
CHTTL_REGISTER_TASK_TYPE((chttl::ChunkBasicAdd<size_t>));

/* CHTML registration stuff */
CHTML_REGISTER_CHUNK_TYPE((CHTMLMatType));
CHTML_REGISTER_CHUNK_TYPE((chtml::MatrixParams<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixGetElements<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixMultiply<LeafMatType, false, false>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixNNZ<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixAssignFromChunkIDs<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixCombineElements<double>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixAdd<LeafMatType>));
CHTML_REGISTER_TASK_TYPE((chtml::MatrixAssignFromSparse<LeafMatType>));


static void
preparePermutationsHML(const BasisInfoStruct& basisInfo,
		       mat::SizesAndBlocks& sizeBlockInfo, 
		       std::vector<int>& permutation,
		       std::vector<int>& inversePermutation)
{
  static const int sparseMatrixBlockSize = 16, sparseMatrixBlockFactor = 4;
  sizeBlockInfo =
    prepareMatrixSizesAndBlocks(basisInfo.noOfBasisFuncs,
                                sparseMatrixBlockSize,
                                sparseMatrixBlockFactor,
                                sparseMatrixBlockFactor,
                                sparseMatrixBlockFactor);
  getMatrixPermutation(basisInfo,
                       sparseMatrixBlockSize,
                       sparseMatrixBlockFactor,
                       sparseMatrixBlockFactor,
                       sparseMatrixBlockFactor,
                       permutation,
                       inversePermutation);
}

static void
preparePermutationsCHTML(const BasisInfoStruct& basisInfo,
			 int blockSize_lowest,
			 int first_blocksize_factor,
			 std::vector<int>& permutation,
			 std::vector<int>& inversePermutation)
{
  getMatrixPermutationOnlyFactor2(basisInfo,
				  blockSize_lowest,
				  first_blocksize_factor,
				  permutation,
				  inversePermutation);
}

static double rand_0_to_1() {
  int randomint = rand();
  double x = (double)randomint;
  return x / RAND_MAX;
}

static double get_random_index_0_to_nm1(int n) {
  int result = (int)((double)n * rand_0_to_1());
  if(result < 0 || result >= n)
    throw std::runtime_error("Error: (result < 0 || result >= n).");
  return result;
}

static void get_elements_from_cht_matrix(int n, 
					 const std::vector<int> & rowind_in, 
					 const std::vector<int> & colind_in, 
					 std::vector<double> & resultValues, 
					 cht::ChunkID cid_matrix, 
					 int blockSize,
					 const std::vector<int> & permutation) {
  // Create params
  int M = n;
  int N = n;
  int leavesSizeMax = blockSize;
  cht::ChunkID cid_param = cht::registerChunk(new chtml::MatrixParams<LeafMatType>(M, N, leavesSizeMax, 0, 0));
  int nValuesToGet = rowind_in.size();
  // Create rowind vector
  std::vector<int> rowind(nValuesToGet);
  std::vector<int> colind(nValuesToGet);
  for(int i = 0; i < nValuesToGet; i++) {
    rowind[i]= permutation[rowind_in[i]];
    colind[i]= permutation[colind_in[i]];
  }
  cht::ChunkID cid_rowind = cht::registerChunk(new chttl::ChunkVector<int>(rowind));
  cht::ChunkID cid_colind = cht::registerChunk(new chttl::ChunkVector<int>(colind));
  // Register task
  std::vector<cht::ChunkID> inputChunks(4);
  inputChunks[0] = cid_param;
  inputChunks[1] = cid_rowind;
  inputChunks[2] = cid_colind;
  inputChunks[3] = cid_matrix;
  cht::ChunkID cid_result = 
    cht::executeMotherTask<chtml::MatrixGetElements<LeafMatType> >(inputChunks);
  // Get resulting chunk object.
  cht::shared_ptr<chttl::ChunkVector<double> const> ptr_result;
  cht::getChunk(cid_result, ptr_result);
  resultValues.resize(ptr_result->size());
  for(int i = 0; i < nValuesToGet; i++)
    resultValues[i] = (*ptr_result)[i];
  cht::deleteChunk(cid_param);
  cht::deleteChunk(cid_rowind);
  cht::deleteChunk(cid_colind);
  cht::deleteChunk(cid_result);
}

static size_t get_nnz_for_CHTML_matrix(cht::ChunkID cid_matrix) {
  cht::ChunkID cid_nnz = cht::executeMotherTask<chtml::MatrixNNZ<LeafMatType> >(cid_matrix);
  cht::shared_ptr<chttl::ChunkBasic<size_t> const> ptr_nnz;
  cht::getChunk(cid_nnz, ptr_nnz);
  size_t nnz = ptr_nnz->x;
  cht::deleteChunk(cid_nnz);
  return nnz;
}

static double get_single_element_from_HML_matrix(int n, int row, int col, const symmMatrix & S, const std::vector<int> & permutationHML) {
  std::vector<int> rowind(1);
  std::vector<int> colind(1);
  std::vector<ergo_real> values(1);
  rowind[0] = row;
  colind[0] = col;
  S.get_values(rowind, colind, values, permutationHML, permutationHML);
  return values[0];
}

int main(int argc, char *argv[])
{
  int nAtoms = 3;
  if(argc > 1)
    nAtoms = atoi(argv[1]);
  bool useLinearMolecule = false;
  if(argc > 2) {
    if(strcmp(argv[2], "linear") == 0)
      useLinearMolecule = true;
  }
  int nWorkers = 2;
  if(argc > 3)
    nWorkers = atoi(argv[3]);
  if(nWorkers < 1) {
    printf("Error: (nWorkers < 1).\n");
    return 1;
  }
  int blockSizeCHTML = 3;
  if(argc > 4)
    blockSizeCHTML = atoi(argv[4]);
  if(blockSizeCHTML < 1) {
    printf("Error: (blockSizeCHTML < 1).\n");
    return 1;
  }

  int nThreads    = 1;
  cht::extras::setNoOfWorkerThreads(nThreads);
  cht::setOutputMode(cht::Output::AllInTheEnd);
  cht::extras::setNWorkers(nWorkers);
  cht::extras::setCacheSize(500000000);
  cht::start();

  if(nAtoms < 1) {
    printf("Error: (nAtoms < 1).\n");
    return 1;
  }
  printf("nWorkers = %d\n", nWorkers);
  printf("nAtoms = %d\n", nAtoms);
  printf("useLinearMolecule = %d\n", (int)useLinearMolecule);
  printf("blockSizeCHTML = %4d\n", blockSizeCHTML);

  std::auto_ptr<IntegralInfo> biBasic(new IntegralInfo(true));
  BasisInfoStruct* bis = new BasisInfoStruct();

  static Molecule m; /* Don't allocate it on stack, it's too large. */
  int verbose = getenv("VERBOSE") != NULL;

  if(useLinearMolecule) {
    double spacing = 5.0;
    printf("Creating linear molecule with spacing %7.3f between atoms.\n", spacing);
    for(int i = 0; i < nAtoms; i++) {
      double x = 0;;
      double y = 0;
      double z = i * spacing;
      m.addAtom(1, x, y, z);
    }
  }
  else {
    double atomsPerUnitVolume = 0.001;
    double boxVolume = nAtoms / atomsPerUnitVolume;
    double boxWidth = pow(boxVolume, 1.0/3.0);
    printf("boxVolume = %9.3f, boxWidth = %9.3f\n", boxVolume, boxWidth);
    for(int i = 0; i < nAtoms; i++) {
      double x = boxWidth * rand_0_to_1();
      double y = boxWidth * rand_0_to_1();
      double z = boxWidth * rand_0_to_1();
      int atomType = 1;
      if(i % 2 == 0)
	atomType = 1;
      m.addAtom(atomType, x, y, z);
    }
  }

  if(bis->addBasisfuncsForMolecule(m, ERGO_SPREFIX "/basis/6-31G",
                                   0, NULL, *biBasic, 0, 0, 0) != 0) {
    printf("bis->addBasisfuncsForMolecule failed.\n");
    return 1;
  }

  // Create basisinfo chunk
  printf("Creating basisinfo chunk...\n");
  cht::ChunkID cid_basisinfo = cht::registerChunk(new BasisInfoStructChunk(*bis));

  int n = bis->noOfBasisFuncs;

  // Get overlap matrix
  std::vector<int> permutationHML, inversePermutationHML;
  mat::SizesAndBlocks sizeBlockInfo;
  preparePermutationsHML(*bis, sizeBlockInfo,
			 permutationHML, inversePermutationHML);
  symmMatrix S;
  S.resetSizesAndBlocks(sizeBlockInfo, sizeBlockInfo); 
  if(compute_overlap_matrix_sparse(*bis, S, 
				   permutationHML) != 0) {
    puts("error in compute_overlap_matrix_sparse");
    return -1;
  }
  symmMatrix S2;
  S2 = 1.0 * S * S;

  printf("nAtoms = %d, n = %d\n", nAtoms, n);
  printf("noOfShells = %d\n", bis->noOfShells);
  printf("noOfSimplePrimitives = %d\n", bis->noOfSimplePrimitives);

  printf("Trying to get chunk...\n");
  cht::shared_ptr<BasisInfoStructChunk const> ptr_c;
  cht::getChunk(cid_basisinfo, ptr_c);
  printf("After getting chunk, got n = %d\n", ptr_c->b.noOfBasisFuncs);

  std::vector<int> permutationCHTML, inversePermutationCHTML;
  // ELIAS NOTE 2013-04-04: hard-coding blocksize_lowest=1 here, for this to work good with block-sparse lib this should be changed.
  int blocksize_lowest = 1;
  int first_blocksize_factor = blockSizeCHTML;
  preparePermutationsCHTML(*bis,
			   blocksize_lowest,
			   first_blocksize_factor,
			   permutationCHTML,
			   inversePermutationCHTML
			   );
#if 0
  for(int i = 0; i < n; i++) {
    permutationCHTML[i] = i;
    inversePermutationCHTML[i] = i;
  }
#endif

  cht::ChunkID cid_perm = cht::registerChunk(new chttl::ChunkVector<int>(inversePermutationCHTML));
  cht::ChunkID cid_blsz = cht::registerChunk(new chttl::ChunkBasic<int>(blockSizeCHTML));

  // Prepare basis func extent list needed for overlap matrix computation.
  ergo_real largest_simple_integral = get_largest_simple_integral(*bis);
  printf("largest_simple_integral = %22.11f\n", largest_simple_integral);
  std::vector<ergo_real> basisFuncExtentList(n);
  const ergo_real MATRIX_ELEMENT_THRESHOLD_VALUE = 1e-12;
  get_basis_func_extent_list(*bis, &basisFuncExtentList[0], MATRIX_ELEMENT_THRESHOLD_VALUE / largest_simple_integral);
  cht::ChunkID cid_extentList = cht::registerChunk(new chttl::ChunkVector<double>(basisFuncExtentList));
  // Now do overlap matrix computation.
  int startIdx = 0;
  cht::ChunkID cid_startIdx = cht::registerChunk(new chttl::ChunkBasic<int>(startIdx));
  cht::ChunkID cid_n = cht::registerChunk(new chttl::ChunkBasic<int>(n));
  std::vector<cht::ChunkID> inputChunks(6);
  inputChunks[0] = cid_basisinfo;
  inputChunks[1] = cid_extentList;
  inputChunks[2] = cid_blsz;
  inputChunks[3] = cid_perm;
  inputChunks[4] = cid_startIdx;
  inputChunks[5] = cid_n;
  Util::TimeMeter tmComputeOverlapMatrix;
  printf("Before cht::executeMotherTask for TaskTypeComputeOverlapMatrix\n");
  cht::ChunkID cid_matrix_S = cht::executeMotherTask<TaskTypeComputeOverlapMatrix>(inputChunks);
  double secondsTakenComputeOverlapMatrix = Util::TimeMeter::get_wall_seconds() - tmComputeOverlapMatrix.get_start_time_wall_seconds();
  printf("cht::executeMotherTask for TaskTypeComputeOverlapMatrix took %12.5f wall seconds.\n", secondsTakenComputeOverlapMatrix);

  // Check result by checking some elements.
  {
    int nElementsToCheck = 222;
    printf("Checking result by looking at %d matrix elements...\n", nElementsToCheck);
    double maxAbsDiff = 0;
    std::vector<int> rowind(nElementsToCheck);
    std::vector<int> colind(nElementsToCheck);
    std::vector<double> values(nElementsToCheck);
    for(int i = 0; i < nElementsToCheck; i++) {
      int j = get_random_index_0_to_nm1(n);
      int k = get_random_index_0_to_nm1(n);
      rowind[i] = j;
      colind[i] = k;
    }
    get_elements_from_cht_matrix(n, 
				 rowind, colind, values, 
				 cid_matrix_S, blockSizeCHTML, permutationCHTML);
    for(int i = 0; i < nElementsToCheck; i++) {
      int j = rowind[i];
      int k = colind[i];
      // Check matrix element (k,j).
      double refValue1 = compute_one_element_of_overlap_mat(*bis, j, k);
      double refValue2 = get_single_element_from_HML_matrix(n, j, k, S, permutationHML);
      // Get matrix element from CHTML matrix.
      double matrixElementValue = values[i];
      // Compare
      double absDiff1 = fabs(matrixElementValue - refValue1);
      if(absDiff1 > maxAbsDiff)
	maxAbsDiff = absDiff1;
      double absDiff2 = fabs(matrixElementValue - refValue2);
      if(absDiff2 > maxAbsDiff)
	maxAbsDiff = absDiff2;
    }
    printf("Checked %d matrix elements, maxAbsDiff = %8.4g\n", nElementsToCheck, maxAbsDiff);
    if(maxAbsDiff > 1e-8)
      throw std::runtime_error("Error: wrong result, (maxAbsDiff > 1e-8).");
  }

  // Check nnz for computed overlap matrix
  int nnz_S = get_nnz_for_CHTML_matrix(cid_matrix_S);
  printf("NNZ: %12d  <-->  %8.3f %% nonzero elements  <-->  %8.3f nonzero elements per row.\n", 
	 nnz_S, (double)nnz_S*100.0/((double)n*n), (double)nnz_S/((double)n));

  // Do multiplication S*S
  std::vector<cht::ChunkID> inputChunksMmul(2);
  inputChunksMmul[0] = cid_matrix_S;
  inputChunksMmul[1] = cid_matrix_S;
  cht::resetStatistics();
  Util::TimeMeter tmMatrixMultiply;
  printf("Before cht::executeMotherTask for chtml::MatrixMultiply\n");
  cht::ChunkID cid_matrix_S2 = 
    cht::executeMotherTask< chtml::MatrixMultiply<LeafMatType, false, false> > (inputChunksMmul);
  printf("After cht::executeMotherTask for chtml::MatrixMultiply\n");
  double secondsTakenMatrixMultiply = Util::TimeMeter::get_wall_seconds() - tmMatrixMultiply.get_start_time_wall_seconds();
  printf("cht::executeMotherTask for chtml::MatrixMultiply took %12.5f wall seconds.\n", secondsTakenMatrixMultiply);
  cht::reportStatistics();

  // Check nnz for S2 matrix
  int nnz_S2 = get_nnz_for_CHTML_matrix(cid_matrix_S2);
  printf("NNZ for S2: %12d  <-->  %8.3f %% nonzero elements  <-->  %8.3f nonzero elements per row.\n", 
	 nnz_S2, (double)nnz_S2*100.0/((double)n*n), (double)nnz_S2/((double)n));


  // Compare to S2 computed in traditional way.
  {
    int nElementsToCheck = 222;
    printf("Checking result of S*S operation by looking at %d matrix elements...\n", nElementsToCheck);
    double maxAbsDiff = 0;
    double maxAbsElement = 0;
    std::vector<int> rowind(nElementsToCheck);
    std::vector<int> colind(nElementsToCheck);
    std::vector<double> values(nElementsToCheck);
    for(int i = 0; i < nElementsToCheck; i++) {
      int j = get_random_index_0_to_nm1(n);
      int k = get_random_index_0_to_nm1(n);
      rowind[i] = j;
      colind[i] = k;
    }
    get_elements_from_cht_matrix(n, 
				 rowind, colind, values, 
				 cid_matrix_S2, blockSizeCHTML, permutationCHTML);
    for(int i = 0; i < nElementsToCheck; i++) {
      int j = rowind[i];
      int k = colind[i];
      double refValue = get_single_element_from_HML_matrix(n, j, k, S2, permutationHML);
      double matrixElementValue = values[i];
      double absElement = fabs(matrixElementValue);
      if(absElement > maxAbsElement)
	maxAbsElement = absElement;
      // Compare
      double absDiff = fabs(matrixElementValue - refValue);
      if(absDiff > maxAbsDiff)
	maxAbsDiff = absDiff;
    }
    printf("Checked %d matrix elements of S2, maxAbsDiff = %8.4g\n", nElementsToCheck, maxAbsDiff);
    if(maxAbsDiff > 1e-8)
      throw std::runtime_error("Error: wrong result, (maxAbsDiff > 1e-8).");  
    printf("Congratulations, S*S result seems OK!\n");
    printf("Largest abs value of any checked element: %12.6f\n", maxAbsElement);
  }

  printf("Calling cht::deleteChunk...\n");
  cht::deleteChunk(cid_basisinfo);
  cht::deleteChunk(cid_extentList);
  cht::deleteChunk(cid_perm);
  cht::deleteChunk(cid_startIdx);
  cht::deleteChunk(cid_n);
  cht::deleteChunk(cid_blsz);
  cht::deleteChunk(cid_matrix_S);
  cht::deleteChunk(cid_matrix_S2);
  printf("After cht::deleteChunk.\n");

  cht::stop();

  puts("CHT test succeeded."); 
  unlink("ergoscf.out");
  return 0;
}

#else

int main(int argc, char *argv[])
{
  printf("Skipping Chunks&Tasks overlap matrix creation test since USE_CHUNKS_AND_TASKS macro not defined.\n");
  return 0;
}

#endif
