Poisson Equation

Adaptive Mesh Refinement

import torch
import numpy as np
from tqdm import tqdm
from torch_fem import LaplaceElementAssembler, Mesh,  Condenser
from torch_fem.dataset import PoissonMultiFrequency
from torch_fem.visualization import StreamPlotter
import matplotlib.pyplot as plt

if __name__ == "__main__":
    torch.random.manual_seed(123456)
    mesh      = Mesh.gen_rectangle(chara_length=0.1)
    assembler = LaplaceElementAssembler.from_mesh(mesh)
    equation  = PoissonMultiFrequency(K=8)
    condenser = Condenser(mesh.boundary_mask)

    optimizer = torch.optim.Adam(mesh.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

    epoch = 100

    f = equation.initial_condition(mesh.points)
    # u = equation.solution(mesh.points)
    loss_fn = torch.nn.MSELoss()

    losses = []

    with StreamPlotter(filename="poisson.mp4") as plotter:
        plotter.draw_mesh(mesh, f)
        pbar = tqdm(total=epoch)
        for i in range(epoch):
            optimizer.zero_grad()
            K = assembler(mesh.points)
            u = K.solve(f)
            loss = loss_fn(K @ u, f)
            # TODO: why retain_graph=True?
            loss.backward(retain_graph=True)
            optimizer.step()
            scheduler.step()
            plotter.draw_mesh(mesh, f)
            pbar.set_postfix(loss=loss.item())
            pbar.update(1)
            losses.append(loss.item())

    fig, ax = plt.subplots(figsize=(12, 8))
    ax.plot(np.arange(len(losses)), losses, label="loss")
    ax.set_xlabel("epoch")
    ax.set_ylabel("loss")
    ax.legend()
    ax.set_yscale("log")
    fig.savefig("loss.png")