/* Copyright (c) 2008-2025 the MRtrix3 contributors.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 *
 * Covered Software is provided under this License on an "as is"
 * basis, without warranty of any kind, either expressed, implied, or
 * statutory, including, without limitation, warranties that the
 * Covered Software is free of defects, merchantable, fit for a
 * particular purpose or non-infringing.
 * See the Mozilla Public License v. 2.0 for more details.
 *
 * For more details, see http://www.mrtrix.org/.
 */

#ifndef __dwi_sdeconv_csd_h__
#define __dwi_sdeconv_csd_h__

#include "app.h"
#include "header.h"
#include "dwi/gradient.h"
#include "dwi/shells.h"
#include "math/SH.h"
#include "math/ZSH.h"
#include "dwi/directions/predefined.h"
#include "math/least_squares.h"

#define NORM_LAMBDA_MULTIPLIER 0.0002

#define DEFAULT_CSD_LMAX 8
#define DEFAULT_CSD_NEG_LAMBDA 1.0
#define DEFAULT_CSD_NORM_LAMBDA 1.0
#define DEFAULT_CSD_THRESHOLD 0.0
#define DEFAULT_CSD_NITER 50

namespace MR
{
  namespace DWI
  {
    namespace SDeconv
    {

    extern const App::OptionGroup CSD_options;

    class CSD { MEMALIGN(CSD)
      public:

        class Shared { MEMALIGN(Shared)
          public:

            Shared (const Header& dwi_header) :
              HR_dirs (Directions::electrostatic_repulsion_300()),
              neg_lambda (DEFAULT_CSD_NEG_LAMBDA),
              norm_lambda (DEFAULT_CSD_NORM_LAMBDA),
              threshold (DEFAULT_CSD_THRESHOLD),
              lmax_response (0),
              lmax_cmdline (0),
              lmax (0),
              niter (DEFAULT_CSD_NITER) {
                grad = DWI::get_DW_scheme (dwi_header);
                // Discard b=0 (b=0 normalisation not supported in this version)
                // Only allow selection of one non-zero shell from command line
                dwis = DWI::Shells (grad).select_shells (true, false, true).largest().get_volumes();
                DW_dirs = DWI::gen_direction_matrix (grad, dwis);

                lmax_data = Math::SH::LforN (dwis.size());
              }


            void parse_cmdline_options()
            {
              using namespace App;
              auto opt = get_options ("lmax");
              if (opt.size()) {
                auto list = parse_ints<uint32_t> (opt[0][0]);
                if (list.size() != 1)
                  throw Exception ("CSD algorithm expects a single lmax to be specified");
                lmax_cmdline = list.front();
              }
              opt = get_options ("filter");
              if (opt.size())
                init_filter = load_vector (opt[0][0]);
              opt = get_options ("directions");
              if (opt.size())
                HR_dirs = load_matrix (opt[0][0]);
              opt = get_options ("neg_lambda");
              if (opt.size())
                neg_lambda = opt[0][0];
              opt = get_options ("norm_lambda");
              if (opt.size())
                norm_lambda = opt[0][0];
              opt = get_options ("threshold");
              if (opt.size())
                threshold = opt[0][0];
              opt = get_options ("niter");
              if (opt.size())
                niter = opt[0][0];
            }


            void set_response (const std::string& path)
            {
              INFO ("loading response function from file \"" + path + "\"");
              set_response (load_vector (path));
            }

            template <class Derived>
              void set_response (const Eigen::MatrixBase<Derived>& in)
              {
                response = in;
                lmax_response = Math::ZSH::LforN (response.size());
                INFO ("setting response function using even SH coefficients: " + str (response.transpose()));
              }


            void init ()
            {
              using namespace Math::SH;

              if (lmax_data <= 0)
                throw Exception ("data contain too few directions even for lmax = 2");

              if (lmax_response <= 0)
                throw Exception ("response function does not contain anisotropic terms");

              lmax = ( lmax_cmdline ? lmax_cmdline : std::min (lmax_response, uint32_t(DEFAULT_CSD_LMAX)) );

              if (lmax <= 0 || lmax % 2)
                throw Exception ("lmax must be a positive even integer");

              assert (response.size());
              lmax_response = std::min (lmax_response, std::min (lmax_data, lmax));
              INFO ("calculating even spherical harmonic components up to order " + str (lmax_response) + " for initialisation");

              if (!init_filter.size())
                init_filter = Eigen::VectorXd::Ones(3);
              init_filter.conservativeResizeLike (Eigen::VectorXd::Zero (Math::ZSH::NforL (lmax_response)));

              auto RH = Math::ZSH::ZSH2RH (response);
              if (size_t(RH.size()) < Math::ZSH::NforL (lmax))
                RH.conservativeResizeLike (Eigen::VectorXd::Zero (Math::ZSH::NforL (lmax)));

              // inverse sdeconv for initialisation:
              auto fconv = init_transform (DW_dirs, lmax_response);
              rconv.resize (fconv.cols(), fconv.rows());
              fconv.diagonal().array() += 1.0e-2;
              //fconv.save ("fconv.txt");
              rconv = Math::pinv (fconv);
              //rconv.save ("rconv.txt");
              ssize_t l = 0, nl = 1;
              for (ssize_t row = 0; row < rconv.rows(); ++row) {
                if (row >= nl) {
                  l++;
                  nl = NforL (2*l);
                }
                rconv.row (row).array() *= init_filter[l] / RH[l];
              }

              // forward sconv for iteration, using all response function
              // coefficients up to the requested lmax:
              INFO ("calculating even spherical harmonic components up to order " + str (lmax) + " for output");
              fconv = init_transform (DW_dirs, lmax);
              l = 0;
              nl = 1;
              for (ssize_t col = 0; col < fconv.cols(); ++col) {
                if (col >= nl) {
                  l++;
                  nl = NforL (2*l);
                }
                fconv.col (col).array() *= RH[l];
              }

              // high-res sampling to apply constraint:
              HR_trans = init_transform (HR_dirs, lmax);
              default_type constraint_multiplier = neg_lambda * 50.0 * response[0] / default_type (HR_trans.rows());
              HR_trans.array() *= constraint_multiplier;

              // adjust threshold accordingly:
              threshold *= constraint_multiplier;

              // precompute as much as possible ahead of Cholesky decomp:
              assert (fconv.cols() <= HR_trans.cols());
              M.resize (DW_dirs.rows(), HR_trans.cols());
              M.leftCols (fconv.cols()) = fconv;
              M.rightCols (M.cols() - fconv.cols()).setZero();
              Mt_M.resize (M.cols(), M.cols());
              Mt_M.triangularView<Eigen::Lower>() = M.transpose() * M;


              // min-norm constraint:
              if (norm_lambda) {
                norm_lambda *= NORM_LAMBDA_MULTIPLIER * Mt_M (0,0);
                Mt_M.diagonal().array() += norm_lambda;
              }

              INFO ("constrained spherical deconvolution initialised successfully");
            }

            size_t nSH () const {
              return HR_trans.cols();
            }

            Eigen::MatrixXd grad;
            Eigen::VectorXd response, init_filter;
            Eigen::MatrixXd DW_dirs, HR_dirs;
            Eigen::MatrixXd rconv, HR_trans, M, Mt_M;
            default_type neg_lambda, norm_lambda, threshold;
            vector<size_t> dwis;
            uint32_t lmax_response, lmax_data, lmax_cmdline, lmax;
            size_t niter;
        };








        CSD (const Shared& shared_data) :
          shared (shared_data),
          work (shared.Mt_M.rows(), shared.Mt_M.cols()),
          HR_T (shared.HR_trans.rows(), shared.HR_trans.cols()),
          F (shared.HR_trans.cols()),
          init_F (shared.rconv.rows()),
          HR_amps (shared.HR_trans.rows()),
          Mt_b (shared.HR_trans.cols()),
          llt (work.rows()),
          old_neg (shared.HR_trans.rows()) { }

        CSD (const CSD&) = default;

        ~CSD() { }

        template <class VectorType>
          void set (const VectorType& DW_signals) {
            F.head (shared.rconv.rows()) = shared.rconv * DW_signals;
            F.tail (F.size()-shared.rconv.rows()).setZero();
            old_neg.assign (1, -1);

            Mt_b = shared.M.transpose() * DW_signals;
          }

        bool iterate() {
          neg.clear();
          HR_amps = shared.HR_trans * F;
          for (ssize_t n = 0; n < HR_amps.size(); n++)
            if (HR_amps[n] < shared.threshold)
              neg.push_back (n);

          if (old_neg == neg)
            return true;

          work.triangularView<Eigen::Lower>() = shared.Mt_M.triangularView<Eigen::Lower>();

          if (neg.size()) {
            for (size_t i = 0; i < neg.size(); i++)
              HR_T.row (i) = shared.HR_trans.row (neg[i]);
            auto HR_T_view = HR_T.topRows (neg.size());
            work.triangularView<Eigen::Lower>() += HR_T_view.transpose() * HR_T_view;
          }

          F.noalias() = llt.compute (work.triangularView<Eigen::Lower>()).solve (Mt_b);

          old_neg = neg;

          return false;
        }

        const Eigen::VectorXd& FOD () const { return F; }


        const Shared& shared;

      protected:
        Eigen::MatrixXd work, HR_T;
        Eigen::VectorXd F, init_F, HR_amps, Mt_b;
        Eigen::LLT<Eigen::MatrixXd> llt;
        vector<int> neg, old_neg;
    };


    }
  }
}

#endif


