! Description:
!> @file
!!   Subroutines needed by MFASIS-NN for computation of cloud NN input parameters
!!   from model profiles
!
!> @brief
!!   Subroutines needed by MFASIS-NN for computation of cloud NN input parameters
!!   from model profiles
!!
!! @details
!! ...
!
! Copyright:
!    This software was developed within the context of
!    the EUMETSAT Satellite Application Facility on
!    Numerical Weather Prediction (NWP SAF), under the
!    Cooperation Agreement dated 7 September 2021, between
!    EUMETSAT and the Met Office, UK, by one or more partners
!    within the NWP SAF. The partners in the NWP SAF are
!    the Met Office, ECMWF, DWD and MeteoFrance.
!
!    Copyright 2024, EUMETSAT, All Rights Reserved.

SUBROUTINE rttov_calc_mfasis_nn_hydro_inpar_ad( &
    nlayers,            &
    prof,               &
    profad,             &
    cloud_columns,      &
    ist,                &
    nch,                &
    cc,                 &
    opdpext,            &
    opdpext_ad,         &
    mfasisnn_coefs,     &
    profiles,           &
    profiles_ad,        &
    aux,                &
    aux_ad,             &
    nn_inpar_idx,       &
    pressure_gradients, &
    sza,                &
    vza,                &
    tauw,               &
    taui,               &
    tauwi_top,          &
    tauwi_bot,          &
    reffw_top,          &
    reffi_top,          &
    reffw_bot,          &
    reffi_bot,          &
    reffwi_top,         &
    reffwi_bot,         &
    psfc,               &
    tauw_top,           & ! TL/AD version, previously optional output in direct subroutine
    tauw_bot,           &
    taui_top,           &
    taui_bot,           &
    tauwi,              &
    logtauw,            &
    tauw_thr,           &
    logtaui,            &
    taui_thr,           &
    dtauw,              &
    dtaui,              &
    dreffw,             &
    dreffi,             &
    fmp,                &
    fbotw,              &
    fboti,              &
    tauw_ad,         &  ! TL/AD output
    taui_ad,         &
    tauwi_top_ad,    &
    tauwi_bot_ad,    &
    reffw_top_ad,    &
    reffi_top_ad,    &
    reffw_bot_ad,    &
    reffi_bot_ad,    &
    reffwi_top_ad,   &  
    reffwi_bot_ad,   &
    fpct_ad,         &
    psfc_in_ad,         &
    psfc_ad)

!INTF_OFF
#include "throw.h"
!INTF_ON
  USE rttov_kinds, ONLY : jpim, jprv, jplm

  USE rttov_types, ONLY :          &
      rttov_cloud_columns,         &
      rttov_coef_mfasis_nn,        &
      rttov_profile_internal,      &
      rttov_profile_aux
!INTF_OFF

  USE rttov_const, ONLY :          &
      wcl_opac_deff,               &
      pi,                          &
      nopac_wcl,                   &
      clw_deff_index,              &
      ice_baum_index

  USE rttov_mfasis_nn_mod, ONLY :  &
      aa_sc,                       &
      nn_idx_reffw_top, nn_idx_reffi_top

  USE yomhook, ONLY : lhook, dr_hook, jphook
!INTF_ON
  IMPLICIT NONE

  INTEGER(jpim),             INTENT(IN)    :: nlayers
  INTEGER(jpim),             INTENT(IN)    :: prof
  INTEGER(jpim),             INTENT(IN)    :: profad
  TYPE(rttov_cloud_columns), INTENT(IN)    :: cloud_columns
  INTEGER(jpim),             INTENT(IN)    :: ist
  INTEGER(jpim),             INTENT(IN)    :: nch
  INTEGER(jpim),             INTENT(IN)    :: cc
  REAL(jprv),       INTENT(IN)    :: opdpext(:,:,:)
  REAL(jprv),       INTENT(INOUT) :: opdpext_ad(:,:,:)
  TYPE(rttov_coef_mfasis_nn),INTENT(IN)    :: mfasisnn_coefs
  TYPE(rttov_profile_internal),INTENT(IN)  :: profiles(:)
  TYPE(rttov_profile_internal),INTENT(INOUT)  :: profiles_ad(:)
  TYPE(rttov_profile_aux),   INTENT(IN)    :: aux
  TYPE(rttov_profile_aux),   INTENT(INOUT) :: aux_ad
  INTEGER(jpim),             INTENT(IN)    :: nn_inpar_idx(:,:)
  LOGICAL(jplm),             INTENT(IN)    :: pressure_gradients
  REAL(jprv),                INTENT(IN)    :: sza
  REAL(jprv),                INTENT(IN)    :: vza
  REAL(jprv),                INTENT(IN)    :: tauw
  REAL(jprv),                INTENT(IN)    :: taui
  REAL(jprv),                INTENT(IN)    :: tauwi_top
  REAL(jprv),                INTENT(IN)    :: tauwi_bot
  REAL(jprv),                INTENT(IN)    :: reffw_top
  REAL(jprv),                INTENT(IN)    :: reffi_top
  REAL(jprv),                INTENT(IN)    :: reffw_bot
  REAL(jprv),                INTENT(IN)    :: reffi_bot
  REAL(jprv),                INTENT(IN)    :: reffwi_top
  REAL(jprv),                INTENT(IN)    :: reffwi_bot
  REAL(jprv),                INTENT(IN)    :: psfc
  REAL(jprv),       INTENT(IN)   :: tauw_top  ! TL/AD version
  REAL(jprv),       INTENT(IN)   :: tauw_bot
  REAL(jprv),       INTENT(IN)   :: taui_top
  REAL(jprv),       INTENT(IN)   :: taui_bot
  REAL(jprv),       INTENT(IN)   :: tauwi
  REAL(jprv),       INTENT(IN)   :: tauw_thr
  REAL(jprv),       INTENT(IN)   :: taui_thr
  REAL(jprv),       INTENT(IN)   :: logtauw
  REAL(jprv),       INTENT(IN)   :: logtaui
  REAL(jprv), DIMENSION(nlayers),   INTENT(IN) :: dtauw
  REAL(jprv), DIMENSION(nlayers),   INTENT(IN) :: dtaui
  REAL(jprv), DIMENSION(nlayers),   INTENT(IN) :: dreffw
  REAL(jprv), DIMENSION(nlayers),   INTENT(IN) :: dreffi
  REAL(jprv), DIMENSION(nlayers),   INTENT(IN) :: fmp
  REAL(jprv), DIMENSION(nlayers),   INTENT(IN) :: fbotw
  REAL(jprv), DIMENSION(nlayers),   INTENT(IN) :: fboti
  REAL(jprv),                INTENT(INOUT)   :: tauw_ad
  REAL(jprv),                INTENT(INOUT)   :: taui_ad
  REAL(jprv),                INTENT(INOUT)   :: tauwi_top_ad
  REAL(jprv),                INTENT(INOUT)   :: tauwi_bot_ad
  REAL(jprv),                INTENT(INOUT)   :: reffw_top_ad
  REAL(jprv),                INTENT(INOUT)   :: reffi_top_ad
  REAL(jprv),                INTENT(INOUT)   :: reffw_bot_ad
  REAL(jprv),                INTENT(INOUT)   :: reffi_bot_ad
  REAL(jprv),                INTENT(INOUT)   :: reffwi_top_ad
  REAL(jprv),                INTENT(INOUT)   :: reffwi_bot_ad
  REAL(jprv),                INTENT(INOUT)   :: fpct_ad
  REAL(jprv),                INTENT(IN)   :: psfc_in_ad
  REAL(jprv),                INTENT(INOUT)   :: psfc_ad

!INTF_END

  REAL(jprv)                     :: fpct
  REAL(jprv)                     :: fac_hi, fac_lo, tau_lo, tau_hi, tau, tau_thr
  REAL(jprv)                     :: tau_thr_ad, tau_thr_v1, tau_ad
  REAL(jprv)                     :: tauw_thr_ad, taui_thr_ad, logtauw_ad, logtaui_ad
  REAL(jprv)                     :: tauw_top_ad, tauw_bot_ad, taui_top_ad, taui_bot_ad, tauwi_ad
  REAL(jprv), DIMENSION(nlayers) :: dtauw_ad, dtaui_ad, fmp_ad, dreffw_ad, dreffi_ad, fbotw_ad, fboti_ad, tauv1
  REAL(jprv), DIMENSION(nlayers) :: ddtaui_ad
  INTEGER(jpim)                  :: j
  INTEGER(jpim)                  :: err

  REAL(jphook) :: zhook_handle
  !- End of header --------------------------------------------------------

  TRY

  IF (lhook) CALL dr_hook('RTTOV_CALC_MFASIS_NN_HYDRO_INPAR_AD',0_jpim,zhook_handle)

  ! --------------------------------------------------------------------------
  ! Computation of NN input parameters from model profiles:
  ! Optical depths, effective radii, ...
  ! --------------------------------------------------------------------------
  tauwi_ad     = 0

  tau_thr_ad   = 0
  tauw_top_ad  = 0
  taui_top_ad  = 0
  tau_ad       = 0
  tauw_thr_ad  = 0
  taui_thr_ad  = 0
  logtauw_ad   = 0
  logtaui_ad   = 0
  tauw_bot_ad  = 0
  taui_bot_ad  = 0

  dtauw_ad  = 0
  dtaui_ad  = 0
  ddtaui_ad = 0
  fmp_ad    = 0
  dreffw_ad = 0
  dreffi_ad = 0
  fbotw_ad  = 0
  fboti_ad  = 0
  tauv1     = 0

!       tauw_a_ad(ncl) = tauw_ad
!       taui_a_ad(ncl) = taui_ad
!       tauwi_top_a_ad(ncl) = tauwi_top_ad
!       tauwi_bot_a_ad(ncl) = tauwi_bot_ad
!       reffw_top_a_ad(ncl) = reffw_top_ad
!       reffi_top_a_ad(ncl) = reffi_top_ad
!       reffw_bot_a_ad(ncl) = reffw_bot_ad
!       reffi_bot_a_ad(ncl) = reffi_bot_ad
!       reffwi_top_a_ad(ncl) = reffwi_top_ad
!       reffwi_bot_a_ad(ncl) = reffwi_bot_ad
!       psfc_a_ad(ncl) = psfc_ad
!       fpct_a_ad(ncl) = fpct_ad

  psfc_ad       = psfc_ad       + psfc_in_ad

!=================================================================================================
! threshold optical depth for detecting cloud top
!            tau_thr_v1 = (tauw + taui + tauwi)/2._jprv
!            tau_thr    = tau_thr_v1/(tau_thr_v1 + 1._jprv)
!
!            tau_thr_ad = (tauw_ad + taui_ad + tauwi_ad)/2._jprv
!            tau_thr_ad = tau_thr_ad/(tau_thr_v1 + 1._jprv)  -  &
!                         tau_thr_ad * tau_thr_v1/(tau_thr_v1 + 1._jprv)**2
!
!            tau = 0._jprv
!            tau_ad = 0._jprv
!            DO j = 1, nlay
!               tau = tau + dtauw(j) + dtaui(j)
!               tau_ad = tau_ad + dtauw_ad(j) + dtaui_ad(j)
!               IF ( tau > tau_thr) THEN
!                  fpct    = profiles   (prof)%p_half(j+1) -   (tau    - tau_thr)/(dtauw(j) + dtaui(j))   &
!                                                * (profiles(prof)%p_half(j+1) - profiles(prof)%p_half(j))
!
!                  fpct_ad = profiles_ad(profad)%p_half(j+1) +( -(tau_ad - tau_thr_ad)/(dtauw(j) + dtaui(j)) &
!                                                * (profiles(prof)%p_half(j+1) - profiles(prof)%p_half(j)) &
!                                                          +(tau    - tau_thr)/(dtauw(j) + dtaui(j))**2 &
!                                             * (profiles(prof)%p_half(j+1) - profiles(prof)%p_half(j)) &
!                                                                            * (dtauw_ad(j) + dtaui_ad(j)) &
!                                                          -(tau    - tau_thr)/(dtauw(j) + dtaui(j))    &
!                                 * (profiles_ad(profad)%p_half(j+1) - profiles_ad(profad)%p_half(j)) &
!                                                        )
!                  fpct = fpct / psfc
!                  fpct_ad = fpct_ad / psfc - fpct * psfc0_ad/psfc
!                  EXIT
!               ENDIF
!            ENDDO
  tau_thr_v1 = (tauw + taui + tauwi)/2._jprv
  tau_thr    = tau_thr_v1/(tau_thr_v1 + 1._jprv)
  tau = 0._jprv
  tauv1(:) = - 9999999._jprv
  DO j = 1, nlayers
    tau = tau + dtauw(j) + dtaui(j)
    tauv1(j)=tau
    IF(tauv1(j) > tau_thr) THEN
      EXIT
    ENDIF
  ENDDO
  tau_ad = 0
  DO j = nlayers, 1, -1
    IF ( tauv1(j) > tau_thr) THEN
      fpct    = profiles   (prof)%p_half(j+1) -   (tauv1(j)-tau_thr)/(dtauw(j) + dtaui(j))    &
                                     * (profiles(prof)%p_half(j+1) - profiles(prof)%p_half(j))
      fpct = fpct / psfc

!                  fpct_ad = fpct_ad / psfc - fpct * psfc0_ad/psfc
      psfc_ad = psfc_ad - fpct_ad * fpct/psfc
      fpct_ad  = fpct_ad / psfc

!                  fpct_ad = profiles_ad(profad)%p_half(j+1) +( -(tau_ad - tau_thr_ad)/(dtauw(j) + dtaui(j)) &
!                                                * (profiles(prof)%p_half(j+1) - profiles(prof)%p_half(j)) &
!                                                          +(tau    - tau_thr)/(dtauw(j) + dtaui(j))**2 &
!                                                * (profiles(prof)%p_half(j+1) - profiles(prof)%p_half(j)) &
!                                                                            * (dtauw_ad(j) + dtaui_ad(j)) &
!                                                          -(tau    - tau_thr)/(dtauw(j) + dtaui(j))    &
!                                        * (profiles_ad(profad)%p_half(j+1) - profiles_ad(profad)%p_half(j)) &

      tau_ad     =    tau_ad  - fpct_ad/(dtauw(j)+dtaui(j)) &
                                             * (profiles(prof)%p_half(j+1) - profiles(prof)%p_half(j))
      tau_thr_ad = tau_thr_ad + fpct_ad/(dtauw(j)+dtaui(j)) &
                                             * (profiles(prof)%p_half(j+1) - profiles(prof)%p_half(j))
      dtauw_ad(j)= dtauw_ad(j)+ fpct_ad * (tauv1(j)-tau_thr)/(dtauw(j)+ dtaui(j))**2 &
                                             * (profiles(prof)%p_half(j+1)-profiles(prof)%p_half(j))
      dtaui_ad(j)= dtaui_ad(j)+ fpct_ad * (tauv1(j)-tau_thr)/(dtauw(j)+ dtaui(j))**2 &
                                             * (profiles(prof)%p_half(j+1)-profiles(prof)%p_half(j))
      IF (pressure_gradients) THEN
        profiles_ad(profad)%p_half(j+1) = profiles_ad(profad)%p_half(j+1) - &
            fpct_ad * (tauv1(j)-tau_thr)/(dtauw(j)+dtaui(j)) + fpct_ad
        profiles_ad(profad)%p_half(j  ) = profiles_ad(profad)%p_half(j  ) + &
            fpct_ad * (tauv1(j)-tau_thr)/(dtauw(j)+dtaui(j))
      ENDIF
      fpct_ad = 0
    ENDIF
!              tau_ad = tau_ad + dtauw_ad(j) + dtaui_ad(j)
    dtauw_ad(j) = dtauw_ad(j) + tau_ad
    dtaui_ad(j) = dtaui_ad(j) + tau_ad
  ENDDO
!            tau_thr_ad = (tauw_ad + taui_ad + tauwi_ad)/2._jprv
!            tau_thr_ad = tau_thr_ad/(tau_thr_v1 + 1._jprv)  -  &
!                         tau_thr_ad * tau_thr_v1/(tau_thr_v1 + 1._jprv)**2
  tau_thr_ad = tau_thr_ad/(tau_thr_v1 + 1._jprv)  -  &
                    tau_thr_ad * tau_thr_v1/(tau_thr_v1 + 1._jprv)**2
  tauw_ad    = tauw_ad + tau_thr_ad/2._jprv
  taui_ad    = taui_ad + tau_thr_ad/2._jprv
  tauwi_ad   = tauwi_ad+ tau_thr_ad/2._jprv
  tau_thr_ad = 0

  IF (SIZE(mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffi_top))%auxparams) > 0) THEN
    IF(taui > 0) THEN
      IF ( taui_top > 1E-6 ) THEN
!                    reffi_top_ad = SUM( (    -fboti_ad(:)) * dtaui(:) * (1._jprv-fmp) * dreffi(:)    +       &
!                                        (1._jprv-fboti(:)) * dtaui_ad(:)*(1._jprv-fmp) * dreffi(:)   +       &
!                                        (1._jprv-fboti(:)) * dtaui(:)   *( -fmp_ad) * dreffi(:)   +       &
!                                        (1._jprv-fboti(:)) * dtaui(:) *(1._jprv-fmp) * dreffi_ad(:) ) /taui_top - &
!                                   reffi_top/taui_top * taui_top_ad

        fboti_ad(:)   = fboti_ad(:)    - dtaui(:)           * (1._jprv-fmp) * dreffi(:) * reffi_top_ad/ taui_top
        dtaui_ad(:)   = dtaui_ad(:)    + (1._jprv-fboti(:)) * (1._jprv-fmp) * dreffi(:) * reffi_top_ad/ taui_top
        fmp_ad        = fmp_ad         - (1._jprv-fboti(:)) * dtaui(:)      * dreffi(:) * reffi_top_ad/ taui_top
        dreffi_ad(:)  = dreffi_ad(:)   + (1._jprv-fboti(:)) * (1._jprv-fmp) * dtaui(:)  * reffi_top_ad/ taui_top

        taui_top_ad   = taui_top_ad    -                                     reffi_top  * reffi_top_ad/ taui_top
        reffi_top_ad  = 0
      ELSE
        reffi_top_ad = 0._jprv
      ENDIF

      IF( taui_bot > 1E-6 ) THEN
!                    reffi_bot_ad = SUM( fboti_ad(:)     * dtaui(:) *(1._jprv-fmp) * dreffi(:)   +       &
!                                        fboti(:)        * dtaui_ad(:)*(1._jprv-fmp) * dreffi(:)   +       &
!                                        fboti(:)        * dtaui(:)   *( -fmp_ad) * dreffi(:)   +       &
!                                        fboti(:)        * dtaui(:) *(1._jprv-fmp) * dreffi_ad(:)    ) / taui_bot - &
!                                   reffi_bot/taui_bot * taui_bot_ad
        fboti_ad(:)    = fboti_ad(:)    + dtaui(:) *(1._jprv-fmp) * dreffi(:) * reffi_bot_ad/taui_bot
        dtaui_ad(:)    = dtaui_ad(:)    + fboti(:) *(1._jprv-fmp) * dreffi(:) * reffi_bot_ad/taui_bot
        fmp_ad         = fmp_ad         - fboti(:) * dtaui(:)     * dreffi(:) * reffi_bot_ad/taui_bot
        dreffi_ad(:)   = dreffi_ad(:)   + fboti(:) *(1._jprv-fmp) * dtaui(:)  * reffi_bot_ad/taui_bot
        taui_bot_ad    = taui_bot_ad    -                          reffi_bot  * reffi_bot_ad/taui_bot
        reffi_bot_ad = 0
      ELSE
        reffi_bot_ad = 0._jprv 
      ENDIF

!                 taui_bot_ad = SUM( fboti_ad(:) * dtaui   (:) * (1._jprv - fmp (:))    &
!                                  + fboti(:)    * dtaui_ad(:) * (1._jprv - fmp (:))    &
!                                  - fboti(:)    * dtaui   (:) * fmp_ad(:)  )
!                 taui_top_ad = taui_ad - taui_bot_ad
      taui_ad     = taui_ad     + taui_top_ad
      taui_bot_ad = taui_bot_ad - taui_top_ad
      taui_top_ad = 0
      dtaui_ad(:) = dtaui_ad(:) + fboti(:) * (1._jprv - fmp (:))   * taui_bot_ad 
      fboti_ad(:) = fboti_ad(:) + dtaui(:) * (1._jprv - fmp (:))   * taui_bot_ad
      fmp_ad(:)   = fmp_ad(:)   - fboti(:) * dtaui(:)              * taui_bot_ad
      taui_bot_ad = 0

      !---------------------------------------------------------------
      call comp_fbot_ad(fboti_ad, taui_thr_ad, ddtaui_ad, taui_thr, dtaui*(1._jprv - fmp(:)), nlayers, aa_sc)
      !---------------------------------------------------------------
      fboti_ad = 0

!                 taui_top_ad = taui_thr_ad
!                 taui_bot_ad = taui_ad    - taui_thr_ad
!                 ddtaui_ad(:)= dtaui_ad(:)*(1._jprv-fmp) - dtaui(:)*fmp_ad
      fmp_ad      = fmp_ad        - dtaui(:)* ddtaui_ad(:)
      dtaui_ad(:) = dtaui_ad(:)   +  ddtaui_ad(:)*(1._jprv-fmp)                 
      ddtaui_ad(:)= 0

!--------------------------------------------------------
      fac_hi =  mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffi_top))%auxparams(3) &
                - mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffi_top))%auxparams(4) * MAX( vza, sza ) / 90._jprv
      fac_lo = 0.5_jprv 
      tau_lo = mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffi_top))%auxparams(1)
      tau_hi = mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffi_top))%auxparams(2)
!--------------------------------------------------------
!                 taui_thr_ad = (fac_lo- (fac_lo-fac_hi) * 0.5_jprv*( 1._jprv - COS( logtaui * pi ) ) ) * taui_ad
!                 IF(logtaui > 0.0_jprv .AND. logtaui <  1.0_jprv ) then
!                   logtaui_ad  = taui_ad/(taui + 1E-6) / (LOG(tau_hi) - LOG(tau_lo))
!                   taui_thr_ad = taui_thr_ad - (fac_lo-fac_hi)*0.5_jprv*pi* SIN(logtaui * pi) * taui * logtaui_ad
!                 ENDIF
      ! logtaui = MIN( MAX( (LOG(taui) - LOG(tau_lo)) / (LOG(tau_hi) - LOG(tau_lo)), 0._jprv ), 1._jprv ) 
      IF(logtaui > 0.0_jprv .AND. logtaui <  1.0_jprv ) then
!                   logtaui_ad  = taui_ad/(taui + 1E-6) / (LOG(tau_hi) - LOG(tau_lo))
!                   taui_thr_ad = taui_thr_ad - (fac_lo-fac_hi)*0.5_jprv*pi* SIN(logtaui * pi) * taui * logtaui_ad
        logtaui_ad  = logtaui_ad  - (fac_lo-fac_hi)*0.5_jprv*pi* SIN(logtaui * pi) * taui * taui_thr_ad

        taui_ad     = taui_ad + logtaui_ad / taui / (LOG(tau_hi) - LOG(tau_lo))
        logtaui_ad  = 0
      ENDIF
!                 taui_thr_ad = (fac_lo- (fac_lo-fac_hi) * 0.5_jprv*( 1._jprv - COS( logtaui * pi ) ) ) * taui_ad
      taui_ad = taui_ad + (fac_lo- (fac_lo-fac_hi) * 0.5_jprv*( 1._jprv - COS( logtaui * pi))) * taui_thr_ad
      taui_thr_ad = 0

    ELSE ! taui > 0)
      reffi_bot_ad = 0._jprv
      reffi_top_ad = 0._jprv
    ENDIF
  ELSE
    err = errorstatus_fatal
    THROWM(err.NE.0, "MFASIS-NN has no information on two-layer parameterisation of ice cloud")
  ENDIF

  IF (SIZE(mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffw_top))%auxparams) > 0) THEN
    IF(tauw > 0) THEN
      IF ( tauwi_top > 1E-6 ) THEN
!                    reffwi_top_ad = SUM(-fbotw_ad(:)    * dtaui(:)   * fmp(:) * dreffi(:) +       &
!                                    (1._jprv - fbotw(:))* dtaui_ad(:)* fmp(:) * dreffi(:) +       &
!                                    (1._jprv - fbotw(:))* dtaui(:)   * fmp_ad(:) * dreffi(:) +       &
!                                    (1._jprv - fbotw(:))* dtaui(:)   * fmp(:) * dreffi_ad(:)    ) / tauwi_top -  &
!                                    reffwi_top/tauwi_top * tauwi_top_ad
        fbotw_ad(:)    = fbotw_ad(:)    - dtaui(:)           * fmp(:)  * dreffi(:) * reffwi_top_ad/tauwi_top
        dtaui_ad(:)    = dtaui_ad(:)    +(1._jprv - fbotw(:))* fmp(:)  * dreffi(:) * reffwi_top_ad/tauwi_top 
        fmp_ad(:)      = fmp_ad(:)      +(1._jprv - fbotw(:))* dtaui(:)* dreffi(:) * reffwi_top_ad/tauwi_top
        dreffi_ad(:)   = dreffi_ad(:)   +(1._jprv - fbotw(:))* fmp(:)  * dtaui(:)  * reffwi_top_ad/tauwi_top
        tauwi_top_ad   = tauwi_top_ad   -                              reffwi_top  * reffwi_top_ad/tauwi_top
        reffwi_top_ad = 0
      ELSE
        reffwi_top_ad = 0._jprv !
      ENDIF
!                 tauwi_top_ad = SUM(-fbotw_ad(:) * dtaui(:)   * fmp(:)    + &
!                         (1._jprv - fbotw(:)) * dtaui_ad(:)* fmp(:)    + &
!                         (1._jprv - fbotw(:)) * dtaui(:)   * fmp_ad(:)     )
      fmp_ad(:)   = fmp_ad(:)   + (1._jprv - fbotw(:)) * dtaui(:) * tauwi_top_ad 
      dtaui_ad(:) = dtaui_ad(:) + (1._jprv - fbotw(:)) * fmp(:)   * tauwi_top_ad 
      fbotw_ad(:) = fbotw_ad(:) -            dtaui(:)  * fmp(:)   * tauwi_top_ad 
      tauwi_top_ad = 0


      IF ( tauwi_bot > 1E-6 ) THEN
!                    reffwi_bot_ad = SUM( fbotw_ad(:)    * dtaui(:)   * fmp(:) * dreffi(:) +       &
!                                        fbotw(:)        * dtaui_ad(:)* fmp(:) * dreffi(:) +       &
!                                        fbotw(:)        * dtaui(:)   * fmp_ad(:) * dreffi(:) +       &
!                                        fbotw(:)        * dtaui(:)   * fmp(:) * dreffi_ad(:)    ) / tauwi_bot - &
!                                    reffwi_bot/tauwi_bot * tauwi_bot_ad
        fbotw_ad(:)   = fbotw_ad(:)    +  dtaui(: )* fmp(:)   * dreffi(:) * reffwi_bot_ad/tauwi_bot
        dtaui_ad(:)   = dtaui_ad(:)    +  fbotw(:) * fmp(:)   * dreffi(:) * reffwi_bot_ad/tauwi_bot
        fmp_ad(:)     = fmp_ad(:)      +  fbotw(:) * dtaui(:) * dreffi(:) * reffwi_bot_ad/tauwi_bot
        dreffi_ad(:)  = dreffi_ad(:)   +  fbotw(:) * fmp(:)   * dtaui(:)  * reffwi_bot_ad/tauwi_bot
        tauwi_bot_ad  = tauwi_bot_ad   -                       reffwi_bot * reffwi_bot_ad/tauwi_bot
        reffwi_bot_ad = 0
      ELSE
        reffwi_bot_ad = 0._jprv !
      ENDIF
!                 tauwi_bot_ad = SUM( fbotw_ad(:) * dtaui(:)   * fmp(:)    + &
!                                     fbotw(:)    * dtaui_ad(:)* fmp(:)    + &
!                                     fbotw(:)    * dtaui(:)   * fmp_ad(:)     )
      fbotw_ad(:)  = fbotw_ad(:)  + dtaui(:)   * fmp(:)  * tauwi_bot_ad
      dtaui_ad(:)  = dtaui_ad(:)  + fbotw(:)   * fmp(:)  * tauwi_bot_ad
      fmp_ad(:)    = fmp_ad(:)    + fbotw(:)   * dtaui(:)* tauwi_bot_ad
      tauwi_bot_ad = 0

      IF ( tauw_top > 1E-6 ) THEN
!                    reffw_top_ad = SUM(     -fbotw_ad(:)* dtauw(:)    * dreffw(:) +       &
!                                     (1._jprv-fbotw(:)) * dtauw_ad(:) * dreffw(:) +       &
!                                     (1._jprv-fbotw(:)) * dtauw(:)    * dreffw_ad(:)    ) / tauw_top     -       &
!                                   reffw_top/tauw_top * tauw_top_ad
        fbotw_ad    = fbotw_ad    -                   dtauw(:)  * dreffw(:) * reffw_top_ad/tauw_top
        dtauw_ad(:) = dtauw_ad(:) +(1._jprv-fbotw(:))           * dreffw(:) * reffw_top_ad/tauw_top
        dreffw_ad(:)= dreffw_ad(:)+(1._jprv-fbotw(:))*dtauw(:)              * reffw_top_ad/tauw_top
        tauw_top_ad = tauw_top_ad -                               reffw_top * reffw_top_ad/tauw_top
        reffw_top_ad = 0
      ELSE
        reffw_top_ad = 0._jprv !
      ENDIF

      IF ( tauw_bot > 1E-6 ) THEN
!                    reffw_bot_ad = SUM( fbotw_ad(:)     * dtauw(:)    * dreffw(:) +       &
!                                        fbotw(:)        * dtauw_ad(:) * dreffw(:) +       &
!                                        fbotw(:)        * dtauw(:)    * dreffw_ad(:)    ) / tauw_bot     -       &
!                                   reffw_bot/tauw_bot * tauw_bot_ad
        fbotw_ad(:)  = fbotw_ad(:)  +         dtauw(:) * dreffw(:) * reffw_bot_ad/ tauw_bot
        dtauw_ad(:)  = dtauw_ad(:)  +fbotw(:)          * dreffw(:) * reffw_bot_ad/ tauw_bot
        dreffw_ad(:) = dreffw_ad(:) +fbotw(:)*dtauw(:)             * reffw_bot_ad/ tauw_bot
        tauw_bot_ad  = tauw_bot_ad  -                   reffw_bot  * reffw_bot_ad/ tauw_bot
        reffw_bot_ad = 0
      ELSE
        reffw_bot_ad = 0._jprv !
      ENDIF

!                 tauw_bot_ad = SUM( fbotw_ad(:) * dtauw(:) + fbotw(:) * dtauw_ad(:))
!                 tauw_top_ad = tauw_ad - tauw_bot_ad
      tauw_ad     = tauw_ad     + tauw_top_ad 
      tauw_bot_ad = tauw_bot_ad - tauw_top_ad 
      tauw_top_ad = 0
      fbotw_ad(:) = fbotw_ad(:) + dtauw(:) * tauw_bot_ad
      dtauw_ad(:) = dtauw_ad(:) + fbotw(:) * tauw_bot_ad
      tauw_bot_ad = 0
      !---------------------------------------------------------------
      call comp_fbot_ad(fbotw_ad, tauw_thr_ad, dtauw_ad, tauw_thr, dtauw, nlayers, aa_sc)
      !---------------------------------------------------------------
      fbotw_ad = 0

!                 tauw_top_ad = tauw_thr_ad
!                 tauw_bot_ad = tauw_ad    - tauw_thr_ad

      tauw_thr_ad = tauw_thr_ad + tauw_top_ad
      tauw_top_ad = 0

!-----------------------------------
      fac_hi =  mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffw_top))%auxparams(3) &
                - mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffw_top))%auxparams(4) * MAX( vza, sza ) / 90._jprv

      fac_lo = 0.5_jprv !
      tau_lo = mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffw_top))%auxparams(1)
      tau_hi = mfasisnn_coefs%nn(nch)%in(nn_inpar_idx(nch, nn_idx_reffw_top))%auxparams(2)
!-----------------------------------
!               tauw_thr_ad = ( fac_lo - (fac_lo-fac_hi) * 0.5_jprv*( 1._jprv - COS( logtauw * pi ) ) ) * tauw_ad
!               IF(logtauw > 0.0_jprv .AND. logtauw <  1.0_jprv ) then
!                 logtauw_ad  = tauw_ad/(tauw + 1E-6) / (LOG(tau_hi) - LOG(tau_lo))
!                 tauw_thr_ad = tauw_thr_ad - (fac_lo-fac_hi)*0.5_jprv*pi* SIN(logtauw * pi) * tauw * logtauw_ad
!               ENDIF
      ! logtauw = MIN( MAX( (LOG(tauw) - LOG(tau_lo)) / (LOG(tau_hi) - LOG(tau_lo)), 0._jprv ), 1._jprv ) 
      IF(logtauw > 0.0_jprv .AND. logtauw <  1.0_jprv ) then
!                 logtauw_ad  = tauw_ad/(tauw + 1E-6) / (LOG(tau_hi) - LOG(tau_lo))
!                 tauw_thr_ad = tauw_thr_ad - (fac_lo-fac_hi)*0.5_jprv*pi* SIN(logtauw * pi) * tauw * logtauw_ad
        logtauw_ad = logtauw_ad   - (fac_lo-fac_hi)*0.5_jprv*pi* SIN(logtauw * pi) * tauw * tauw_thr_ad
        tauw_ad    = tauw_ad      + logtauw_ad / tauw / (LOG(tau_hi) - LOG(tau_lo))
        logtauw_ad = 0
      ENDIF

!           tauw_thr_ad =          (fac_lo- (fac_lo-fac_hi)* 0.5_jprv*( 1._jprv - COS(logtauw*pi))) * tauw_ad
      tauw_ad     = tauw_ad +(fac_lo- (fac_lo-fac_hi)* 0.5_jprv*( 1._jprv - COS(logtauw*pi))) * tauw_thr_ad
      tauw_thr_ad = 0
!---------------------=====================
    ELSE ! tauw > 0)
      reffw_bot_ad = 0._jprv
      reffw_top_ad = 0._jprv
      reffwi_bot_ad = 0._jprv
      reffwi_top_ad = 0._jprv
    ENDIF  ! tauw > 0
  ELSE
    err = errorstatus_fatal
    THROWM(err.NE.0, "MFASIS-NN has no information on two-layer parameterisation of water cloud")
  ENDIF

!            tauwi_ad = SUM(   dtaui_ad(:) * fmp(:)     &
!                           +  dtaui   (:) * fmp_ad(:) ) ! mixed-phase cloud ice in water cloud
!            taui_ad  = SUM(   dtaui_ad(:) * (1._jprv - fmp(:))  &
!                           -  dtaui   (:) * fmp_ad(:) ) ! mixed-phase cloud ice in water cloud
  fmp_ad(:)   = fmp_ad(:)   - dtaui(:)           * taui_ad
  dtaui_ad(:) = dtaui_ad(:) + (1._jprv - fmp(:)) * taui_ad
  taui_ad  = 0
  fmp_ad(:)   = fmp_ad(:)   + dtaui(:)           * tauwi_ad
  dtaui_ad(:) = dtaui_ad(:) + fmp(:)             * tauwi_ad
  tauwi_ad = 0

  tau_thr    =  1._jprv
  !---------------------------------------------------------------
  call comp_fbot_ad(fmp_ad, tau_thr_ad, dtauw_ad, tau_thr, dtauw, nlayers, aa_sc)
  !---------------------------------------------------------------
  fmp_ad = 0
  tau_thr_ad =  0

  DO j = 1, nlayers
    IF ( cloud_columns%icldarr(cc,j,prof) .EQ. 1 ) THEN
!                  tauw_ad = tauw_ad + dtauw_ad(j)
      dtauw_ad(j) = dtauw_ad(j) + tauw_ad
!                  dreffi_ad(j) = aux_ad%hydro_deff(ice_baum_index,j,prof)/2._jprv  !
      aux_ad%hydro_deff(ice_baum_index,j,profad) = &
          aux_ad%hydro_deff(ice_baum_index,j,profad) + dreffi_ad(j)/2._jprv  !
      dreffi_ad(j) = 0

      ! layer eff. radii water
      IF ( dtauw(j) .GT. 0 ) THEN
!                        dreffw_ad(j) = dreffw_ad(j)/dtauw(j) - dreffw(j)/(dtauw(j))    * dtauw_ad(j)
        dtauw_ad(j)  = dtauw_ad(j)    -  dreffw(j)/(dtauw(j))    * dreffw_ad(j)
        dreffw_ad(j) = dreffw_ad(j)/dtauw(j)

        aux_ad%hydro_deff(clw_deff_index,j,profad) = &
            aux_ad%hydro_deff(clw_deff_index,j,profad) + &
            dreffw_ad(j) /2._jprv * opdpext(clw_deff_index,j,ist)
        opdpext_ad(clw_deff_index,j,ist) = &
            opdpext_ad(clw_deff_index,j,ist) + &
            dreffw_ad(j) /2._jprv * aux%hydro_deff(clw_deff_index,j,prof)
        opdpext_ad(1:nopac_wcl,j,ist) = &
            opdpext_ad(1:nopac_wcl,j,ist) + &
            dreffw_ad(j) * wcl_opac_deff(1:nopac_wcl)/2._jprv
      ELSE
        dreffw_ad(j) = 0._jprv !
      ENDIF

      ! layer optical depths
      opdpext_ad(ice_baum_index,j,ist) = &
          opdpext_ad(ice_baum_index,j,ist) + dtaui_ad(j)
      opdpext_ad(1:clw_deff_index,j,ist) = &
          opdpext_ad(1:clw_deff_index,j,ist) + dtauw_ad(j)
    ENDIF

  ENDDO ! DO j = 1, nlay

  tauw_ad = 0.0_jprv
  taui_ad = 0.0_jprv
  tauwi_top_ad = 0.0_jprv
  tauwi_bot_ad = 0.0_jprv
  reffw_top_ad = 0.0_jprv
  reffi_top_ad = 0.0_jprv
  reffw_bot_ad = 0.0_jprv
  reffi_bot_ad = 0.0_jprv
  reffwi_top_ad = 0.0_jprv
  reffwi_bot_ad = 0.0_jprv
  fpct_ad = 0._jprv

  ! Extract optical depths and effective radii per layer
  dtauw_ad(:) = 0._jprv
  dtaui_ad(:) = 0._jprv
  fmp_ad(:) = 0._jprv
  dreffw_ad(:) = 0._jprv
  dreffi_ad(:) = 0._jprv


  IF (lhook) CALL dr_hook('RTTOV_CALC_MFASIS_NN_HYDRO_INPAR_AD',1_jpim,zhook_handle)

  CATCH

  IF (lhook) CALL dr_hook('RTTOV_CALC_MFASIS_NN_HYDRO_INPAR_AD',1_jpim,zhook_handle)


  CONTAINS

  !----------------------------------------------------------------------------------------------
  !   Compute weighting factors fbot for bottom layers
  !----------------------------------------------------------------------------------------------
  SUBROUTINE  comp_fbot_ad(fbot_ad, tau_thr_ad, dtaul_ad, tau_thr, dtaul, nlay, aa_sc)

    IMPLICIT NONE
!===========================================================
    INTEGER(jpim),    INTENT(IN)    :: nlay
    REAL(jprv),       INTENT(INOUT) :: fbot_ad(nlay) 
    REAL(jprv),       INTENT(INOUT) :: tau_thr_ad
    REAL(jprv),       INTENT(INOUT) :: dtaul_ad(:)
    REAL(jprv),       INTENT(IN)    :: tau_thr
    REAL(jprv),       INTENT(IN)    :: dtaul(:)
    REAL(jprv),       INTENT(IN)    :: aa_sc

    REAL(jprv)                      :: fbot (nlay)
    REAL(jprv)                      :: ffunc(nlay)
    REAL(jprv)                      :: ffunc_ad(nlay)

    REAL(jprv)                      :: fbot_old, ffunc_old
    REAL(jprv)                      :: fbot_old_ad, ffunc_old_ad
    REAL(jprv)                      :: aa, aa_lim
    REAL(jprv)                      :: aa_ad
    REAL(jprv)                      :: aaa(0:nlay)
    INTEGER(jpim)                   :: j, jj

    aa_lim = aa_sc * 0.5_jprv  * pi

!------------------------------------------------------------------------------
!     1. compute nl quantities "aaa(:), fbot(:)" for linear computations below
!------------------------------------------------------------------------------
    fbot(:)   = 0._jprv
    ffunc(:)  = 0._jprv
    fbot_old  = 0._jprv
    ffunc_old = 0._jprv
    aa = - tau_thr
    aaa(:) = 0
    aaa(0) = aa
    if(aaa(0) > -aa_sc * 0.5_jprv  * pi ) THEN
      ffunc_old = 0.5_jprv*(aaa(0)+ aa_sc*0.5_jprv*pi - aa_sc*COS(aaa(0)/aa_sc) )
    endif
    DO j = 1, nlay
      IF(dtaul(j) /= 0) THEN
!         aa = taui_int - tau_thr
        aa = aa + dtaul(j)
        aaa(j)=aa
        IF(aa <= -aa_lim) cycle
        IF(aa >=  aa_lim) THEN
          ffunc(j) = aa
        ELSE
          ffunc(j) = 0.5_jprv*(aa+ aa_lim - aa_sc*COS(aa/aa_sc) )
        ENDIF
        fbot(j) = (ffunc(j) - ffunc_old)/dtaul(j)
        fbot_old  = fbot(j)
        ffunc_old = ffunc(j)

        IF(aa >=  aa_lim) THEN !set all remaining to one and quit
          DO jj=j+1, nlay
            fbot(jj) = 1._jprv
            aaa(jj)=aa
          ENDDO
          EXIT
        ENDIF
      ELSE
        aaa(j)=aa
        fbot(j)  = fbot_old
      ENDIF
    ENDDO
!------------------------------------------------------------------------------
!     2. start linear computations 
!------------------------------------------------------------------------------

    ffunc_ad(:)  = 0._jprv
    fbot_old_ad  = 0._jprv
    ffunc_old_ad = 0._jprv
    aa_ad = 0
    DO j = nlay, 1, -1
      IF( aaa(j-1) >= aa_lim) THEN
        cycle
      ENDIF
      IF(dtaul(j) /= 0) THEN
!         IF(aaa(j) <= -aa_lim) cycle
        IF(aaa(j) > -aa_lim) THEN

!1)       fbot_ad(j)   = (ffunc_ad(j) - ffunc_old_ad)/dtaul(j) - &
!                         fbot(j) * dtaul_ad(j)/dtaul(j)
!2)       ffunc_old_ad = ffunc_ad(j)
!3)       fbot_old_ad  = fbot_ad(j)

!3)------------------------
          fbot_ad(j) = fbot_ad(j) + fbot_old_ad 
          fbot_old_ad= 0
!2)------------------------

          ffunc_ad(j)= ffunc_ad(j) + ffunc_old_ad
          ffunc_old_ad = 0
!1)------------------------
          ffunc_ad(j)= ffunc_ad(j) + fbot_ad(j)/dtaul(j)
!         ffunc_old_ad= ffunc_old_ad - fbot_ad(j)/dtaul(j)
          ffunc_old_ad=              - fbot_ad(j)/dtaul(j)
          dtaul_ad(j) = dtaul_ad(j)  - fbot_ad(j)/dtaul(j) * fbot(j)
          fbot_ad(j)  = 0
!  ------------------------

          IF(aaa(j) >=  aa_lim) THEN
!           ffunc_ad(j) = aa_ad
            aa_ad = aa_ad + ffunc_ad(j)
            ffunc_ad(j) = 0
          ELSE
!           ffunc_ad(j) = 0.5_jprv*aa_ad* (1._jprv + sin(aaa(j)/aa_sc))
            aa_ad= aa_ad + ffunc_ad(j) * 0.5_jprv* (1._jprv + sin(aaa(j)/aa_sc))
            ffunc_ad(j) = 0
          ENDIF
!         IF(aaa(j) >=  aa_lim) THEN !set all remaining to one and quit
!           DO jj=j+1, nlay
!             fbot_ad(jj) = 0
!           ENDDO
!           EXIT
!         ENDIF
        ENDIF
!         aa_ad = aa_ad + dtaul_ad(j)
        dtaul_ad(j) = dtaul_ad(j) + aa_ad
      ELSE
!         fbot_ad(j)  = fbot_old_ad
        fbot_old_ad  = fbot_old_ad + fbot_ad(j)
        fbot_ad(j)   = 0
        cycle
      ENDIF
    ENDDO
!     if(aaa(0) > -aa_sc * 0.5_jprv  * pi ) THEN
!       ffunc_old_ad = 0.5_jprv*aa_ad* (1._jprv + sin(aaa(0)/aa_sc))
!     endif
    if(aaa(0) > -aa_sc * 0.5_jprv  * pi ) THEN
!       ffunc_old_ad = 0.5_jprv*aa_ad* (1._jprv + sin(aaa(0)/aa_sc))
      aa_ad= aa_ad + ffunc_old_ad * 0.5_jprv* (1._jprv + sin(aaa(0)/aa_sc))
      ffunc_old_ad = 0
    endif
!     aa = - tau_thr
    tau_thr_ad = tau_thr_ad - aa_ad

    fbot_ad(:)   = 0._jprv
!     ffunc_ad(:)  = 0._jprv
!     fbot_old_ad  = 0._jprv
!     ffunc_old_ad = 0._jprv
!     aa_ad = 0
  END SUBROUTINE comp_fbot_ad



END SUBROUTINE rttov_calc_mfasis_nn_hydro_inpar_ad
