#include "unit.h"
#include <queue>
#include <openbabel/mol.h>
#include <openbabel/obconversion.h>

Unit::Unit()
 : covRadius(ATOM_DISTANCE.cov_radius()),
   vdwRadius(ATOM_DISTANCE.vdw_radius()),
   ionRadius(ATOM_DISTANCE.ion_radius()),
   unitNum(-1),
   AtomNear(&atomVect, &serialToIdx, &molVect),
   CHarge(&atomVect, &serialToIdx, &molVect),
   covLim(ATOM_DISTANCE.cov_lim())
{
 for(int i = 1; i != ELMT_CNT; ++i)
  symbolToNum[SYMBOL[i]] = i;
}

bool Unit::read_file(const std::string *name, const std::string *type) {
 bool result = true;

 clear();

 if(*type == "xyz")
  result = read_xyz(name);
 else if(*type == "bcl")
  result = read_bcl(name);
 else if(*type == "pdb")
  result = read_pdb(name);
 else if(*type == "cif")
  result = read_cif(name);
 if(!result) {
  std::cerr << "error_ : " << __PRETTY_FUNCTION__ << std::endl;
  return result;
 }
 set_atom_num();
 create_serial_to_idx_map(&serialToIdx, &atomVect);

 if(*type == "bcl")
  detect_one_way_bond();
 else
  detect_bond();

 return result;
}

void Unit::create_data_structure() {
 for(Atom &atom : atomVect)
  atom.molIdx = -1;
 create_mol_vect();
 detect_main_chain();
 sequencing();
 capture_side_chain();
 update_protein_num_data();
}

void Unit::clear() {
 atomVect.clear();
 atomVect.shrink_to_fit();
 serialToIdx.clear();
 molVect.clear();
 molVect.shrink_to_fit();
 AtomNear.clear();
 Comment = "void Unit::clear()";
}

bool Unit::read_xyz(const std::string *name) {
 File_IO file;
 if(!file.read_xyz(name, &atomVect, &Comment)) {
  clear();
  return false;
 }

 return true;
}

bool Unit::read_bcl(const std::string *name) {
 File_IO file;
 if(!file.read_bcl(name, &atomVect, &Comment)) {
  clear();
  return false;
 }

 return true;
}

bool Unit::read_pdb(const std::string *name) {
 File_IO file;
 if(!file.read_pdb(name, &atomVect)) {
  clear();
  return false;
 }

 return true;
}

bool Unit::read_cif(const std::string *name) {
 File_IO file;
 if(!file.read_cif(name, &atomVect)) {
  clear();
  return false;
 }

 return true;
}

void Unit::set_atom_num() {
 for(Atom &atom : atomVect) {
  atom.Symbol[0] = (char)std::toupper(atom.Symbol[0]);
  int size = atom.Symbol.size();
  if(size > 1)
   atom.Symbol[1] = (char)std::tolower(atom.Symbol[1]);
  if(size > 2)
   atom.Symbol[2] = (char)std::tolower(atom.Symbol[2]);
  atom.Elmt = symbolToNum[atom.Symbol];
 }
}

void Unit::detect_bond() {
 const std::array<std::array<double, ELMT_CNT>, ELMT_CNT>
        &vdw_contact = ATOM_DISTANCE.vdw_contact();
 int atom_cnt = atomVect.size();
 double dist0 = 0;
 double dist1 = 0;
 int elmt0 = 0;
 int elmt1 = 0;
 for(int i = 1; i < atom_cnt; ++i) {
  elmt0 = atomVect[i].Elmt;
  for(int j = 0; j != i; ++j) {
   //分岐限定法．少しだけ早くなる
   dist0 = fabs(atomVect[j].XYZ[0] - atomVect[i].XYZ[0]);
   elmt1 = atomVect[j].Elmt;
   if(fabs(dist0) < vdw_contact[elmt0][elmt1]) {
    dist1 = (atomVect[j].XYZ - atomVect[i].XYZ).norm();
    //同じ組み合わせは 1 回しか走査しないので，push_back() は 2 行必要
    if(dist1 < covLim[elmt0][elmt1]
        && judgement_cov(elmt0, elmt1)) {
     atomVect[i].Bond.push_back(atomVect[j].serialNum);
     atomVect[j].Bond.push_back(atomVect[i].serialNum);
     if(atomVect[i].serialNum < atomVect[j].serialNum)
      atomVect[i].BondUp.push_back(atomVect[j].serialNum);
     else
      atomVect[j].BondUp.push_back(atomVect[i].serialNum);
    }
   }
  }
 }
}

void Unit::detect_one_way_bond() {  //BCL 形式のファイルを開くとき専用
 int atom_cnt = atomVect.size();
 for(int i = 0; i != atom_cnt; ++i) {
  int serial_i = atomVect[i].serialNum;
  const std::vector<int> &bond_bi = atomVect[i].Bond;
  int bond_cnt = bond_bi.size();
  for(int j = 0; j != bond_cnt; ++j) {
   int serial_j = bond_bi[j];
   //同じ組み合わせを 2 回走査することになるので，push_back() は 1 行でよい
   if(serial_i < serial_j)
    atomVect[i].BondUp.push_back(serial_j);
  }
 }
}

bool Unit::judgement_cov(int elmt0, int elmt1) {
 //COO Na
 if((elmt0 == 8 && elmt1 == 11) || (elmt1 == 8 && elmt0 == 11))
  return false;
 //COO K
 if((elmt0 == 8 && elmt1 == 12) || (elmt1 == 8 && elmt0 == 12))
  return false;
 //COO Mg
 if((elmt0 == 8 && elmt1 == 19) || (elmt1 == 8 && elmt0 == 19))
  return false;
 //COO Ca
 if((elmt0 == 8 && elmt1 == 20) || (elmt1 == 8 && elmt0 == 20))
  return false;

 //NH4 Cl
 if((elmt0 == 1 && elmt1 == 17) || (elmt1 == 1 && elmt0 == 17))
  return false;
 //NH4 Cl
 if((elmt0 == 7 && elmt1 == 17) || (elmt1 == 7 && elmt0 == 17))
  return false;

 return true;
}

void Unit::create_mol_vect() {
 //molVect をクリア
 molVect.clear();
 molVect.shrink_to_fit();

 //atomVect[i].molIdx を順次チェック．molIdx == -1 なら分子をキャプチャ
 int atom_cnt = atomVect.size();
 for(int i = 0; i < atom_cnt; ++i)
  atomVect[i].molIdx = -1;

 int mol_idx = 0;
 for(int i = 0; i < atom_cnt; ++i) {
  if(atomVect[i].molIdx < 0) {
   molVect.push_back(Mol(&atomVect, &serialToIdx));  //この段階では空
   capture_mol(i, mol_idx++);  //最初の原子インデックスと分子インデックス
  }
 }
}

void Unit::capture_mol(int atom_idx, int mol_idx) {
 std::queue<int> qserial;  //処理済みの原子のシリアル番号を格納するキュー

 //引数の原子は処理済みだからそのまま格納
 atomVect[atom_idx].molIdx = mol_idx;
 int atom_serial = atomVect[atom_idx].serialNum;
 molVect[mol_idx].atomSerial.push_back(atom_serial);
 qserial.push(atom_serial);

 //残りは結合先をたどる
 while(!qserial.empty()) {
  //処理済みの原子を取得
  int serial = qserial.front();
  qserial.pop();
  int idx = serialToIdx[serial];

  //結合先の原子を処理
  const std::vector<int> &target_bond = atomVect[idx].Bond;
  for(int target_serial : target_bond) {
   int target_idx = serialToIdx[target_serial];  //serial -> idx
   if(atomVect[target_idx].molIdx != mol_idx) {  //結合相手が未走査なら
    atomVect[target_idx].molIdx = mol_idx;
    molVect[mol_idx].atomSerial.push_back(target_serial);
    qserial.push(target_serial);  //結合相手の結合相手は，走査未完了
   }
  }
 }
}

void Unit::detect_main_chain() {
 for(Mol &mol : molVect)
  mol.detect_main_chain();
}

void Unit::sequencing() {
 for(Mol &mol : molVect)
  mol.sequencing();
}

void Unit::capture_side_chain() {
 for(Mol &mol : molVect)
  mol.capture_side_chain();
}

void Unit::update_protein_num_data() {
 //原子の変更などでは，分子単位で処理する場合がある
 int mol_cnt = molVect.size();
 for(int i = 0; i != mol_cnt; ++i) {
  molVect[i].update_protein_num_data();
 }
}

bool Unit::save(const std::string *name, const std::string *type) {
 File_IO io;
 if(*type != "pdb")
  return io.write_bcl_or_xyz(name, type, &atomVect, &serialToIdx, &Comment);
 else
  return io.write_pdb(name, &molVect, &atomVect, &serialToIdx, &Comment);

 return true;
}

bool Unit::export_fasta(const std::string *name) {
 File_IO io;
 return io.write_fasta(name, &molVect);
}

void Unit::translate(const Eigen::Vector3d *vect) {
 for(Atom &atom : atomVect)
  atom.XYZ += *vect;
}

bool Unit::contain_peptide() const {
 for(const Mol &mol : molVect)
  for(const std::vector<Amino_Acid> &peptide : mol.aaTable)
   if(!peptide.empty())
    return true;
 return false;
}

void Unit::get_xyz_for_superimpose
        (int layer, int mol, int pept,
        const std::array<std::vector<int>, 2> *aa_num,
        std::vector<Eigen::Vector3d> *container) const
{
 //面倒なので Mol には降りていかない
 const std::vector<Amino_Acid> &peptide =
        molVect[mol].aaTable[pept];
 int cnt = (*aa_num)[0].size();  //0 も 1 も同じサイズ
 for(int i = 0; i < cnt; ++i) {
  //ともにギャップでなければ重ね合わせの回転行列の作成に使える
  if((*aa_num)[0][i] > -1 && (*aa_num)[1][i] > -1) {
   int pept_num = (*aa_num)[layer][i];
   int serial = peptide[pept_num].Annot.at(N);
   int idx = serialToIdx.at(serial);
   container->push_back(atomVect[idx].XYZ);

   serial = peptide[pept_num].Annot.at(CA);
   idx = serialToIdx.at(serial);
   container->push_back(atomVect[idx].XYZ);

   serial = peptide[pept_num].Annot.at(C);
   idx = serialToIdx.at(serial);
   container->push_back(atomVect[idx].XYZ);
  }
 }
}

void Unit::get_xyz_for_superimpose(const std::vector<int> *serial,
        std::vector<Eigen::Vector3d> *container) const
{
 for(int num : *serial) {
  int idx = serialToIdx.at(num);
  container->push_back(atomVect[idx].XYZ);
 }
}

bool Unit::import_via_babel(const std::string *name, const std::string *type)
{
 //Open Babel で XYZ 形式に変換
 std::string type_out("xyz");
 std::string tmp_name = *name + "." + type_out;
 try {
  std::ifstream ifs(name->data());
  std::ofstream ofs(tmp_name.data());
  OpenBabel::OBConversion conv(&ifs, &ofs);
  conv.SetInAndOutFormats(type->data(), "XYZ");
  OpenBabel::OBMol mol;
  conv.Read(&mol);
  conv.Write(&mol);
 }
 catch(...) {
  std::cerr << "Open Babel で一時ファイルの出力に失敗しました．" << std::endl;
  return false;
 }

 //XYZ 形式に変換したファイルを読み込み
 if(!read_file(&tmp_name, &type_out))
  return false;

 //XYZ 形式の一時ファイルを削除
 std::remove(tmp_name.data());
 return true;
}

bool Unit::export_via_babel(const std::string *name, const std::string *type)
{
 //一時ファイルを PDB 形式で出力
 std::string tmp_type("pdb");
 std::string tmp_name = *name + "." + tmp_type;
 if(!save(&tmp_name, &tmp_type))
  return false;

 //Open Babel で type 形式に変換
 try {
  std::ifstream ifs(tmp_name.data());
  std::ofstream ofs(name->data());
  OpenBabel::OBConversion conv(&ifs, &ofs);
  conv.SetInAndOutFormats("PDB", type->data());
  OpenBabel::OBMol mol;
  conv.Read(&mol);
  conv.Write(&mol);
 }
 catch(...) {
  std::cerr << "Open Babel でファイル変換に失敗しました．" << std::endl;
  return false;
 }

 //PDB 形式の一時ファイルを削除
 std::remove(tmp_name.data());
 return true;
}

void Unit::set_orbital() {
 for(Atom &atom : atomVect) {
  int bond_cnt = atom.Bond.size();
  if(bond_cnt == 4) {
    atom.Orbital = SP3;
  }
  else if(bond_cnt == 3) {
   if(atom.Elmt == 6)
    atom.Orbital = SP2;
   else
    atom.Orbital = SP3;
  }
  else if(bond_cnt == 2) {
   if(atom.Elmt == 6)
    atom.Orbital = SP;
   else if(atom.Elmt == 7)
    atom.Orbital = SP2;
   else if(atom.Elmt == 8)
    atom.Orbital = SP3;
  }
  else if(bond_cnt == 1) {
   if(atom.Elmt == 1)
    atom.Orbital = SP3;
   else if(atom.Elmt == 8)
    atom.Orbital = SP2;
  }
 }
}

