# =====================================================================
# InitLambda.R - Lambda Grid Generation
# =====================================================================

InitLambda <- function(lamB, lamTh, init.obj,
                       n, p, q, n.lamB, n.lamTh,
                       lamB.min.ratio, lamTh.min.ratio,
                       lamB.scale.factor=NULL, lamTh.scale.factor=NULL) {
  
  # -------------------- Extract Data Characteristics --------------------
  
  n_eff <- init.obj$n_eff %||% n
  np_ratio <- init.obj$np_ratio %||% (n_eff / p)
  nq_ratio <- init.obj$nq_ratio %||% (n_eff / q)
  sparsity_B <- init.obj$estimated_sparsity_B %||% 0.8
  sparsity_Th <- init.obj$estimated_sparsity_Th %||% 0.7
  high_dim <- init.obj$high_dimensional %||% (np_ratio < 1 || nq_ratio < 1)
  
  # -------------------- Smart Scale Factors --------------------
  
  # Adaptive scale factors based on problem characteristics
  if (is.null(lamB.scale.factor)) {
    if (high_dim) {
      lamB.scale.factor <- 1.5  # More regularization for high-dim
    } else if (sparsity_B > 0.9) {
      lamB.scale.factor <- 1.2  # Slightly more for very sparse
    } else {
      lamB.scale.factor <- 1.0
    }
  }
  
  if (is.null(lamTh.scale.factor)) {
    if (nq_ratio < 1.5) {
      lamTh.scale.factor <- 1.5
    } else if (sparsity_Th > 0.8) {
      lamTh.scale.factor <- 1.2
    } else {
      lamTh.scale.factor <- 1.0
    }
  }
  
  # Validate scale factors
  lamB.scale.factor <- max(0.1, min(10, lamB.scale.factor))
  lamTh.scale.factor <- max(0.1, min(10, lamTh.scale.factor))
  
  # -------------------- Adaptive Min Ratios --------------------
  
  if (is.null(lamB.min.ratio)) {
    # Smart min ratio based on effective sample size and dimensionality
    if (np_ratio > 5) {
      lamB.min.ratio <- 1e-4  # Can explore smaller lambdas
    } else if (np_ratio > 2) {
      lamB.min.ratio <- 5e-4
    } else if (np_ratio > 1) {
      lamB.min.ratio <- 1e-3
    } else {
      lamB.min.ratio <- 5e-3  # High-dim: don't go too small
    }
    
    # Adjust for sparsity
    lamB.min.ratio <- lamB.min.ratio * (1 + sparsity_B)
  }
  
  if (is.null(lamTh.min.ratio)) {
    if (nq_ratio > 5) {
      lamTh.min.ratio <- 1e-4
    } else if (nq_ratio > 2) {
      lamTh.min.ratio <- 5e-4
    } else if (nq_ratio > 1) {
      lamTh.min.ratio <- 1e-3
    } else {
      lamTh.min.ratio <- 5e-3
    }
    
    lamTh.min.ratio <- lamTh.min.ratio * (1 + sparsity_Th * 0.5)
  }
  
  # Validate min ratios
  lamB.min.ratio <- max(1e-6, min(0.1, lamB.min.ratio))
  lamTh.min.ratio <- max(1e-6, min(0.1, lamTh.min.ratio))
  
  # -------------------- Adaptive Grid Sizes --------------------
  
  if (is.null(n.lamB)) {
    # Smart grid size based on problem scale
    if (p * q > 50000) {
      n.lamB <- 20  # Fewer points for very large problems
    } else if (p * q > 10000) {
      n.lamB <- 30
    } else if (p * q > 1000) {
      n.lamB <- 40
    } else if (np_ratio > 2) {
      n.lamB <- 50  # More points when we have enough data
    } else {
      n.lamB <- 30  # Conservative for high-dim
    }
  }
  
  if (is.null(n.lamTh)) {
    if (q > 200) {
      n.lamTh <- 20
    } else if (q > 100) {
      n.lamTh <- 30
    } else if (nq_ratio > 2) {
      n.lamTh <- 50
    } else {
      n.lamTh <- 30
    }
  }
  
  # -------------------- Multi-Resolution Grid Generation --------------------
  
  generate_adaptive_grid <- function(max_val, min_ratio, n_points,
                                     sparsity_est = 0.8, high_dim = FALSE,
                                     seq = "beta") {

    if (!is.finite(max_val) || max_val <= 0) {
      warning("Invalid max lambda, using 1.0")
      max_val <- 1
    }

    if (n_points <= 1) return(max_val)

    log_max <- log10(max_val)
    log_min <- log10(max_val * min_ratio)

    if (n_points <= 10) {
      # Simple geometric sequence for small grids
      return(10^seq(log_max, log_min, length.out = n_points))
    }
    
    if (seq == "beta") {
      # Estimate where optimal lambda might be
      if (high_dim || init.obj$overall_miss > 0.25) {
        # High-dim: optimal lambda typically in upper range
        optimal_region_center <- log_max - (log_max - log_min) * 0.3
        optimal_region_width <- (log_max - log_min) * 0.4
        
        n_zone1 <- ceiling(n_points * 0.10)
        n_zone3 <- ceiling(n_points * 0.35)
        
      } else {
        # Low-dim: optimal lambda typically in middle range
        optimal_region_center <- log_max - (log_max - log_min) * 0.4
        optimal_region_width <- (log_max - log_min) * 0.6
        
        n_zone1 <- ceiling(n_points * 0.10)
        n_zone3 <- ceiling(n_points * 0.21)
      }
      
    } else {
      optimal_region_center <- log_max - (log_max - log_min) * 0.4
      optimal_region_width <- (log_max - log_min) * 0.6
      
      n_zone1 <- ceiling(n_points * 0.10)
      n_zone3 <- ceiling(n_points * 0.21)
    }
    
    # Adjust based on expected sparsity
    optimal_region_center <- optimal_region_center +
      (log_max - optimal_region_center) * sparsity_est * 0.3
    
    # Three-zone strategy
    zone1_end <- optimal_region_center + optimal_region_width / 2  # Upper zone
    zone3_start <- optimal_region_center - optimal_region_width / 2  # Lower zone
    
    n_zone2 <- n_points - n_zone1 - n_zone3
    
    # Generate points in each zone
    if (zone1_end < log_max) {
      zone1 <- seq(log_max, zone1_end, length.out = max(2, n_zone1))
    } else {
      zone1 <- numeric(0)
      n_zone2 <- n_zone2 + n_zone1
    }
    
    if (zone3_start > log_min) {
      zone3 <- seq(zone3_start, log_min, length.out = max(2, n_zone3))
    } else {
      zone3 <- numeric(0)
      n_zone2 <- n_zone2 + n_zone3
    }
    
    zone2 <- seq(min(zone1_end, log_max),
                 max(zone3_start, log_min),
                 length.out = max(2, n_zone2))
    
    # Combine and sort
    all_points <- unique(sort(c(zone1, zone2, zone3), decreasing = TRUE))
    
    # Ensure we have exactly n_points
    if (length(all_points) > n_points) {
      # Subsample to get exactly n_points, keeping extremes
      keep_idx <- round(seq(1, length(all_points), length.out = n_points))
      all_points <- all_points[keep_idx]
    } else if (length(all_points) < n_points) {
      # Add more points by interpolation
      n_add <- n_points - length(all_points)
      add_points <- seq(log_max, log_min, length.out = n_add + 2)[2:(n_add + 1)]
      all_points <- unique(sort(c(all_points, add_points), decreasing = TRUE))
    }
    
    return(10^all_points[1:n_points])
    
  }
  
  # -------------------- Generate Lambda Grids --------------------
  
  if (!is.null(lamB) && !is.null(lamTh)) {
    # User provided both
    lamB.vec <- sort(unique(as.numeric(lamB)), decreasing = TRUE)
    lamTh.vec <- sort(unique(as.numeric(lamTh)), decreasing = TRUE)
    
  } else if (is.null(lamB) && is.null(lamTh)) {
    # Generate both adaptively
    lamB.vec <- generate_adaptive_grid(
      (init.obj$lamB.max %||% 1) * lamB.scale.factor,
      lamB.min.ratio, n.lamB,
      sparsity_est = sparsity_B,
      high_dim = (n_eff < p * q),
      seq = "beta"
    )
    
    lamTh.vec <- generate_adaptive_grid(
      (init.obj$lamTh.max %||% 1) * lamTh.scale.factor,
      lamTh.min.ratio, n.lamTh,
      sparsity_est = sparsity_Th,
      high_dim = (nq_ratio < 1),
      seq = "theta"
    )
    
  } else if (is.null(lamB)) {
    # Only Beta grid missing
    lamTh.vec <- sort(unique(as.numeric(lamTh)), decreasing = TRUE)
    lamB.vec <- generate_adaptive_grid(
      (init.obj$lamB.max %||% 1) * lamB.scale.factor,
      lamB.min.ratio, n.lamB,
      sparsity_est = sparsity_B,
      high_dim = (n_eff < p * q),
      seq = "beta"
    )
    
  } else {
    # Only Theta grid missing
    lamB.vec <- sort(unique(as.numeric(lamB)), decreasing = TRUE)
    lamTh.vec <- generate_adaptive_grid(
      (init.obj$lamTh.max %||% 1) * lamTh.scale.factor,
      lamTh.min.ratio, n.lamTh,
      sparsity_est = sparsity_Th,
      high_dim = (nq_ratio < 1),
      seq = "theta"
    )
    
  }
  
  # -------------------- Create Output Grids --------------------
  
  # Create all combinations for grid search
  lamB.vec.long <- rep(lamB.vec, each = length(lamTh.vec))
  lamTh.vec.long <- rep(lamTh.vec, times = length(lamB.vec))
  
  return(list(
    lamB.vec = lamB.vec.long,
    lamTh.vec = lamTh.vec.long
  ))
}
