(*
  Copyright (C) 2010 Florent Monnier

  Permission is hereby granted, free of charge, to any person obtaining a
  copy of this software and associated documentation files (the "Software"),
  to deal in the Software without restriction, including without limitation the
  rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
  sell copies of the Software, and to permit persons to whom the Software is
  furnished to do so, subject to the following conditions:

  The above copyright notice and this permission notice shall be included in
  all copies or substantial portions of the Software.

  The Software is provided "as is", without warranty of any kind, express or
  implied, including but not limited to the warranties of merchantability,
  fitness for a particular purpose and noninfringement. In no event shall
  the authors or copyright holders be liable for any claim, damages or other
  liability, whether in an action of contract, tort or otherwise, arising
  from, out of or in connection with the software or the use or other dealings
  in the Software.
*)

type t = float array

let pi = 3.14159_26535_89793_23846_2643
let pi_twice = pi *. 2.0
let pi_half = pi /. 2.0
let deg_to_rad = pi /. 180.0
let rad_to_deg = 180.0 /. pi

let get_identity() =
  [| 1.0; 0.0; 0.0; 0.0;
     0.0; 1.0; 0.0; 0.0;
     0.0; 0.0; 1.0; 0.0;
     0.0; 0.0; 0.0; 1.0; |]

let copy = Array.copy ;;


(* construct a projection matrix *)
let perspective_projection ~fov ~ratio ~near ~far =

  let maxY = near *. tan (fov *. pi /. 360.0) in
  let minY = -. maxY in
  let minX = minY *. ratio
  and maxX = maxY *. ratio in

  let x_diff = maxX -. minX in
  let y_diff = maxY -. minY in
  let z_diff = far -. near in
  let near_twice = 2.0 *. near in

  let a = near_twice /. x_diff
  and b = near_twice /. y_diff
  and c = (maxX +. minX) /. x_diff
  and d = (maxY +. minY) /. y_diff
  and e = -. (far +. near) /. z_diff
  and f = -. (near_twice *. far) /. z_diff
  in
  [| a;   0.0; 0.0; 0.0;
     0.0; b;   0.0; 0.0;
     c;   d;   e;  -1.0;
     0.0; 0.0; f;   0.0; |]


let ortho_projection ~left ~right ~bottom ~top ~near ~far =
  let x_diff = right -. left
  and y_diff = top -. bottom
  and z_diff = far -. near in
  [|
    (*
    (2.0 /. x_diff);  0.0;  0.0;    -. ((right +. left) /. x_diff);
    0.0;  (2.0 /. y_diff);  0.0;    -. ((top +. bottom) /. y_diff);
    0.0;  0.0;  (-.2.0 /. z_diff);  -. ((far +. near) /. z_diff);
    0.0;  0.0;  0.0;  1.0;
    *)
    2.0 /. x_diff;  0.0; 0.0; 0.0;
    0.0; 2.0 /. y_diff;  0.0; 0.0;
    0.0; 0.0; -2.0 /. z_diff; 0.0;
    (-. right -. left) /. x_diff;
    (-. top -. bottom) /. y_diff;
    (-. far -. near)   /. z_diff;
    1.0;
  |]


let frustum ~left ~right ~bottom ~top ~near ~far =
  let near_twice = 2.0 *. near
  and right_minus_left = right -. left
  and top_minus_bottom = top -. bottom
  and far_minus_near = far -. near
  in
  let e = near_twice /. right_minus_left
  and f = near_twice /. top_minus_bottom
  and a = (right +. left) /. right_minus_left
  and b = (top +. bottom) /. top_minus_bottom
  and c = -. ((far +. near) /. far_minus_near)
  and d = -. (near_twice *. far /. far_minus_near)
  in
  [| e;   0.0;  a;   0.0;
     0.0; f;    b;   0.0;
     0.0; 0.0;  c;   d;
     0.0; 0.0; -1.0; 0.0; |]


let translation_matrix (x,y,z) =
  [| 1.0; 0.0; 0.0; 0.0;
     0.0; 1.0; 0.0; 0.0;
     0.0; 0.0; 1.0; 0.0;
       x;   y;   z; 1.0; |]


let scale_matrix (x,y,z) =
  [|   x; 0.0; 0.0; 0.0;
     0.0;   y; 0.0; 0.0;
     0.0; 0.0;   z; 0.0;
     0.0; 0.0; 0.0; 1.0; |]

let print_mat m =
  Printf.printf "  \
  %6.3f  %6.3f  %6.3f  %6.3f
  %6.3f  %6.3f  %6.3f  %6.3f
  %6.3f  %6.3f  %6.3f  %6.3f
  %6.3f  %6.3f  %6.3f  %6.3f
%!"
  m.(0) m.(1) m.(2) m.(3)
  m.(4) m.(5) m.(6) m.(7)
  m.(8) m.(9) m.(10) m.(11)
  m.(12) m.(13) m.(14) m.(15)
;;


let x_rotation_matrix ~angle:a =
  let a = a *. deg_to_rad in
  let cos_a = cos a
  and sin_a = sin a in
  [| 1.0;      0.0;    0.0;  0.0;
     0.0;    cos_a;  sin_a;  0.0;
     0.0; -. sin_a;  cos_a;  0.0;
     0.0;      0.0;    0.0;  1.0; |]

let y_rotation_matrix ~angle:a =
  let a = a *. deg_to_rad in
  let cos_a = cos a
  and sin_a = sin a in
  [| cos_a;  0.0; -. sin_a;  0.0;
       0.0;  1.0;      0.0;  0.0;
     sin_a;  0.0;    cos_a;  0.0;
       0.0;  0.0;      0.0;  1.0; |]

let z_rotation_matrix ~angle:a =
  let a = a *. deg_to_rad in
  let cos_a = cos a
  and sin_a = sin a in
  [|   cos_a;  sin_a;  0.0;  0.0;
    -. sin_a;  cos_a;  0.0;  0.0;
         0.0;    0.0;  1.0;  0.0;
         0.0;    0.0;  0.0;  1.0; |]


let normalise_vector (x,y,z) =
  let nrm = 1.0 /. sqrt(x *. x +. y *. y +. z *. z) in
  (x *. nrm, y *. nrm, z *. nrm)



let rotation_matrix_of_axis ~axis:u ~angle:a =
  let a = a *. deg_to_rad in
  let c = cos a
  and s = sin a in
  let ux, uy, uz = normalise_vector u in
  let ux2 = ux *. ux
  and uy2 = uy *. uy
  and uz2 = uz *. uz
  and uxy = ux *. uy
  and uxz = ux *. uz
  and uyz = uy *. uz
  and cc = 1.0 -. c
  in
  [|
     ux2 +. (1.0 -. ux2) *. c;
     uxy *. cc +. uz *. s;
     uxz *. cc -. uy *. s;
     0.0;

     uxy *. cc -. uz *. s;
     uy2 +. (1.0 -. uy2) *. c;
     uyz *. cc +. ux *. s;
     0.0;

     uxz *. cc +. uy *. s;
     uyz *. cc -. ux *. s;
     uz2 +. (1.0 -. uz2) *. c;
     0.0;

     0.0; 0.0; 0.0;  1.0;
  |]


(* multiply two matrices *)
let mult_matrix ~m1 ~m2 =
  if Array.length m1 <> 16
  || Array.length m2 <> 16
  then invalid_arg "mult_matrix";

  let m1_0  = Array.unsafe_get m1 0
  and m1_1  = Array.unsafe_get m1 1
  and m1_2  = Array.unsafe_get m1 2
  and m1_3  = Array.unsafe_get m1 3
  and m1_4  = Array.unsafe_get m1 4
  and m1_5  = Array.unsafe_get m1 5
  and m1_6  = Array.unsafe_get m1 6
  and m1_7  = Array.unsafe_get m1 7
  and m1_8  = Array.unsafe_get m1 8
  and m1_9  = Array.unsafe_get m1 9
  and m1_10 = Array.unsafe_get m1 10
  and m1_11 = Array.unsafe_get m1 11
  and m1_12 = Array.unsafe_get m1 12
  and m1_13 = Array.unsafe_get m1 13
  and m1_14 = Array.unsafe_get m1 14
  and m1_15 = Array.unsafe_get m1 15

  and m2_0  = Array.unsafe_get m2 0
  and m2_1  = Array.unsafe_get m2 1
  and m2_2  = Array.unsafe_get m2 2
  and m2_3  = Array.unsafe_get m2 3
  and m2_4  = Array.unsafe_get m2 4
  and m2_5  = Array.unsafe_get m2 5
  and m2_6  = Array.unsafe_get m2 6
  and m2_7  = Array.unsafe_get m2 7
  and m2_8  = Array.unsafe_get m2 8
  and m2_9  = Array.unsafe_get m2 9
  and m2_10 = Array.unsafe_get m2 10
  and m2_11 = Array.unsafe_get m2 11
  and m2_12 = Array.unsafe_get m2 12
  and m2_13 = Array.unsafe_get m2 13
  and m2_14 = Array.unsafe_get m2 14
  and m2_15 = Array.unsafe_get m2 15
  in
  [|
    m1_0 *. m2_0  +. m1_4 *. m2_1  +. m1_8  *. m2_2  +. m1_12 *. m2_3;
    m1_1 *. m2_0  +. m1_5 *. m2_1  +. m1_9  *. m2_2  +. m1_13 *. m2_3;
    m1_2 *. m2_0  +. m1_6 *. m2_1  +. m1_10 *. m2_2  +. m1_14 *. m2_3;
    m1_3 *. m2_0  +. m1_7 *. m2_1  +. m1_11 *. m2_2  +. m1_15 *. m2_3;
    m1_0 *. m2_4  +. m1_4 *. m2_5  +. m1_8  *. m2_6  +. m1_12 *. m2_7;
    m1_1 *. m2_4  +. m1_5 *. m2_5  +. m1_9  *. m2_6  +. m1_13 *. m2_7;
    m1_2 *. m2_4  +. m1_6 *. m2_5  +. m1_10 *. m2_6  +. m1_14 *. m2_7;
    m1_3 *. m2_4  +. m1_7 *. m2_5  +. m1_11 *. m2_6  +. m1_15 *. m2_7;
    m1_0 *. m2_8  +. m1_4 *. m2_9  +. m1_8  *. m2_10 +. m1_12 *. m2_11;
    m1_1 *. m2_8  +. m1_5 *. m2_9  +. m1_9  *. m2_10 +. m1_13 *. m2_11;
    m1_2 *. m2_8  +. m1_6 *. m2_9  +. m1_10 *. m2_10 +. m1_14 *. m2_11;
    m1_3 *. m2_8  +. m1_7 *. m2_9  +. m1_11 *. m2_10 +. m1_15 *. m2_11;
    m1_0 *. m2_12 +. m1_4 *. m2_13 +. m1_8  *. m2_14 +. m1_12 *. m2_15;
    m1_1 *. m2_12 +. m1_5 *. m2_13 +. m1_9  *. m2_14 +. m1_13 *. m2_15;
    m1_2 *. m2_12 +. m1_6 *. m2_13 +. m1_10 *. m2_14 +. m1_14 *. m2_15;
    m1_3 *. m2_12 +. m1_7 *. m2_13 +. m1_11 *. m2_14 +. m1_15 *. m2_15;
  |]


(* low functional *)

let scale_mat ~m (x,y,z) =
  if Array.length m <> 16
  then invalid_arg "scale";

  let m_0  = Array.unsafe_get m 0
  and m_1  = Array.unsafe_get m 1
  and m_2  = Array.unsafe_get m 2
  and m_3  = Array.unsafe_get m 3
  and m_4  = Array.unsafe_get m 4
  and m_5  = Array.unsafe_get m 5
  and m_6  = Array.unsafe_get m 6
  and m_7  = Array.unsafe_get m 7
  and m_8  = Array.unsafe_get m 8
  and m_9  = Array.unsafe_get m 9
  and m_10 = Array.unsafe_get m 10
  and m_11 = Array.unsafe_get m 11
  and m_12 = Array.unsafe_get m 12
  and m_13 = Array.unsafe_get m 13
  and m_14 = Array.unsafe_get m 14
  and m_15 = Array.unsafe_get m 15
  in
  [|
    (m_0 *. x);  (m_1 *. x);  (m_2  *. x);  (m_3  *. x);
    (m_4 *. y);  (m_5 *. y);  (m_6  *. y);  (m_7  *. y);
    (m_8 *. z);  (m_9 *. z);  (m_10 *. z);  (m_11 *. z);
    (m_12);      (m_13);      (m_14);       (m_15);
  |]


let translate_mat ~m (x,y,z) =
  if Array.length m <> 16
  then invalid_arg "translate";

  let m_0  = Array.unsafe_get m 0
  and m_1  = Array.unsafe_get m 1
  and m_2  = Array.unsafe_get m 2
  and m_3  = Array.unsafe_get m 3
  and m_4  = Array.unsafe_get m 4
  and m_5  = Array.unsafe_get m 5
  and m_6  = Array.unsafe_get m 6
  and m_7  = Array.unsafe_get m 7
  and m_8  = Array.unsafe_get m 8
  and m_9  = Array.unsafe_get m 9
  and m_10 = Array.unsafe_get m 10
  and m_11 = Array.unsafe_get m 11
  and m_12 = Array.unsafe_get m 12
  and m_13 = Array.unsafe_get m 13
  and m_14 = Array.unsafe_get m 14
  and m_15 = Array.unsafe_get m 15
  in
  [|
    (m_0);    (m_1);    (m_2);    (m_3);
    (m_4);    (m_5);    (m_6);    (m_7);
    (m_8);    (m_9);    (m_10);   (m_11);
    (m_0 *. x) +. (m_4 *. y) +. (m_8  *. z) +. (m_12);
    (m_1 *. x) +. (m_5 *. y) +. (m_9  *. z) +. (m_13);
    (m_2 *. x) +. (m_6 *. y) +. (m_10 *. z) +. (m_14);
    (m_3 *. x) +. (m_7 *. y) +. (m_11 *. z) +. (m_15);
  |]


let x_rotate_mat ~m a =
  if Array.length m <> 16
  then invalid_arg "rotate_x";
  let a = a *. deg_to_rad in
  let cos_a = cos a
  and sin_a = sin a in
  let neg_sin_a = -. sin_a in
  let get = Array.unsafe_get m in
  let m_0  = get 0
  and m_1  = get 1
  and m_2  = get 2
  and m_3  = get 3
  and m_4  = get 4
  and m_5  = get 5
  and m_6  = get 6
  and m_7  = get 7
  and m_8  = get 8
  and m_9  = get 9
  and m_10 = get 10
  and m_11 = get 11
  and m_12 = get 12
  and m_13 = get 13
  and m_14 = get 14
  and m_15 = get 15
  in
  [|
    (m_0);
    (m_1);
    (m_2);
    (m_3);
    (m_4 *. cos_a) +. (m_8  *. sin_a);
    (m_5 *. cos_a) +. (m_9  *. sin_a);
    (m_6 *. cos_a) +. (m_10 *. sin_a);
    (m_7 *. cos_a) +. (m_11 *. sin_a);
    (m_4 *. neg_sin_a) +. (m_8  *. cos_a);
    (m_5 *. neg_sin_a) +. (m_9  *. cos_a);
    (m_6 *. neg_sin_a) +. (m_10 *. cos_a);
    (m_7 *. neg_sin_a) +. (m_11 *. cos_a);
    (m_12);
    (m_13);
    (m_14);
    (m_15);
  |]

let y_rotate_mat ~m a =
  if Array.length m <> 16
  then invalid_arg "rotate_y";
  let a = a *. deg_to_rad in
  let cos_a = cos a
  and sin_a = sin a in
  let neg_sin_a = -. sin_a in
  let get = Array.unsafe_get m in
  let m_0  = get 0
  and m_1  = get 1
  and m_2  = get 2
  and m_3  = get 3
  and m_4  = get 4
  and m_5  = get 5
  and m_6  = get 6
  and m_7  = get 7
  and m_8  = get 8
  and m_9  = get 9
  and m_10 = get 10
  and m_11 = get 11
  and m_12 = get 12
  and m_13 = get 13
  and m_14 = get 14
  and m_15 = get 15
  in
  [|
    (m_0 *. cos_a) +. (m_8  *. neg_sin_a);
    (m_1 *. cos_a) +. (m_9  *. neg_sin_a);
    (m_2 *. cos_a) +. (m_10 *. neg_sin_a);
    (m_3 *. cos_a) +. (m_11 *. neg_sin_a);
    (m_4);
    (m_5);
    (m_6);
    (m_7);
    (m_0 *. sin_a) +. (m_8  *. cos_a);
    (m_1 *. sin_a) +. (m_9  *. cos_a);
    (m_2 *. sin_a) +. (m_10 *. cos_a);
    (m_3 *. sin_a) +. (m_11 *. cos_a);
    (m_12);
    (m_13);
    (m_14);
    (m_15);
  |]

let z_rotate_mat ~m a =
  if Array.length m <> 16
  then invalid_arg "rotate_z";
  let a = a *. deg_to_rad in
  let cos_a = cos a
  and sin_a = sin a in
  let neg_sin_a = -. sin_a in
  let get = Array.unsafe_get m in
  let m_0  = get 0
  and m_1  = get 1
  and m_2  = get 2
  and m_3  = get 3
  and m_4  = get 4
  and m_5  = get 5
  and m_6  = get 6
  and m_7  = get 7
  and m_8  = get 8
  and m_9  = get 9
  and m_10 = get 10
  and m_11 = get 11
  and m_12 = get 12
  and m_13 = get 13
  and m_14 = get 14
  and m_15 = get 15
  in
  [|
    (m_0 *. cos_a) +. (m_4 *. sin_a);
    (m_1 *. cos_a) +. (m_5 *. sin_a);
    (m_2 *. cos_a) +. (m_6 *. sin_a);
    (m_3 *. cos_a) +. (m_7 *. sin_a);
    (m_0 *. neg_sin_a) +. (m_4 *. cos_a);
    (m_1 *. neg_sin_a) +. (m_5 *. cos_a);
    (m_2 *. neg_sin_a) +. (m_6 *. cos_a);
    (m_3 *. neg_sin_a) +. (m_7 *. cos_a);
    (m_8 );
    (m_9 );
    (m_10);
    (m_11);
    (m_12);
    (m_13);
    (m_14);
    (m_15);
  |]

let rotate_mat ~m a u =
  if Array.length m <> 16
  then invalid_arg "rotate";
  let a = a *. deg_to_rad in
  let c = cos a
  and s = sin a in
  let ux, uy, uz = normalise_vector u in
  let ux2 = ux *. ux
  and uy2 = uy *. uy
  and uz2 = uz *. uz
  and uxy = ux *. uy
  and uxz = ux *. uz
  and uyz = uy *. uz
  and cc = 1.0 -. c
  in
  let m2_0  = ux2 +. (1.0 -. ux2) *. c;
  and m2_1  = uxy *. cc +. uz *. s;
  and m2_2  = uxz *. cc -. uy *. s;
  and m2_4  = uxy *. cc -. uz *. s;
  and m2_5  = uy2 +. (1.0 -. uy2) *. c;
  and m2_6  = uyz *. cc +. ux *. s;
  and m2_8  = uxz *. cc +. uy *. s;
  and m2_9  = uyz *. cc -. ux *. s;
  and m2_10 = uz2 +. (1.0 -. uz2) *. c;
  in
  let m1_0  = Array.unsafe_get m 0
  and m1_1  = Array.unsafe_get m 1
  and m1_2  = Array.unsafe_get m 2
  and m1_3  = Array.unsafe_get m 3
  and m1_4  = Array.unsafe_get m 4
  and m1_5  = Array.unsafe_get m 5
  and m1_6  = Array.unsafe_get m 6
  and m1_7  = Array.unsafe_get m 7
  and m1_8  = Array.unsafe_get m 8
  and m1_9  = Array.unsafe_get m 9
  and m1_10 = Array.unsafe_get m 10
  and m1_11 = Array.unsafe_get m 11
  and m1_12 = Array.unsafe_get m 12
  and m1_13 = Array.unsafe_get m 13
  and m1_14 = Array.unsafe_get m 14
  and m1_15 = Array.unsafe_get m 15
  in
  [|
    m1_0 *. m2_0  +. m1_4 *. m2_1  +. m1_8  *. m2_2;
    m1_1 *. m2_0  +. m1_5 *. m2_1  +. m1_9  *. m2_2;
    m1_2 *. m2_0  +. m1_6 *. m2_1  +. m1_10 *. m2_2;
    m1_3 *. m2_0  +. m1_7 *. m2_1  +. m1_11 *. m2_2;
    m1_0 *. m2_4  +. m1_4 *. m2_5  +. m1_8  *. m2_6;
    m1_1 *. m2_4  +. m1_5 *. m2_5  +. m1_9  *. m2_6;
    m1_2 *. m2_4  +. m1_6 *. m2_5  +. m1_10 *. m2_6;
    m1_3 *. m2_4  +. m1_7 *. m2_5  +. m1_11 *. m2_6;
    m1_0 *. m2_8  +. m1_4 *. m2_9  +. m1_8  *. m2_10;
    m1_1 *. m2_8  +. m1_5 *. m2_9  +. m1_9  *. m2_10;
    m1_2 *. m2_8  +. m1_6 *. m2_9  +. m1_10 *. m2_10;
    m1_3 *. m2_8  +. m1_7 *. m2_9  +. m1_11 *. m2_10;
    m1_12;
    m1_13;
    m1_14;
    m1_15;
  |]



(* high functional *)

let scale m v f =
  f (scale_mat m v)

let translate m v f =
  f (translate_mat m v)

let rotate_x m a f =
  f (x_rotate_mat m a)

let rotate_y m a f =
  f (y_rotate_mat m a)

let rotate_z m a f =
  f (z_rotate_mat m a)

let rotate m a v f =
  f (rotate_mat m a v)


let compute_normal_of_plane_feed ~normal ~vec_a ~vec_b =
  normal.(0) <- (vec_a.(1) *. vec_b.(2)) -. (vec_a.(2) *. vec_b.(1));
  normal.(1) <- (vec_a.(2) *. vec_b.(0)) -. (vec_a.(0) *. vec_b.(2));
  normal.(2) <- (vec_a.(0) *. vec_b.(1)) -. (vec_a.(1) *. vec_b.(0));
;;

let compute_normal_of_plane ~vec_a:(ax,ay,az) ~vec_b:(bx,by,bz) =
  ( (ay *. bz) -. (az *. by),
    (az *. bx) -. (ax *. bz),
    (ax *. by) -. (ay *. bx) )

let compute_normal_of_plane_ta ~vec_a ~vec_b =
  ( (vec_a.(1) *. vec_b.(2)) -. (vec_a.(2) *. vec_b.(1)),
    (vec_a.(2) *. vec_b.(0)) -. (vec_a.(0) *. vec_b.(2)),
    (vec_a.(0) *. vec_b.(1)) -. (vec_a.(1) *. vec_b.(0)) )

let compute_normal_of_plane_arr ~vec_a ~vec_b =
  [| (vec_a.(1) *. vec_b.(2)) -. (vec_a.(2) *. vec_b.(1));
     (vec_a.(2) *. vec_b.(0)) -. (vec_a.(0) *. vec_b.(2));
     (vec_a.(0) *. vec_b.(1)) -. (vec_a.(1) *. vec_b.(0)); |]

let normalize_vector (x,y,z) =
  let f = sqrt(x *. x +. y *. y +. z *. z) in
  (x /. f,
   y /. f,
   z /. f)



let matrix_translate ~matrix (x, y, z) =
  matrix.(12) <- matrix.(0) *. x +. matrix.(4) *. y +. matrix.(8)  *. z +. matrix.(12);
  matrix.(13) <- matrix.(1) *. x +. matrix.(5) *. y +. matrix.(9)  *. z +. matrix.(13);
  matrix.(14) <- matrix.(2) *. x +. matrix.(6) *. y +. matrix.(10) *. z +. matrix.(14);
  matrix.(15) <- matrix.(3) *. x +. matrix.(7) *. y +. matrix.(11) *. z +. matrix.(15);
;;



let look_at ~eye:(eye_x, eye_y, eye_z) ~center:(to_x,to_y,to_z) ~up_vector =
  let (forward_x,
       forward_y,
       forward_z) as forward =
    normalize_vector (to_x -. eye_x,
                      to_y -. eye_y,
                      to_z -. eye_z) in
  (* Side = forward x up *)
  let (side_x,
       side_y,
       side_z) as side =
    normalize_vector (compute_normal_of_plane forward up_vector) in
  (* Recompute up as: up = side x forward *)
  let up_x, up_y, up_z = compute_normal_of_plane side forward in
  let resultMatrix =
    [| side_x; up_x; -.forward_x; 0.0;
       side_y; up_y; -.forward_y; 0.0;
       side_z; up_z; -.forward_z; 0.0;
       0.0;    0.0;    0.0;       1.0; |]
  in
  matrix_translate resultMatrix (-. eye_x, -. eye_y, -. eye_z);
  (resultMatrix)
;;


let mult_matrices a b =
  let r = Array.make 16 0.0 in
  for i = 0 to pred 4 do
    for j = 0 to pred 4 do
      r.(i*4+j) <-
          a.(i*4+0) *. b.(0*4+j) +.
          a.(i*4+1) *. b.(1*4+j) +.
          a.(i*4+2) *. b.(2*4+j) +.
          a.(i*4+3) *. b.(3*4+j);
    done;
  done;
  (r)

let mult_matrix_vec matrix in_ =
  Array.init 4 (fun i ->
    in_.(0) *. matrix.(0*4+i) +.
    in_.(1) *. matrix.(1*4+i) +.
    in_.(2) *. matrix.(2*4+i) +.
    in_.(3) *. matrix.(3*4+i)
  )

(* Invert 4x4 matrix.
   Contributed by David Moore (See Mesa bug #6748) *)
let invert_matrix m invOut =
  let ( * ) = ( *. ) in
  let ( + ) = ( +. ) in
  let ( - ) = ( -. ) in

  if Array.length m <> 16 then invalid_arg "invert_matrix";

  let m00 = Array.unsafe_get m 0
  and m01 = Array.unsafe_get m 1
  and m02 = Array.unsafe_get m 2
  and m03 = Array.unsafe_get m 3
  and m04 = Array.unsafe_get m 4
  and m05 = Array.unsafe_get m 5
  and m06 = Array.unsafe_get m 6
  and m07 = Array.unsafe_get m 7
  and m08 = Array.unsafe_get m 8
  and m09 = Array.unsafe_get m 9
  and m10 = Array.unsafe_get m 10
  and m11 = Array.unsafe_get m 11
  and m12 = Array.unsafe_get m 12
  and m13 = Array.unsafe_get m 13
  and m14 = Array.unsafe_get m 14
  and m15 = Array.unsafe_get m 15
  in
  let inv = [|
      m05*m10*m15 - m05*m11*m14 - m09*m06*m15 + m09*m07*m14 + m13*m06*m11 - m13*m07*m10;
    -.m01*m10*m15 + m01*m11*m14 + m09*m02*m15 - m09*m03*m14 - m13*m02*m11 + m13*m03*m10;
      m01*m06*m15 - m01*m07*m14 - m05*m02*m15 + m05*m03*m14 + m13*m02*m07 - m13*m03*m06;
    -.m01*m06*m11 + m01*m07*m10 + m05*m02*m11 - m05*m03*m10 - m09*m02*m07 + m09*m03*m06;
    -.m04*m10*m15 + m04*m11*m14 + m08*m06*m15 - m08*m07*m14 - m12*m06*m11 + m12*m07*m10;
      m00*m10*m15 - m00*m11*m14 - m08*m02*m15 + m08*m03*m14 + m12*m02*m11 - m12*m03*m10;
    -.m00*m06*m15 + m00*m07*m14 + m04*m02*m15 - m04*m03*m14 - m12*m02*m07 + m12*m03*m06;
      m00*m06*m11 - m00*m07*m10 - m04*m02*m11 + m04*m03*m10 + m08*m02*m07 - m08*m03*m06;
      m04*m09*m15 - m04*m11*m13 - m08*m05*m15 + m08*m07*m13 + m12*m05*m11 - m12*m07*m09;
    -.m00*m09*m15 + m00*m11*m13 + m08*m01*m15 - m08*m03*m13 - m12*m01*m11 + m12*m03*m09;
      m00*m05*m15 - m00*m07*m13 - m04*m01*m15 + m04*m03*m13 + m12*m01*m07 - m12*m03*m05;
    -.m00*m05*m11 + m00*m07*m09 + m04*m01*m11 - m04*m03*m09 - m08*m01*m07 + m08*m03*m05;
    -.m04*m09*m14 + m04*m10*m13 + m08*m05*m14 - m08*m06*m13 - m12*m05*m10 + m12*m06*m09;
      m00*m09*m14 - m00*m10*m13 - m08*m01*m14 + m08*m02*m13 + m12*m01*m10 - m12*m02*m09;
    -.m00*m05*m14 + m00*m06*m13 + m04*m01*m14 - m04*m02*m13 - m12*m01*m06 + m12*m02*m05;
      m00*m05*m10 - m00*m06*m09 - m04*m01*m10 + m04*m02*m09 + m08*m01*m06 - m08*m02*m05;
  |] in

  let det = m00 * inv.(0) + m01 * inv.(4) + m02 * inv.(8) + m03 * inv.(12) in
  if det = 0.0 then (false) else
  begin
    let det = 1.0 /. det in
    let get = Array.unsafe_get inv in
    for i = 0 to pred 16 do
      invOut.(i) <- (get i) *. det;
    done;
    (true)
  end

let unproject ~win_x ~win_y ~win_z ~model ~proj ~viewport =
  let finalMatrix = mult_matrices model proj in
  if not(invert_matrix finalMatrix finalMatrix) then failwith "unproject";

  let _in = [| win_x; win_y; win_z; 1.0; |] in

  (* map x and y from window coordinates *)
  _in.(0) <- (_in.(0) -. float viewport.(0)) /. float viewport.(2);
  _in.(1) <- (_in.(1) -. float viewport.(1)) /. float viewport.(3);

  (* map to range -1 to 1 *)
  _in.(0) <- _in.(0) *. 2. -. 1.;
  _in.(1) <- _in.(1) *. 2. -. 1.;
  _in.(2) <- _in.(2) *. 2. -. 1.;

  let out = mult_matrix_vec finalMatrix _in in
  if out.(3) = 0.0 then failwith "unproject";

  let d = 1.0 /. out.(3) in
  let x = out.(0) *. d
  and y = out.(1) *. d
  and z = out.(2) *. d in
  (x, y, z)