#include "atom_near.h"

Atom_Near::Atom_Near(std::vector<Atom> *atom_vect, const std::map<int, int> *idx_map, std::vector<Mol> *mol_vect)
 : atomVect(atom_vect),
   serialToIdx(idx_map),
   molVect(mol_vect)
{}

void Atom_Near::clear() {
 checkedSerial.clear();
 checkedSerial.shrink_to_fit();

 checkedAtomIdx.clear();
 checkedAtomIdx.shrink_to_fit();

 checkedAANum.clear();
 checkedAANum.shrink_to_fit();
}

const std::vector<std::array<int, 3>> *Atom_Near::atom_to_neighbor_aa(
        const std::vector<int> *atom_serial,
        double range) {

 //atom_serial 近傍原子をチェック
 checkedAtomIdx.assign(atomVect->size(), 0);
 for(int serial : *atom_serial)
  check_near_atom(serial, range);

 //チェックされた原子のアミノ酸を特定
 checkedAANum.clear();
 checkedAANum.shrink_to_fit();
 std::array<int, 3> row {-1, -1, -1};
 int atom_cnt = atomVect->size();
 for(int i = 0; i < atom_cnt; ++i)
  if(checkedAtomIdx[i] && (*atomVect)[i].aaIdx > -1) {
   row[0] = (*atomVect)[i].molIdx;
   row[1] = (*atomVect)[i].peptIdx;
   row[2] = (*atomVect)[i].aaIdx;
   checkedAANum.push_back(row);
  }

 return &checkedAANum;
}

void Atom_Near::check_near_atom(int serial, double range) {
 int atom_cnt = atomVect->size();
 int idx = serialToIdx->at(serial);
 Eigen::Vector3d &v0 = (*atomVect)[idx].XYZ;
 for(int i = 0; i < idx; ++i) {
  Eigen::Vector3d &v1 = (*atomVect)[i].XYZ;
  if(fabs(v1[0] - v0[0]) < range)
   if((v1 - v0).norm() < range)
    checkedAtomIdx[i] = 1;
 }
 for(int i = idx + 1; i < atom_cnt; ++i) {
  Eigen::Vector3d &v1 = (*atomVect)[i].XYZ;
  if(fabs(v1[0] - v0[0]) < range)
   if((v1 - v0).norm() < range)
    checkedAtomIdx[i] = 1;
 }
}

void Atom_Near::check_near_atom(int mol0, int mol1, double range) {
 const std::vector<int> &serial0 = (*molVect)[mol0].atomSerial;
 const std::vector<int> &serial1 = (*molVect)[mol1].atomSerial;
 for(int s0 : serial0) {
  int idx0 = serialToIdx->at(s0);
  Eigen::Vector3d &v0 = (*atomVect)[idx0].XYZ;
  for(int s1 : serial1) {
   int idx1 = serialToIdx->at(s1);
   Eigen::Vector3d &v1 = (*atomVect)[idx1].XYZ;
   if(fabs(v1[0] - v0[0]) < range)
    if((v1 - v0).norm() < range)
     checkedAtomIdx[idx1] = 1;
  }
 }
}

const std::vector<std::array<int, 3>> *Atom_Near::aa_to_neighbor_aa(
        const std::vector<std::array<int, 3>> *aa,
        double range) {
 checkedAANum.clear();
 checkedAANum.shrink_to_fit();
 checkedAtomIdx.assign(atomVect->size(), 0);

 //アミノ酸情報から原子を特定
 std::vector<int> aa_atom;
 for(const std::array<int, 3> &row : *aa) {
  const Amino_Acid &aa = (*molVect)[row[0]].aaTable[row[1]][row[2]];
  for(int serial : aa.allSerial)
   check_near_atom(serial, range);
 }

 //チェックされた原子のアミノ酸を特定
 std::array<int, 3> row {-1, -1, -1};
 int atom_cnt = atomVect->size();
 for(int i = 0; i < atom_cnt; ++i)
  if(checkedAtomIdx[i] && (*atomVect)[i].aaIdx > -1) {
   row[0] = (*atomVect)[i].molIdx;
   row[1] = (*atomVect)[i].peptIdx;
   row[2] = (*atomVect)[i].aaIdx;
   checkedAANum.push_back(row);
  }

 return &checkedAANum;
}

const std::vector<std::array<int, 3>> *Atom_Near::mol_to_neighbor_aa(
        int mol_num,
        double range) {
 checkedAANum.clear();
 checkedAANum.shrink_to_fit();
 int atom_cnt = atomVect->size();
 checkedAtomIdx.assign(atom_cnt, 0);

 int mol_cnt = molVect->size();
 for(int i = 0; i < mol_num; ++i)
  check_near_atom(mol_num, i, range);
 for(int i = mol_num + 1; i < mol_cnt; ++i)
  check_near_atom(mol_num, i, range);

 std::array<int, 3> row {-1, -1, -1};
 for(int i = 0; i < atom_cnt; ++i)
  if(checkedAtomIdx[i] && (*atomVect)[i].aaIdx > -1) {
   row[0] = (*atomVect)[i].molIdx;
   row[1] = (*atomVect)[i].peptIdx;
   row[2] = (*atomVect)[i].aaIdx;
   checkedAANum.push_back(row);
  }

 return &checkedAANum;
}

const std::vector<int> *Atom_Near::mol_to_neighbor_atom(int mol0, int mol1, double range) {
 checkedAtomIdx.assign(atomVect->size(), 0);

 for(int serial0 : (*molVect)[mol0].atomSerial) {
  int idx0 = serialToIdx->at(serial0);
  Eigen::Vector3d v0 = (*atomVect)[idx0].XYZ;
  for(int serial1 : (*molVect)[mol1].atomSerial) {
   int idx1 = serialToIdx->at(serial1);
   Eigen::Vector3d v1 = (*atomVect)[idx1].XYZ;
   if(fabs(v1[0] - v0[0]) < range)
    if((v1 - v0).norm() < range) {
     checkedAtomIdx[idx0] = 1;
     checkedAtomIdx[idx1] = 1;
    }
  }
 }

 return &checkedAtomIdx;
}

const std::vector<int> *Atom_Near::fused_serial() {
 //標準的な共有結合距離（分岐限定法で使う）
 std::vector<std::vector<double>> cov_table;
 int atom_cnt = atomVect->size();
 cov_table.assign(atom_cnt, std::vector<double>(atom_cnt, 0.0));

 const std::array<double, ELMT_CNT> &cov_radius
        = ATOM_DISTANCE.cov_radius();
 int i_end = atom_cnt - 1;
 for(int i = 0; i < i_end; ++i) {
  int i_elmt = (*atomVect)[i].Elmt;
  for(int j = i + 1; j < atom_cnt; ++j) {
   int j_elmt = (*atomVect)[j].Elmt;
   cov_table[i][j] = cov_table[j][i]
        = cov_radius[i_elmt] + cov_radius[j_elmt];
  }
 }

 //"fused atom" のチェックはここから
 std::vector<int> checked_idx(atom_cnt, 0);
 for(int i = 0; i < i_end; ++i) {
  Eigen::Vector3d &xyz_i = (*atomVect)[i].XYZ;
  for(int j = i + 1; j < atom_cnt; ++j) {
   Eigen::Vector3d &xyz_j = (*atomVect)[j].XYZ;
  //X 座標がある程度近くにあり，
   if(fabs( xyz_i[0] - xyz_j[0]) < 5.568) {
    //原子間距離が共有結合距離の半分未満であり，
    if((xyz_i - xyz_j).norm() < cov_table[i][j] * 0.5) {
     //同元素なら "fused atom" とみなす
     if((*atomVect)[i].Elmt == (*atomVect)[j].Elmt) {
      int serial_i = (*atomVect)[i].serialNum;
      int serial_j = (*atomVect)[j].serialNum;
      //シリアル番号の大きい方の idx をリストアップ
      checked_idx[serial_i > serial_j ? i : j] = 1;
     }
    }
   }
  }
 }

 //チェックされた原子のシリアル番号を返却
 checkedSerial.clear();
 checkedSerial.shrink_to_fit();
 for(int i = 0; i < atom_cnt; ++i)
  if(checked_idx[i])
   checkedSerial.push_back((*atomVect)[i].serialNum);
 return &checkedSerial;
}

const std::vector<int> *Atom_Near::excess_bond_serial() {
 checkedSerial.clear();
 checkedSerial.shrink_to_fit();

 int atom_cnt = atomVect->size();
 for(int i = 0; i < atom_cnt; ++i) {
  if((*atomVect)[i].Elmt == 1 && (*atomVect)[i].Bond.size() > 1)
   checkedSerial.push_back((*atomVect)[i].serialNum);
  if((*atomVect)[i].Elmt == 6 && (*atomVect)[i].Bond.size() > 4)
   checkedSerial.push_back((*atomVect)[i].serialNum);
  if((*atomVect)[i].Elmt == 7 && (*atomVect)[i].Bond.size() > 4)
   checkedSerial.push_back((*atomVect)[i].serialNum);
  if((*atomVect)[i].Elmt == 8 && (*atomVect)[i].Bond.size() > 3)
   checkedSerial.push_back((*atomVect)[i].serialNum);
 }

 return &checkedSerial;
}

