#include "align3d.h"

Align3D::Align3D(Shared_Prm *prm)
 : SharedPrm(prm),
   aaReplMtx(nullptr),
   aaOrder(AA_ORDER.aa_order()),
   aaReplMaxVal(0.0),
   Window(3),
   aaWeight(1.0),
   Threshold(0.0),
   Harr(false)
{}

void Align3D::clear() {
 //TBPoint は，search_traceback_point() 内で初期化して push_back

 Dynamic.clear();

 Window = 3;
 aaWeight = 0.0;
 Threshold = 0.0;
 Harr = false;
 aaReplMaxVal = 0.0;

 for(int i = 0; i < 2; ++i) {
  OrigSeq[i] = nullptr;
  orgXYZ[i] = nullptr;
  winXYZ[i].clear();
  winXYZ[i].shrink_to_fit();
  Alignment[i].clear();
  Alignment[i].shrink_to_fit();
 }
}

void Align3D::init() {
 if(SharedPrm->AaRepl == 0) {
  aaReplMtx = &PAM250;
  aaReplMaxVal = 25.0;
 }
 else {
  aaReplMtx = &BLOSUM62;
  aaReplMaxVal = 15.0;
 }

 //スコアマトリックスの初期化（スコアは 0.0 で初期化しておけばよい）
 int col_cnt = 0;
 int row_cnt = 0;
 if(aaWeight) {
  col_cnt = OrigSeq[1]->size() - Window + 2;
  row_cnt = OrigSeq[0]->size() - Window + 2;
 }
 else {
  col_cnt = orgXYZ[1]->size() - Window + 2;
  row_cnt = orgXYZ[0]->size() - Window + 2;
 }
 Dynamic.init_mtx(row_cnt, col_cnt, SharedPrm->GapIni, SharedPrm->GapElg);

 //winXYZ[0] のサンプリング
 --row_cnt;  //winXYZ[0].assign(row_cnt);
 for(int i = 0; i != row_cnt; ++i) {
  int j_end = i + Window;
  std::vector<Eigen::Vector3d> row;
  for(int j = i; j != j_end; ++j) {
   for(int k = 0; k != 4; ++k) {
    row.push_back((*orgXYZ[0])[j][k]);
   }
  }
  centering(&row);
  winXYZ[0].push_back(row);
 }

 //winXYZ[1] のサンプリング
 --col_cnt;  //winXYZ[1].assign(row_cnt);
 for(int i = 0; i != col_cnt; ++i) {
  int j_end = i + Window;
  std::vector<Eigen::Vector3d> row;
  for(int j = i; j != j_end; ++j) {
   for(int k = 0; k != 4; ++k) {
    row.push_back((*orgXYZ[1])[j][k]);
   }
  }
  centering(&row);
  winXYZ[1].push_back(row);
 }
}

void Align3D::trace() {
 int row_cnt = Dynamic.get_row_cnt();//scoreMtx.size();
 int col_cnt = Dynamic.get_col_cnt();//scoreMtx[0].size();
 if(aaWeight == 1.0) {
  for(int i = 1; i < row_cnt; ++i)
   for(int j = 1; j < col_cnt; ++j) {
    Dynamic.trace_1_step(i, j, gain_aa(i - 1, j - 1));
   }
 }
 else if(aaWeight == 0.0) {
  for(int i = 1; i < row_cnt; ++i)
   for(int j = 1; j < col_cnt; ++j) {
    Dynamic.trace_1_step(i, j, gain_3d(i - 1, j - 1));
   }
 }
 else {
  for(int i = 1; i < row_cnt; ++i)
   for(int j = 1; j < col_cnt; ++j) {
    double gain = aaWeight * gain_aa(i - 1, j - 1);
    gain += (1.0 - aaWeight) * gain_3d(i - 1, j - 1);
    Dynamic.trace_1_step(i, j, gain);
   }
 }
}

double Align3D::gain_aa(int pos0, int pos1) {
 double gain = 0.0;
 for(int i = 0; i != Window; ++i) {
  char code0 = (*OrigSeq[0])[pos0 + i];
  char code1 = (*OrigSeq[1])[pos1 + i];
  //不明なアミノ酸は 'X' としている
  if(code0 !='X' && code1 != 'X')
   gain += (*aaReplMtx)[aaOrder.at(code0)][aaOrder.at(code1)];
 }

 return gain / (double)Window;
}

double Align3D::gain_3d(int pos0, int pos1) {
 //誤差の蓄積を避けるため，回転結果はこれに格納する
 std::vector<Eigen::Vector3d> rotated(
        Window * 4,
        Eigen::Vector3d::Zero(3));

 //サンプリング時に centering 済み
 Sup.init();
 //winXYZ[1][pos1] を基準にして winXYZ[0][pos0] を回転
 Sup.create_rotation_matrix_core(
        &winXYZ[0][pos0],
        &winXYZ[1][pos1]);

 //winXYZ[0][pos0] を回転して rotated に格納する
 Sup.rotation_core(&winXYZ[0][pos0], &rotated);

 double rms = calc_rms(&rotated, &winXYZ[1][pos1]);

 //0 除算を回避しつつ大小関係を逆転させる
 //そのままだと値の範囲が 0.0 〜1.0 なので，アミノ酸置換行列の範囲に合わせる
 return 1.0 / exp(sqrt(rms)) * aaReplMaxVal - Threshold;
}

void Align3D::prepare_traceback_point()
{
 const std::vector<TB_Point> *tb_point
        = Dynamic.search_traceback_point();

 //閾値より高いスコアが格納された点をコピーする
 TBPoint.clear();
 TBPoint.shrink_to_fit();
 for(const TB_Point & point : *tb_point)
  if(point.Score > Threshold)
   TBPoint.push_back(point);
}

void Align3D::traceback(int num) {
 num = 0;  //今は最大スコアの点から

 for(int i = 0; i != 2; ++i) {
  Alignment[i].clear();
  Alignment[i].shrink_to_fit();
 }
 const std::array<std::vector<int>, 2> *aligned_num
        = Dynamic.traceback(num);

 //バグが発生したので，ていねいにコーディング
 int length = (*aligned_num)[0].size();
 for(int i = 0; i < length; ++i) {
  int aa_pos = (*aligned_num)[0][i];
  if(aa_pos > -1) {  //ギャップでない
   int code = (*OrigSeq[0])[aa_pos];
   Alignment[0].push_back(code);
  }
  else {  //ギャップ
   Alignment[0].push_back('-');
  }
 }
 for(int i = 0; i < length; ++i) {
  int aa_pos = (*aligned_num)[1][i];
  if(aa_pos > -1) {  //ギャップでない
   int code = (*OrigSeq[1])[aa_pos];
   Alignment[1].push_back(code);
  }
  else {  //ギャップ
   Alignment[1].push_back('-');
  }
 }
}

void Align3D::calc_harr() {
//trace とほぼ同じコードを使ったので，scoreMtx の第 0 行，第 0 列は空白になる
 int row_cnt = Dynamic.get_row_cnt();//scoreMtx.size();
 int col_cnt = Dynamic.get_col_cnt();//scoreMtx[0].size();
 if(aaWeight == 1.0) {
  for(int i = 1; i < row_cnt; ++i)
   for(int j = 1; j < col_cnt; ++j) {
    Dynamic.set_score(i, j, gain_aa(i - 1, j - 1));
   }
 }
 else if(aaWeight == 0.0) {
  for(int i = 1; i < row_cnt; ++i)
   for(int j = 1; j < col_cnt; ++j) {
    Dynamic.set_score(i, j, gain_3d(i - 1, j - 1));
   }
 }
 else {
  for(int i = 1; i < row_cnt; ++i)
   for(int j = 1; j < col_cnt; ++j) {
    double gain = aaWeight * gain_aa(i - 1, j - 1);
    gain += (1.0 - aaWeight) * gain_3d(i - 1, j - 1);
    Dynamic.set_score(i, j, gain);
   }
 }
}

/*
void Align3D::get_xyz_from_alignment(int layer, std::vector<Eigen::Vector3d> &xyz) {
 int cnt = alignAANum[layer].size();
 if(!cnt) {
  std::cerr << __PRETTY_FUNCTION__ << '\n';
  exit(1);
 }

 for(int i = 0; i != cnt; ++i)
  if(alignAANum[0][i] != -1 && alignAANum[1][i] != -1) {
   int aa_pos = alignAANum[layer][i];
   for(const Eigen::Vector3d &vect : (*orgXYZ[layer])[aa_pos])
    xyz.push_back(vect);
  }
}
*/
