본문 바로가기

PS

개인용 다항식 템플릿

최근 개인 PS용 템플릿을 만들고 있는데 그중에서 다항식 관련 연산들에 관한 템플릿을 공유해보려 합니다.

최적화는 그렇게 잘되어 있지 않고, 버그가 있을 수 있습니다.

자유롭게 사용하셔도 되지만, 잘 안돌아갈수도 있습니다(...?)

지원하는 기능은 다음과 같습니다.

 

1. 사칙연산: 다항식간의 +,-,*,/ 연산을 지원합니다. 나눗셈은 일반적으로 잘 정의가 안되니 다항식 환(혹은 formal power series)이여야만 정의가 된다는것을 유의해주세요. 자세한것은 제 이전 다항식에 관한 글 참고해주세요.

2. exp, log, pow를 지원합니다. 역시 다항식 환위에서만 정의됩니다.

3. 미분,적분인 derivative, intergral을 지원합니다.

4. 다항식 값을 판단하는 eval 함수를 지원합니다. 다항식 P에 대해, P.eval(a)를 하면, P(x)에 x=a를 대입했을 때 결과값을 반환합니다. vector로 넣을 수 있고, (a0, a1, ... ) 벡터를 넣으면 P((a0,a1,...,))은 (P(a0), P(a1), ...)을 반환합니다.

5. interpolation을 지원합니다. (x_i,y_i)값들 pair를 넣으면 P(x_i) = y_i가 되도록 P를 갱신시킵니다.

6. polynomial taylor shift를 지원합니다. 다항식 P(x)가 주어질때, P.taylor_shift(c)를 하면, P(x+c)의 계수들을 return해줍니다.

 

주의 사항은 다음과 같습니다.

1. atcoder/convolution과 atcoder/modint를 include해야합니다. 만약 타 OJ에 제출하려고 하는 경우 atcoder library를 잘 convert 해야합니다.

2. 다항식환(혹은 formal power series)을 사용하고 싶은 경우, polynomial_max_size를 적당히 바꿔주세요. 이는 다항식들을 mod x^{polynomial_max_size} 위에서 보겠다는 의미입니다.

3. 조금 더 자세한 사용법이 궁금하면 main함수를 참고해주세요.

 

아래 코드를 시험해보고 싶으신 분들은 이 문제를 풀어보세요.

https://www.acmicpc.net/problem/26037

#include <atcoder/modint>
#include <atcoder/convolution>
#include <vector>
#include <assert.h>
int polynomial_max_size = 12345678;
template <typename T> // mod x^n
struct Polynomial{
	//Polynomial made by hyperbolic
    std::vector<T> V;

    Polynomial(){ }
    Polynomial(T val)
    {
        V.push_back(val);
    }
    Polynomial(std::vector<T> A)
    {
        assert(A.size()<=polynomial_max_size);
        V = A;
    }

    int size() { return V.size(); }
    void pop_back(){ V.pop_back(); }
    T back(){ return V.back(); }
    T& operator[](int ind)
    {
        assert(ind<polynomial_max_size);
        if(ind>=V.size()) V.resize(ind+1);
        return V[ind];
    }

    Polynomial operator+(Polynomial A)
    {
        Polynomial ans;
        for(int i=0;i<V.size();i++) ans[i] += V[i];
        for(int i=0;i<A.size();i++) ans[i] += A[i];
        return ans;
    }
    Polynomial operator-(Polynomial A)
    {
        Polynomial ans;
        for(int i=0;i<V.size();i++) ans[i] += V[i];
        for(int i=0;i<A.size();i++) ans[i] -= A[i];
        return ans;
    }
    Polynomial operator*(Polynomial A)
    {
        std::vector<T> res = atcoder::convolution(V,A.V);
        while(res.size()>polynomial_max_size) res.pop_back();
        return Polynomial(res);
    }
    Polynomial operator/(Polynomial A)
    {
        Polynomial B = A.inv();
        return Polynomial(V) * B;
    }
    Polynomial operator*(T val)
    {
        Polynomial ans = Polynomial(V);
        for(int i=0;i<ans.size();i++) ans[i] *= val;
        return ans;
    }
    Polynomial operator/(T val)
    {
        assert(val!=0);
        Polynomial ans = Polynomial(V);
        for(int i=0;i<ans.size();i++) ans[i] /= val;
        return ans;
    }
    Polynomial derivative(){ return derivative(polynomial_max_size); }
    Polynomial integral(){ return integral(polynomial_max_size); }
    Polynomial inv(){ return inv(polynomial_max_size); }
    Polynomial log(){ return log(polynomial_max_size); }
    Polynomial exp(){ return exp(polynomial_max_size); }
    Polynomial pow(long long int k){ return pow(k,polynomial_max_size); }
    std::pair<Polynomial,Polynomial> quotient(Polynomial G)
    {
        Polynomial F = Polynomial(V);
        int deg_f = size()-1;
        int deg_g = G.size()-1;
        if(deg_f < deg_g) return std::make_pair(Polynomial(),F);

        Polynomial F2 = F.reverse();
        Polynomial G2 = G.reverse();
        F2.truncated(deg_f-deg_g+1);
        G2.truncated(deg_f-deg_g+1);

        Polynomial Q2 = F2 * G2.inv(deg_f - deg_g + 1);
        Q2.truncated(deg_f-deg_g+1);

        Polynomial Q = Q2.reverse();
        Polynomial R = Polynomial(V) - G * Q;

        Q.remove_leading_zero();
        R.remove_leading_zero();
        return std::make_pair(Q,R);
    }
    T eval(T ind)
    {
        std::vector<T> temp;
        temp.push_back(ind);
        return eval(temp)[0];
    }
    std::vector<T> eval(std::vector<T> &ind)
    {
        std::vector<Polynomial> check(4*ind.size()+2);
        std::vector<T> ans(ind.size());
        monomial_prod(0,ind.size()-1,1,ind,check);
        multipoint_evaluation(Polynomial(V),0,ind.size()-1,1,check,ans);
        return ans;
    }
    void interpolation(std::vector< std::pair<T,T> > points)
    {
        assert(points.size()<=polynomial_max_size);
        std::vector<T> x,y;
        for(int i=0;i<points.size();i++) x.push_back(points[i]);
        for(int i=0;i<points.size();i++) y.push_back(points[i]);
        interpolation(x,y);
    }
    void interpolation(std::vector<T> x, std::vector<T> y)
    {
        assert(x.size()==y.size());
        assert(x.size()<=polynomial_max_size);
        V.clear();

        std::vector<Polynomial> check(4*x.size()+2);
        monomial_prod(0,x.size()-1,1,x,check);
        Polynomial g = check[1];
        std::vector<T> derivative_values = g.derivative().eval(x);
        std::vector<T> coeff;
        for(int i=0;i<x.size();i++) coeff.push_back(y[i]/derivative_values[i]);
        Polynomial h = interpolation_calculate(0,x.size()-1,x,coeff).second;
        for(int i=0;i<h.size();i++) V.push_back(h[i]);
    }
    Polynomial taylor_shift(T c) // coeff of f(x+c)
    {
        std::vector<T> fact;
        fact.push_back(1);
        for(int i=1;i<V.size();i++) fact.push_back(fact.back()*i);
        
        std::vector<T> A;
        for(int i=0;i<V.size();i++) A.push_back(V[i]*fact[i]);
        std::vector<T> C;
        for(int i=0;i<V.size();i++) C.push_back(c.pow(i)/fact[i]);
        std::reverse(C.begin(),C.end());

        std::vector<T> B2 = atcoder::convolution(A,C);
        std::vector<T> B;

        for(int i=V.size()-1;i<2*V.size()-1;i++) B.push_back(B2[i]);
        for(int i=0;i<V.size();i++) B[i] /= fact[i];
        Polynomial ans(B);
        ans.remove_leading_zero();
        return ans;
    }

    private:
    Polynomial derivative(int polynomial_max_size)
    {
        while(V.size()>0 && V.back()==0) V.pop_back();
        Polynomial ans;
        for(int i=1;i<V.size()&&i-1<polynomial_max_size;i++) ans[i-1] = i*V[i];
        ans.truncated(polynomial_max_size);
        return ans;
    }
    Polynomial integral(int polynomial_max_size)
    {
        while(V.size()>0 && V.back()==0) V.pop_back();
        Polynomial ans;
        for(int i=0;i<V.size()&&i+1<polynomial_max_size;i++) ans[i+1] = V[i]/(i+1);
        ans.truncated(polynomial_max_size);
        return ans;
    }
    Polynomial inv(int polynomial_max_size)
    {
        while(V.size()>0 && V.back()==0) V.pop_back();
        assert(V.size()>0);
        assert(V[0]!=0);

        Polynomial g(1/V[0]);

        while(g.size()<polynomial_max_size)
        {
            int g_size = g.size();
            Polynomial h;
            for(int i=0;i<2*g_size&&i<V.size();i++) h[i] = V[i];
            h = Polynomial(2) - g*h;
            h.truncated(2*g_size);
            g = g*h;
            g.truncated(2*g_size);
        }
        g.truncated(polynomial_max_size);
        g.remove_leading_zero();
        return g;
    }
    Polynomial log(int polynomial_max_size)
    {
        while(V.size()>0 && V.back()==0) V.pop_back();
        assert(V.size()>0);
        assert(V[0]==1);

        if(V.size()==1) return Polynomial(0);
        else
        {
            Polynomial f_2 = Polynomial();
            for(int i=0;i<V.size() && i<polynomial_max_size;i++) f_2[i] = V[i];
            Polynomial f_1 = f_2.derivative(polynomial_max_size);
            Polynomial f_3 = f_2.inv(polynomial_max_size);
            Polynomial g = f_1 * f_3;
            g.truncated(polynomial_max_size);
            g = g.integral(polynomial_max_size);
            g[0] = 0;
            g.remove_leading_zero();
            return g;
        }
    }
    Polynomial exp(int polynomial_max_size)
    {
        while(V.size()>0 && V.back()==0) V.pop_back();
        if(V.size()==0) return Polynomial(1);
        assert(V[0]==0);

        Polynomial g(1);
        while(g.size()<polynomial_max_size)
        {
            int g_size = g.size();
            Polynomial h;
            for(int i=0;i<2*g_size&&i<V.size();i++) h[i] = V[i];
            h = h + Polynomial(1) - g.log(2*g_size<polynomial_max_size?2*g_size:polynomial_max_size);
            h.truncated(2*g_size);
            g = g*h;
            g.truncated(2*g_size);
        }
        g.truncated(polynomial_max_size);
        g.remove_leading_zero();
        return g;
    }
    Polynomial pow(long long int k, int polynomial_max_size)
    {
        remove_leading_zero();
        if(k==0) return Polynomial(1);
        int start = -1;
        for(int i=0;i<V.size();i++)
        {
            if(V[i]!=0)
            {
                start = i;
                break;
            }
        }
        if(start==-1) return Polynomial();
        if(start>0)
        {
            if(k>=polynomial_max_size) return Polynomial();
            else if(k*start>=polynomial_max_size) return Polynomial();
        }

        Polynomial P;
        for(int i=start;i<V.size()&&i-start<polynomial_max_size;i++) P[i-start] = V[i];
        T coeff = P[0];
        P = P / coeff;

        P = P.log(polynomial_max_size);
        P = P * k;
        P = P.exp(polynomial_max_size);
        P = P * coeff.pow(k);

        Polynomial ans;
        long long int shift = k*start;
        for(int i=0;i<P.size() && shift+i<polynomial_max_size;i++) ans[shift+i] = P[i];
        ans.truncated(polynomial_max_size);
        ans.remove_leading_zero();
        return ans;
    }
    Polynomial reverse()
    {
        std::vector<T> temp;
        for(int i=(int)V.size()-1;i>=0;i--) temp.push_back(V[i]);
        return Polynomial(temp);
    }
    void truncated(int k)
    {
        while(V.size()>k) V.pop_back();
        while(V.size()<k) V.push_back(0);
    }
    void remove_leading_zero()
    {
        while(V.size()>0 && V.back()==0) V.pop_back();
    }
    void monomial_prod(int l, int r, int v, std::vector<T> &ind, std::vector<Polynomial> &check)
    {
        if(l==r)
        {
            check[v][0] = -ind[l];
            check[v][1] = 1;
        }
        else
        {
            int h = (l+r)/2;
            monomial_prod(l,h,2*v,ind,check);
            monomial_prod(h+1,r,2*v+1,ind,check);
            check[v] = check[2*v] * check[2*v+1];
        }
    }
    void multipoint_evaluation(Polynomial P, int l, int r, int v, std::vector<Polynomial>& check, std::vector<T> &ans)
    {
        P = P.quotient(check[v]).second;
        if(l==r) ans[l] = P[0];
        else
        {
            int h = (l+r)/2;
            multipoint_evaluation(P,l,h,2*v,check,ans);
            multipoint_evaluation(P,h+1,r,2*v+1,check,ans);
        }
    }
    std::pair<Polynomial,Polynomial> interpolation_calculate(int l, int r, std::vector<T> &x, std::vector<T> &coeff)
    {
        if(l==r)
        {
            Polynomial ans;
            ans[0] = -x[l];
            ans[1] = 1;
            return std::make_pair(ans,Polynomial(coeff[l]));
        }
        else
        {
            int h = (l+r)/2;
            auto P1 = interpolation_calculate(l,h,x,coeff);
            auto P2 = interpolation_calculate(h+1,r,x,coeff);
            Polynomial ans1 = P1.first * P2.first;
            Polynomial ans2 = P1.first * P2.second + P1.second * P2.first;
            return std::make_pair(ans1,ans2);
        }
    }
};
using mint = atcoder::modint998244353;

int main()
{
    Polynomial<mint> P; // p(x) = 0+x+2x^2+3x^3+4x^4
    for(int i=0;i<5;i++) P[i] = i;
    Polynomial<mint> Q(3); // q(x) = 3
    Polynomial<mint> mul_pq = P*Q;
    Polynomial<mint> div_pq = P/Q;
    Polynomial<mint> add_pq = P+Q;
    Polynomial<mint> sub_pq = P-Q;
    Polynomial<mint> exp_p = P.exp(); // e^p(x)
    Polynomial<mint> log_p = P.log(); // log(p(x))
    Polynomial<mint> pow_p = P.pow(3); // p(x)^3
    std::vector<mint> x;
    for(int i=1;i<=3;i++) x.push_back(i);
    std::vector<mint> eval_p = P.eval(x); // return p(x_0), p(x_1), p(x_2), ...
    std::vector<mint> y;
    for(int i=1;i<=3;i++) y.push_back(100*i);
    Polynomial<mint> S;
    S.interpolation(x,y); // make S satisfying S(x_i) = y_i
    
}

 

'PS' 카테고리의 다른 글

Power Projection Algorithm  (0) 2024.05.10
Half GCD Algorithm  (0) 2023.10.05
롤링 해시를 할때 주의해야 할점  (0) 2023.03.05
다항식 연산들로 할 수 있는 것  (0) 2022.10.13
어려운 다항식 연산들에 대하여  (0) 2022.10.05