import * as THREE from 'three';
import IKChain from './IKChain';

export type IKSolverOptions = {
  iterations?: number;
};

// Temporary variable for calculation
// Avoid unnecessary instantiations
const goalPosition = new THREE.Vector3();
const joint2GoalVector = new THREE.Vector3();
const effectorPosition = new THREE.Vector3();
const joint2EffectorVector = new THREE.Vector3();
const jointPosition = new THREE.Vector3();
const jointQuaternionInverse = new THREE.Quaternion();
const jointScale = new THREE.Vector3();
const axis = new THREE.Vector3();
const quaternion = new THREE.Quaternion();

export class IKSolver {
  public chains: IKChain[] = [];

  public iterations = 1;

  constructor(options: IKSolverOptions = {}) {
    this.iterations = options.iterations || this.iterations;
  }

  public add(chain: IKChain) {
    this.chains.push(chain);
  }

  public solve() {
    this.chains.forEach((chain) => {
      this._solveChain(chain, this.iterations);
    });
    this.chains.forEach((chain) => {
      chain.joints.forEach((joint) => {
        joint.applyConstraints(chain, 'applyLazy');
      });
    });
  }

  protected _solveChain(chain: IKChain, iteration = 1) {
    if (!chain.target || !chain.effector) return;
    // world coordinates of target position
    // TODO: optimize
    chain.target.getWorldPosition(goalPosition);

    for (let i = iteration; i > 0; i--) {
      let didConverge = true;
      chain.joints.forEach((joint) => {
        if (!chain.effector) return;
        // Get the world coordinates, posture, etc. of the target joint
        joint.bone.matrixWorld.decompose(jointPosition, jointQuaternionInverse, jointScale);
        jointQuaternionInverse.invert();

        //  joint of interest -> vector of effector
        chain.effector.bone.getWorldPosition(effectorPosition);
        joint2EffectorVector.subVectors(effectorPosition, jointPosition);
        joint2EffectorVector.applyQuaternion(jointQuaternionInverse);
        joint2EffectorVector.normalize();

        // target joint -> target position vector
        joint2GoalVector.subVectors(goalPosition, jointPosition);
        joint2GoalVector.applyQuaternion(jointQuaternionInverse);
        joint2GoalVector.normalize();

        // cos rad
        let deltaAngle = joint2GoalVector.dot(joint2EffectorVector);

        if (deltaAngle > 1.0) {
          deltaAngle = 1.0;
        } else if (deltaAngle < -1.0) {
          deltaAngle = -1.0;
        }

        // rad
        deltaAngle = Math.acos(deltaAngle);

        // Vibration avoidance
        if (deltaAngle < 1e-4) {
          return;
        }

        // TODO: Limitation of min rotation amount

        // Axis of rotation
        axis.crossVectors(joint2EffectorVector, joint2GoalVector);
        axis.normalize();

        // rotate
        quaternion.setFromAxisAngle(axis, deltaAngle);
        joint.bone.quaternion.multiply(quaternion);

        joint.applyConstraints(chain);

        // TODO: optimize
        joint.bone.updateMatrixWorld(true);
        didConverge = false;
      });

      if (didConverge) break;
    }
  }
}
