#pragma once

#include <vector>
#include <algorithm>
#include <cmath>
#include "floatcascade.h" // Assumes the previous FloatCascade design

namespace universal {

// Adaptive precision configuration
struct AdaptivePrecision {
    double absolute_tolerance = 0.0;
    double relative_tolerance = 1e-15;
    size_t max_components = 50;
    size_t min_components = 2;
    bool auto_compress = true;
    
    // Termination criteria for operations like division
    size_t max_iterations = 100;
    bool early_termination = true;
};

// Variable-length FloatCascade - the dynamic version of FloatCascade<N>
class VariableCascade {
private:
    std::vector<double> components;  // Still in increasing magnitude order
    
public:
    // Constructors
    VariableCascade() : components() {}
    
    explicit VariableCascade(double x) : components({x}) {}
    
    explicit VariableCascade(const std::vector<double>& comps) : components(comps) {}
    
    // Construct from fixed-size FloatCascade
    template<size_t N>
    explicit VariableCascade(const FloatCascade<N>& fc) : components(N) {
        for (size_t i = 0; i < N; ++i) {
            components[i] = fc[i];
        }
        remove_zeros();  // Clean up
    }
    
    // Extract to fixed-size FloatCascade (truncate or zero-pad as needed)
    template<size_t N>
    FloatCascade<N> to_cascade() const {
        FloatCascade<N> result;
        size_t copy_size = std::min(N, components.size());
        
        for (size_t i = 0; i < copy_size; ++i) {
            result[i] = components[i];
        }
        // Remaining components are already zero-initialized
        
        return result;
    }
    
    // Dynamic access
    double operator[](size_t i) const { 
        return (i < components.size()) ? components[i] : 0.0; 
    }
    
    double& operator[](size_t i) { 
        if (i >= components.size()) {
            components.resize(i + 1, 0.0);
        }
        return components[i]; 
    }
    
    size_t size() const { return components.size(); }
    const std::vector<double>& data() const { return components; }
    
    // Estimation
    double to_double() const {
        double sum = 0.0;
        for (double comp : components) {
            sum += comp;
        }
        return sum;
    }
    
    // Precision management
    void reserve(size_t n) { components.reserve(n); }
    
    void resize(size_t n) { 
        components.resize(n, 0.0); 
    }
    
    void remove_zeros() {
        components.erase(
            std::remove(components.begin(), components.end(), 0.0),
            components.end()
        );
    }
    
    void compress(double threshold = 0.0) {
        if (threshold == 0.0) {
            threshold = std::abs(to_double()) * 1e-16;
        }
        
        // Remove components smaller than threshold
        components.erase(
            std::remove_if(components.begin(), components.end(),
                [threshold](double x) { return std::abs(x) < threshold; }),
            components.end()
        );
        
        // Re-sort by magnitude
        std::sort(components.begin(), components.end(),
            [](double a, double b) { return std::abs(a) < std::abs(b); });
    }
    
    bool is_zero() const {
        return components.empty() || 
               std::all_of(components.begin(), components.end(),
                   [](double x) { return x == 0.0; });
    }
    
    int sign() const {
        for (auto it = components.rbegin(); it != components.rend(); ++it) {
            if (*it > 0.0) return 1;
            if (*it < 0.0) return -1;
        }
        return 0;
    }
};

// The adaptive precision priest class
class priest {
private:
    VariableCascade cascade;
    AdaptivePrecision config;
    
    // Workspace for intermediate calculations
    mutable std::vector<double> workspace;
    
public:
    // Constructors
    priest() : cascade(), config() {}
    
    explicit priest(double x, const AdaptivePrecision& cfg = {}) 
        : cascade(x), config(cfg) {}
    
    priest(const VariableCascade& vc, const AdaptivePrecision& cfg = {})
        : cascade(vc), config(cfg) {}
    
    // Construct from fixed-precision types - this is the key interoperability!
    explicit priest(const dd& d, const AdaptivePrecision& cfg = {})
        : cascade(d.get_cascade()), config(cfg) {}
    
    explicit priest(const td& t, const AdaptivePrecision& cfg = {})
        : cascade(t.get_cascade()), config(cfg) {}
    
    template<size_t N>
    explicit priest(const FloatCascade<N>& fc, const AdaptivePrecision& cfg = {})
        : cascade(fc), config(cfg) {}
    
    // Extract to fixed-precision types
    dd to_dd() const { return dd(cascade.to_cascade<2>()); }
    td to_td() const { return td(cascade.to_cascade<3>()); }
    
    template<size_t N>
    FloatCascade<N> to_cascade() const { return cascade.to_cascade<N>(); }
    
    // Conversion operators
    explicit operator double() const { return cascade.to_double(); }
    explicit operator dd() const { return to_dd(); }
    explicit operator td() const { return to_td(); }
    
    // Precision management
    void set_precision(const AdaptivePrecision& cfg) { config = cfg; }
    const AdaptivePrecision& get_precision() const { return config; }
    
    priest& compress() {
        cascade.compress();
        return *this;
    }
    
    size_t num_components() const { return cascade.size(); }
    
    // Adaptive Addition - this is where the magic happens!
    priest operator+(const priest& other) const {
        // Start with basic addition
        priest result = basic_add(*this, other);
        
        // Check if we meet precision requirements
        if (config.auto_compress && result.num_components() > config.min_components) {
            result.compress();
        }
        
        // Ensure we don't exceed max components
        if (result.num_components() > config.max_components) {
            result = truncate_to_precision(result, config.max_components);
        }
        
        return result;
    }
    
    // Adaptive Division - the 1/3 solver!
    priest operator/(const priest& other) const {
        return adaptive_divide(*this, other, config);
    }
    
    // This is how we handle 1.0/3.0 with controlled precision
    static priest adaptive_divide(const priest& dividend, const priest& divisor, 
                                  const AdaptivePrecision& cfg) {
        // Start with hardware precision estimate
        double approx = dividend.cascade.to_double() / divisor.cascade.to_double();
        priest result(approx, cfg);
        
        size_t iteration = 0;
        while (iteration < cfg.max_iterations) {
            // Newton-Raphson iteration: x_{n+1} = x_n * (2 - d * x_n)
            priest product = result * divisor;
            priest two_minus_product = priest(2.0, cfg) - product;
            priest new_result = result * two_minus_product;
            
            // Check convergence
            priest error = dividend - new_result * divisor;
            double abs_error = std::abs(error.cascade.to_double());
            double rel_error = abs_error / std::abs(dividend.cascade.to_double());
            
            if ((abs_error <= cfg.absolute_tolerance) || 
                (rel_error <= cfg.relative_tolerance)) {
                if (cfg.early_termination) {
                    return new_result.compress();
                }
            }
            
            // Check if we're making progress
            priest improvement = new_result - result;
            if (std::abs(improvement.cascade.to_double()) < 1e-17) {
                break;  // No more improvement possible
            }
            
            result = new_result;
            iteration++;
            
            // Limit component growth
            if (result.num_components() > cfg.max_components) {
                result = truncate_to_precision(result, cfg.max_components);
            }
        }
        
        return result.compress();
    }
    
    // Properties
    bool is_zero() const { return cascade.is_zero(); }
    int sign() const { return cascade.sign(); }
    
    // Debug info
    friend std::ostream& operator<<(std::ostream& os, const priest& p) {
        os << "priest(" << p.cascade.size() << " components: ";
        const auto& data = p.cascade.data();
        for (size_t i = 0; i < data.size(); ++i) {
            if (i > 0) os << ", ";
            os << data[i];
        }
        os << ") ≈ " << p.cascade.to_double();
        return os;
    }
    
private:
    // Helper methods
    static priest basic_add(const priest& a, const priest& b) {
        // Merge components and use expansion sum
        VariableCascade result_cascade;
        
        // This is a simplified version - full implementation would use
        // proper Priest/Shewchuk expansion sum algorithms
        std::vector<double> all_components;
        all_components.insert(all_components.end(), 
            a.cascade.data().begin(), a.cascade.data().end());
        all_components.insert(all_components.end(),
            b.cascade.data().begin(), b.cascade.data().end());
        
        // Sort by magnitude and accumulate
        std::sort(all_components.begin(), all_components.end(),
            [](double x, double y) { return std::abs(x) < std::abs(y); });
        
        std::vector<double> result_comps;
        double sum = 0.0;
        
        for (double comp : all_components) {
            double new_sum, error;
            expansion_ops::two_sum(sum, comp, new_sum, error);
            if (error != 0.0) {
                result_comps.push_back(error);
            }
            sum = new_sum;
        }
        
        if (sum != 0.0) {
            result_comps.push_back(sum);
        }
        
        return priest(VariableCascade(result_comps), a.config);
    }
    
    static priest truncate_to_precision(const priest& p, size_t max_comps) {
        if (p.cascade.size() <= max_comps) return p;
        
        VariableCascade truncated;
        truncated.resize(max_comps);
        
        for (size_t i = 0; i < max_comps - 1; ++i) {
            truncated[i] = p.cascade[i];
        }
        
        // Sum remaining components into the last one
        double remainder = 0.0;
        for (size_t i = max_comps - 1; i < p.cascade.size(); ++i) {
            remainder += p.cascade[i];
        }
        truncated[max_comps - 1] = remainder;
        
        return priest(truncated, p.config);
    }
};

} // namespace universal

// Demonstration of the adaptive precision design
#ifdef PRIEST_ADAPTIVE_TEST
#include <iostream>
#include <iomanip>

int main() {
    using namespace universal;
    
    std::cout << std::setprecision(20);
    
    // Start with fixed precision types
    dd d1(1.0);
    td t1(3.0);
    
    std::cout << "Original dd: " << d1 << std::endl;
    std::cout << "Original td: " << t1 << std::endl;
    
    // Promote to adaptive precision
    priest p1(d1);  // From dd
    priest p3(t1);  // From td
    
    std::cout << "\nPromoted to priest:" << std::endl;
    std::cout << "priest from dd: " << p1 << std::endl;
    std::cout << "priest from td: " << p3 << std::endl;
    
    // Configure high precision for 1/3 calculation
    AdaptivePrecision high_precision{
        .absolute_tolerance = 1e-30,
        .relative_tolerance = 1e-25,
        .max_components = 20,
        .max_iterations = 50
    };
    
    // The magic: adaptive precision 1/3
    priest one(1.0, high_precision);
    priest three(3.0, high_precision);
    priest one_third = one / three;
    
    std::cout << "\nAdaptive 1/3 calculation:" << std::endl;
    std::cout << "1/3 = " << one_third << std::endl;
    std::cout << "Components: " << one_third.num_components() << std::endl;
    std::cout << "As double: " << double(one_third) << std::endl;
    
    // Convert back to fixed precision types
    dd result_dd = one_third.to_dd();
    td result_td = one_third.to_td();
    
    std::cout << "\nConverted back:" << std::endl;
    std::cout << "As dd: " << result_dd << std::endl;
    std::cout << "As td: " << result_td << std::endl;
    
    // Demonstrate precision control
    AdaptivePrecision low_precision{
        .relative_tolerance = 1e-10,
        .max_components = 3
    };
    
    priest quick_third = priest(1.0, low_precision) / priest(3.0, low_precision);
    std::cout << "\nLow precision 1/3: " << quick_third << std::endl;
    
    return 0;
}
#endif