Browse Source

Barnes hut and other changes

master
Stephen 7 months ago
parent
commit
b0d7d5d947
5 changed files with 501 additions and 55 deletions
  1. +6
    -1
      Cargo.toml
  2. +205
    -0
      src/barnes_hut.rs
  3. +218
    -54
      src/main.rs
  4. +33
    -0
      src/rigid_point.rs
  5. +39
    -0
      src/tests.rs

+ 6
- 1
Cargo.toml View File

@ -10,4 +10,9 @@ edition = "2018"
kiss3d = "0.24.1"
nalgebra = "0.21"
rand = "0.7"
rayon = "1.3"
rayon = "1.3"
[target.x86_64-unknown-linux-gnu]
rustflags = [
"-C", "link-arg=-fuse-ld=lld",
]

+ 205
- 0
src/barnes_hut.rs View File

@ -0,0 +1,205 @@
extern crate nalgebra as na;
use na::{Vector3};
use crate::rigid_point::RigidPoint;
const G: f64 = 0.00667408;
const MAX_DEPTH: i32 = 100;
pub struct Octree {
nodes: Vec<OctNode>
}
pub struct OctNode {
parent: isize, // -1 = no parent
children: [isize; 8],
mass: f64,
min_coord: Vector3<f64>,
max_coord: Vector3<f64>,
num_particles: usize,
avg_weighted_position: Vector3<f64>,// Divide by mass for this to be accurate!
center_of_mass: Vector3<f64> // Calculated once from avg_weighted_position
}
impl Octree {
pub fn new(min: Vector3<f64>, max: Vector3<f64>) -> Self {
let mut nodes = Vec::new();
nodes.push(OctNode::new(min, max, -1)); // Create our root node
Octree { nodes: nodes }
}
pub fn construct_tree(&mut self, particles: Vec<RigidPoint>) {
self.add_particles_to_specific_node(particles, 0);
//calculate center of masses
for n in &mut self.nodes {
n.center_of_mass = n.avg_weighted_position / n.mass;
}
}
pub fn barnes_hut(&self, particle: &mut RigidPoint, delta: f64) {
self.barnes_hut_specific_node(particle, 0, delta);
}
fn barnes_hut_specific_node(&self,
particle: &mut RigidPoint,
cur_node: usize,
delta: f64) {
if self.nodes[cur_node].num_particles == 0 { return; }
let min_coord = self.nodes[cur_node].min_coord;
let max_coord = self.nodes[cur_node].max_coord;
let width = max_coord.x - min_coord.x;
let height = max_coord.y - min_coord.y;
let s = if width > height { width } else { height };
let s_squared = s * s;
//let s_squared = (self.nodes[cur_node].max_coord -
// self.nodes[cur_node].min_coord).norm_squared(); // width^2
let dis_vec = self.nodes[cur_node].center_of_mass - particle.position;
let d_squared = dis_vec.norm_squared();
let theta_squared = 0.5 * 0.5; // TODO make this const
let ratio = s_squared / d_squared;
if self.nodes[cur_node].num_particles == 1 || ratio < theta_squared {
//We can approximate
if(d_squared < 0.1) { return; }
let other = &self.nodes[cur_node];
let dist = d_squared.sqrt();
let v = G * other.mass * delta / d_squared; // / (d_squared + 0.1).sqrt();
particle.velocity += v * dis_vec / dist;
}
else {
//println!("A{}", self.nodes[cur_node].num_particles);
//Divide and conquer
for i in 0..8 {
self.barnes_hut_specific_node(particle, self.nodes[cur_node].children[i] as usize, delta);
}
}
}
fn add_particles_to_specific_node(&mut self, particles: Vec<RigidPoint>, cur_node: usize) {
if particles.len() == 0 {
return;
}
if particles.len() == 1 {
self.nodes[cur_node].mass += particles[0].mass;
self.nodes[cur_node].num_particles += 1;
self.nodes[cur_node].avg_weighted_position += particles[0].position * particles[0].mass;
return;
}
let center_point = (self.nodes[cur_node].min_coord + self.nodes[cur_node].max_coord) / 2.0;
if self.nodes[cur_node].children[0] == -1 {
for i in 0..8 {
let idx = self.nodes.len();
let (min, max) = Octree::get_octant_bounding_box_from_id(self.nodes[cur_node].min_coord,
self.nodes[cur_node].max_coord,
center_point,
i);
self.nodes.push(OctNode::new(min, max, cur_node as isize));
self.nodes[cur_node].children[i] = idx as isize;
}
}
else {
println!("This should never happen :S");
}
let mut particle_octs: [Vec<RigidPoint>; 8] = [vec![], vec![], vec![], vec![],
vec![], vec![], vec![], vec![]];
for particle in particles {
/*println!("// {} {} {} //", particle.position, center_point, self.nodes[cur_node].max_coord);
let mut input = String::new();
io::stdin().read_line(&mut input);*/
// Update ourself
self.nodes[cur_node].mass += particle.mass;
self.nodes[cur_node].num_particles += 1;
self.nodes[cur_node].avg_weighted_position += particle.position * particle.mass;
let octant_index = Octree::get_id_from_center(center_point, particle.position);
particle_octs[octant_index].push(particle);
}
//Recurse
for (i, particle_oct) in particle_octs.iter().enumerate() {
self.add_particles_to_specific_node(particle_oct.clone(),
self.nodes[cur_node].children[i] as usize);
}
}
// TODO wire this up
pub fn get_id_from_center(center: Vector3<f64>, point: Vector3<f64>) -> usize {
let offset = point - center;
//We can look at the sign of the components of the offset to figure out which octant it is in.
let x_offset = if offset.x > 0.0 { 0 } else { 1 };
let y_offset = if offset.y > 0.0 { 0 } else { 1 };
let z_offset = if offset.z > 0.0 { 0 } else { 1 };
x_offset * 1 + y_offset * 2 + z_offset * 4 // basic binary stuff here
}
pub fn get_octant_bounding_box_from_id(min: Vector3<f64>,
max: Vector3<f64>,
center: Vector3<f64>,
idx: usize) -> (Vector3<f64>, Vector3<f64>) {
let mut min_coord: Vector3<f64> = Vector3::new(0.0, 0.0, 0.0);
let mut max_coord: Vector3<f64> = Vector3::new(0.0, 0.0, 0.0);
if (idx & 1) != 0 {
min_coord.x = min.x;
max_coord.x = center.x;
}
else {
min_coord.x = center.x;
max_coord.x = max.x;
}
if (idx & 2) != 0 {
min_coord.y = min.y;
max_coord.y = center.y;
}
else {
min_coord.y = center.y;
max_coord.y = max.y;
}
if (idx & 4) != 0 {
min_coord.z = min.z;
max_coord.z = center.z;
}
else {
min_coord.z = center.z;
max_coord.z = max.z;
}
(min_coord, max_coord)
}
pub fn check_for_collision(&self, particle: RigidPoint) {
check_for_collision_with_specific_node(particle, 0);
}
fn check_for_collision_with_specific_node(&self, particle: RigidPoint, idx: usize) {
let cur_node = self.nodes[idx];
}
}
impl OctNode {
fn new(min: Vector3<f64>, max: Vector3<f64>, parent: isize) -> Self {
OctNode {
parent: parent,
children: [-1, -1, -1, -1, -1, -1, -1, -1],
mass: 0.0,
min_coord: min,
max_coord: max,
num_particles: 0,
avg_weighted_position: Vector3::new(0.0, 0.0, 0.0),
center_of_mass: Vector3::new(0.0, 0.0, 0.0)
}
}
}

+ 218
- 54
src/main.rs View File

@ -1,122 +1,286 @@
extern crate kiss3d;
extern crate nalgebra as na;
use na::{Vector3, UnitQuaternion, Translation3};
use na::{Vector3, Translation3, Point3, Point2};
use kiss3d::window::Window;
use kiss3d::light::Light;
use kiss3d::scene::SceneNode;
use kiss3d::text::Font;
use rand::prelude::*;
use std::time::Instant;
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct RigidPoint {
position: Vector3<f32>,
velocity: Vector3<f32>,
mass: f64,
index: usize
}
mod barnes_hut;
mod rigid_point;
impl RigidPoint {
fn new(position: Vector3<f32>, velocity: Vector3<f32>, mass: f64, index: usize) -> Self {
RigidPoint {
position: position,
velocity: velocity,
mass: mass,
index: index
}
}
use rigid_point::RigidPoint;
use barnes_hut::Octree;
fn update_node(&self, node: &mut SceneNode) {
node.set_local_translation(Translation3::from(self.position));
}
}
const G: f64 = 0.00667408;
#[derive(Clone)]
struct Line {
start: Point3<f32>,
end: Point3<f32>
}
struct Universe {
particles: Vec<RigidPoint>,
nodes: Vec<SceneNode>
nodes: Vec<SceneNode>,
lines: Vec<Line>,
delta: f64,
}
impl Universe {
fn new() -> Self {
Universe { particles: Vec::new(), nodes: Vec::new() }
fn new(delta: f64) -> Self {
Universe { particles: Vec::new(), nodes: Vec::new(), lines: Vec::new(), delta: delta }
}
fn push(&mut self, particle: RigidPoint, node: SceneNode) {
self.particles.push(particle);
self.nodes.push(node)
}
fn add_line(&mut self, start: Point3<f32>, end: Point3<f32>) {
self.lines.push(Line {start: start, end: end});
}
fn render(&self, window: &mut Window) {
for line in &self.lines {
window.draw_line(&line.start, &line.end, &Point3::new(0.0, 1.0, 0.0));
}
}
//Not completely correct, since the velocity and position are at slightly different times
fn calc_total_energy(&self) -> f64 {
let mut energy: f64 = 0.0;
for a in self.particles.clone() {
//Calculate kinetic energy
energy += 0.5 * a.mass * a.velocity.norm_squared();
//Calculate gravitational potential energy
for b in &self.particles {
if a.index != b.index {
let dist = na::distance(&Point3::from(a.position), &Point3::from(b.position));
let energy_one = -G * b.mass * a.mass / dist;
// println!("dist: {} energy: {}", dist, energy_one);
energy += energy_one;
}
}
}
energy
}
fn combine(&mut self, i: usize, j: usize, window: &mut Window) {
let a = self.particles[i].clone();
let b = self.particles[j].clone();
if a.index == -1 || b.index == -1 {
return;
}
let mass_total = a.mass + b.mass;
//Combine positions(weighted)
let new_pos = (a.position * a.mass + b.position * b.mass) / mass_total;
//Combine velocities (weighted)
let new_vel = (a.velocity * a.mass + b.velocity * b.mass) / mass_total;
//Calculate new radius
let r1 = a.radius;
let r2 = b.radius;
let new_r = (r1 * r1 * r1 + r2 * r2 * r2).cbrt();
//Destroy i
self.nodes[i].unlink();
self.particles[i].index = -1;
//Update j
self.nodes[j].unlink();
let mut new_node = window.add_sphere(new_r as f32);
new_node.set_color(1.0, 0.0, 0.0);
new_node.append_translation(&Translation3::from(b.position_f32()));
self.nodes[j] = new_node;
self.particles[j].position = new_pos;
self.particles[j].velocity = new_vel;
self.particles[j].radius = new_r;
self.particles[j].mass = mass_total;
}
}
fn main() {
println!("{}", std::mem::size_of::<RigidPoint>());
let mut window = Window::new("Galaxy simulator");
window.set_light(Light::StickToCamera);
let mut particles = generate_random_points(&mut window);
let delta = 0.005;
let mut universe = generate_random_points(&mut window, delta);
let mut start_time = Instant::now();
//let mut start_time = Instant::now();
while window.render() {
calc_velocities(&mut particles, start_time.elapsed().as_secs_f64());
tick(&mut particles);
start_time = Instant::now();
for _ in 0..10 {
tick(&mut universe, &mut window);
calc_velocities(&mut universe, delta);
}
universe.render(&mut window);
//Calculate energy. Used to tell how much the algorithm sucks
/*window.draw_text(&format!("Total energy: {}", universe.calc_total_energy()),
&Point2::new(0.0, 0.0),
120.0,
&Font::default(),
&Point3::new(1.0, 1.0, 1.0));*/
//start_time = Instant::now();
}
}
fn generate_random_points(window: &mut Window) -> Universe {
let mut ret: Universe = Universe::new();
fn generate_random_points(window: &mut Window, delta: f64) -> Universe {
let mut ret: Universe = Universe::new(delta);
let mut rng = rand::thread_rng();
for i in 0..3000 {
let x: f32 = rng.gen_range(-50.0, 50.0);
let y: f32 = rng.gen_range(-50.0, 50.0);
let z: f32 = rng.gen_range(-5.0, 5.0); // Flat galaxy
let position: Vector3<f32> = Vector3::new(x, y, z);
let velocity: Vector3<f32> = Vector3::new(-z, x, y);
for i in 0..5_000 {
let x: f64 = rng.gen_range(-50.0, 50.0);
let y: f64 = rng.gen_range(-50.0, 50.0);
let z: f64 = rng.gen_range(-5.0, 5.0); // Flat galaxy
let position: Vector3<f64> = Vector3::new(x, y, z);
let velocity: Vector3<f64> = Vector3::new(-y, x, 0.0);
// let velocity: Vector3<f64> = Vector3::new(0.0, 0.0, 0.0);
let node = window.add_sphere(0.1);
ret.push(RigidPoint::new(position, velocity / 1000.0, 1.0, i), node);
ret.push(RigidPoint::new(position, velocity / 60.0, 80.0, i, 0.1), node);
}
//Move velocities forward slightly
calc_velocities(&mut ret, delta / 2.0);
// ret.push(RigidPoint::new(Vector3::new(0.0, 0.0, 0.0), Vector3::new(0.0, 0.0, 0.0), 1000.0, 1000), window.add_sphere(0.1)); //Black hole
return ret;
}
const G: f32 = 0.00667408;
fn calc_velocities(u: &mut Universe, delta: f64) {
let mut p = u.particles.clone();
//Find min/max
//println!("Finding min/max");
let mut min: Vector3<f64> = Vector3::new(0.0, 0.0, 0.0);
let mut max: Vector3<f64> = Vector3::new(0.0, 0.0, 0.0);
for particle in &u.particles {
if particle.position.x < min.x {
min.x = particle.position.x;
}
if particle.position.y < min.y {
min.y = particle.position.y;
}
if particle.position.z < min.z {
min.z = particle.position.z;
}
if particle.position.x > max.x {
max.x = particle.position.x;
}
if particle.position.y > max.y {
max.y = particle.position.y;
}
if particle.position.z > max.z {
max.z = particle.position.z;
}
}
// println!("Constructing tree");
let mut oct = Octree::new(min, max);
oct.construct_tree(u.particles.clone().into_iter().filter(|x| {x.index != -1}).collect());
//println!("Barnes hut");
u.particles.par_iter_mut().for_each(|mut p| {
oct.barnes_hut(&mut p, delta);
});
//println!("Done");
}
/*fn calc_velocities(u: &mut Universe, delta: f64) {
let p = u.particles.clone();
u.particles.par_iter_mut().for_each(|mut b| {
if b.index == -1 {
return;
}
p.iter().for_each(|a| {
calc_gravity(&mut b, &a, delta);
});
});
}
}*/
fn calc_gravity(body: &mut RigidPoint, other: &RigidPoint, delta: f64) {
if body.index == other.index || body.index == -1 || other.index == -1 { // can't be the same
return;
}
let dx = body.position.x - other.position.x;
let dy = body.position.y - other.position.y;
let dz = body.position.z - other.position.z;
let dist_squared = dx * dx + dy * dy + dz * dz;
if dist_squared < 0.01 {
// don't calculate gravity for stars that are really close
return;
}
let dist = dist_squared.sqrt();
let v = -G * (other.mass as f32) / dist_squared * delta as f32;
let hyp = dist_squared.sqrt();
// let v = -G * other.mass / dist_squared * delta;
let v = -G * other.mass * delta / dist / (dist_squared + 0.1).sqrt(); // softening
let hyp = dist;
body.velocity += Vector3::new(v * dx / hyp, v * dy / hyp, v * dz / hyp);
}
fn tick(u: &mut Universe) {
fn tick(u: &mut Universe, window: &mut Window) {
let delta = u.delta;
u.particles.par_iter_mut().for_each(|a| {
a.position += a.velocity;
if a.index == -1 {
return;
}
a.position += a.velocity * delta;
});
u.particles.clone().iter().for_each(|a| {
u.nodes[a.index].set_local_translation(Translation3::from(a.position));
if a.index != -1 {
a.update_node(&mut u.nodes[a.index as usize]);
}
});
println!("tick");
// Add line for particle 0
/*{
let a = &mut u.particles[0];
let old_pos = a.position_f32();
a.position += a.velocity * delta;
let new_pos = a.position_f32();
u.add_line(Point3::from(old_pos), Point3::from(new_pos));
}*/
//Look for collisions TODO use the tree
/*let mut collisions: Arc<Mutex<Vec<(usize, usize)>>> = Arc::new(Mutex::new(Vec::with_capacity(100)));
let p = u.particles.clone();
p.par_iter().enumerate().for_each(|(i, a)| {
p.iter().enumerate().for_each(|(j, b)| {
if i != j && a.index != -1 && b.index != -1 {
let dist_squared = (a.position - b.position).norm_squared();
let r = a.radius + b.radius;
if dist_squared < r * r {
collisions.lock().unwrap().push((i, j));
}
}
});
});
for c in collisions.lock().unwrap().iter() {
u.combine(c.0, c.1, window);
}*/
//println!("tick");
}
// Tests
#[cfg(test)]
mod tests;

+ 33
- 0
src/rigid_point.rs View File

@ -0,0 +1,33 @@
extern crate nalgebra as na;
use na::{Vector3, Translation3, Point3, Point2};
use kiss3d::scene::SceneNode;
#[derive(Clone)]
pub struct RigidPoint {
pub position: Vector3<f64>,
pub velocity: Vector3<f64>, // Velocity is one half-timestep in the future
pub mass: f64,
pub index: isize,
pub radius: f64
}
impl RigidPoint {
pub fn new(position: Vector3<f64>, velocity: Vector3<f64>, mass: f64, index: isize, radius: f64) -> Self {
RigidPoint {
position: position,
velocity: velocity,
mass: mass,
index: index,
radius: radius
}
}
pub fn update_node(&self, node: &mut SceneNode) {
node.set_local_translation(Translation3::from(self.position_f32()));
}
pub fn position_f32(&self) -> Vector3<f32> {
Vector3::new(self.position.x as f32, self.position.y as f32, self.position.z as f32)
}
}

+ 39
- 0
src/tests.rs View File

@ -0,0 +1,39 @@
#[cfg(test)]
use super::*;
// Quick and dirty
fn assert_eq_floats(a: f64, b: f64) {
let good = (a - b).abs() < 0.0001;
if !good {
println!("Expected {}, actual {}", a, b);
}
assert!(good);
}
#[test]
fn test_calc_gravity() {
let mut a = RigidPoint::new(Vector3::new(0.0, 0.0, 0.0), Vector3::new(0.0, 0.0, 0.0), 1000.0, 0, 0.1);
let mut b = RigidPoint::new(Vector3::new(10.0, 10.0, 20.0), Vector3::new(0.0, 0.0, 0.0), 50.0, 1, 0.1);
calc_gravity(&mut a, &b, 1.0);
calc_gravity(&mut b, &a, 1.0);
assert_eq!(a.position, Vector3::new(0.0, 0.0, 0.0));
assert_eq_floats(G * b.mass / 24.494 / 24.494, a.velocity.norm());
assert_eq_floats(G * a.mass / 24.494 / 24.494, b.velocity.norm());
}
#[test]
fn test_octree_partitioning() {
//oct = Octree::new(Vector::new(0.0, 0.0, 0.0), Vector3::new(100.0, 100.0, 100.0));
let min: Vector3<f64> = Vector3::new(0.0, 0.0, 0.0);
let max: Vector3<f64> = Vector3::new(100.0, 100.0, 100.0);
let center = (min + max) / 2.0;
for i in (0..8) {
let (new_min, new_max) = Octree::get_octant_bounding_box_from_id(min, max, center, i);
let new_center = (new_min + new_max) / 2.0;
assert_eq!(i, Octree::get_id_from_center(center, new_center));
}
}

Loading…
Cancel
Save