import sys
sys.path.append("../..")
import torch
from torch_fem import LaplaceElementAssembler, MassElementAssembler, Mesh,Condenser
from torch_fem.dataset import WaveMultiSinCos
from torch_fem.utils import mul, dot
if __name__ == '__main__':
dt = 0.001
c = 4.0
n = 100
torch.random.manual_seed(123456)
mesh = Mesh.gen_rectangle(chara_length=0.01)
dataset = WaveMultiSinCos(K=4, c=c)
u0 = dataset.initial_condition(mesh.points)
M_asm = MassElementAssembler.from_mesh(mesh, quadrature_order=2)
A_asm = LaplaceElementAssembler.from_assembler(M_asm)
M = M_asm(mesh.points)
A = A_asm(mesh.points)
condenser = Condenser(mesh.boundary_mask)
Us = [u0]
v0 = torch.zeros_like(u0)
A = c*c*A
K = 2 * M
F = -dt * dt * A @ u0 + 2 * M @ u0 + 2 * dt * M @ v0
K_, F_ = condenser(K, F)
U_ = K_.solve(F_)
U = condenser.recover(U_)
M_ = condenser(M)[0]
Us.append(U)
for _ in range(n-2):
U1, U2 = Us[-2:]
F = 2 * M @ U2 - M @ U1 - dt * dt * A @ U2
F_ = condenser.condense_rhs(F)
U_ = M_.solve(F_)
U = condenser.recover(U_)
Us.append(U)
Us_gt = [dataset.solution(mesh.points, dt*i) for i in range(n)]
mesh.plot({"prediction":Us, "ground truth":Us_gt},save_path="wave.mp4", backend="matplotlib", dt=dt, show_mesh=True )