#include "statistics_draw.h"
#include <complex>

Statistics_Draw::Statistics_Draw(QWidget *parent,
        std::vector<Atom> *atom,
        const std::vector<Mol> *mol,
        const std::map<int, int> *serial,
        const Shared_Prm *prm)
 : Draw_Base(parent, atom, mol, serial, prm)
{
 calc_mean_var_cov();
 calc_eigen_vect();
 create_objects();
}

void Statistics_Draw::calc_mean_var_cov() {
 //平均
 Mean.fill(0.0);
 for(const Atom &atom : *AtomVect)
  for(int i = 0; i < 3; ++i)
   Mean[i] += atom.XYZ[i];

 double n = AtomVect->size();
 for(double &m : Mean)
  m /= n;

 //分散
 for(int i = 0; i < 3; ++i)
  calc_var(i);

 //共分散
 for(int i = 0; i < 3; ++i)
  for(int j = i; j < 3; ++j) {
   calc_cov(i, j);
  }
}

void Statistics_Draw::calc_var(int c0) {
 CovMtx(c0, c0) = 0.0;
 for(const Atom &atom : *AtomVect)
  CovMtx(c0, c0) += pow((atom.XYZ[c0] - Mean[c0]), 2.0);
 CovMtx(c0, c0) /= (double)AtomVect->size();
}

void Statistics_Draw::calc_cov(int c0, int c1) {
 for(const Atom &atom : *AtomVect)
  CovMtx(c0, c1)
        += (atom.XYZ[c0] - Mean[c0]) * (atom.XYZ[c1] - Mean[c1]);
 CovMtx(c0, c1) /= (double)AtomVect->size();
 CovMtx(c1, c0) = CovMtx(c0, c1);
}

void Statistics_Draw::calc_eigen_vect() {
 Eigen::EigenSolver<Eigen::Matrix3d> solv;
 solv.compute(CovMtx);

 Eigen::MatrixXcd e_vect(3, 3);  //複素数型でないとコンパイルエラー
 e_vect = solv.eigenvectors();

 Eigen::VectorXcd e_val(3);  //複素数型でないとコンパイルエラー
 e_val = solv.eigenvalues();

 for(int i = 0; i < 3; ++i) {
  EigenVal(i) = e_val(i).real();
  for(int j = 0; j < 3; ++j)
   EigenVect(i, j) = e_vect(i, j).real();
 }
}

void Statistics_Draw::create_objects() {
 create_wire();
 create_cone();
}

void Statistics_Draw::create_wire() {
 WireForDraw.clear();
 WireForDraw.shrink_to_fit();

 int color_num = 0;
 Wire wire;
 for(const Mol &mol : *MolVect) {
  const std::vector<int> &serial_vect = mol.atomSerial;
  for(int serial0 : serial_vect) {
   int idx0 = SerialToIdx->at(serial0);
   const Atom &atom0 = (*AtomVect)[idx0];
   const std::vector<int> &bond = atom0.BondUp;
   for(int serial1 : bond) {
    int idx1 = SerialToIdx->at(serial1);
    wire.Pos0 = &(*AtomVect)[idx0].XYZ;
    wire.Pos1 = &(*AtomVect)[idx1].XYZ;
    for(int j = 0; j != 3; ++j)
     wire.Color[j] = molColor[color_num][j];

    WireForDraw.push_back(wire);
   }
  }
  color_num = (color_num + 1)  % 6;
 }
}
void Statistics_Draw::create_cone() {
 ConeForDraw.clear();
 ConeForDraw.shrink_to_fit();

 std::array<double, 3> val{0.0, 0.0, 0.0};
 for(int i = 0; i < 3; ++i) {
  if(EigenVal[i] < 0.0)
   val[i] = -sqrt(-EigenVal[i]) * 2.0;
  else
   val[i] = sqrt(EigenVal[i]) * 2.0;
 }

 const Eigen::Vector3d *center = Affine.get_center();
 for(int i = 0; i < 6; ++i)
  ConeXYZ[i][0] = *center;

 ConeXYZ[0][1]
        = EigenVect.col(0) * val[0] + *center;
 ConeXYZ[1][1]
        = -EigenVect.col(0) * val[0] + *center;
 ConeXYZ[2][1]
        = EigenVect.col(1) * val[1] + *center;
 ConeXYZ[3][1]
        = -EigenVect.col(1) * val[1] + *center;
 ConeXYZ[4][1]
        = EigenVect.col(2) * val[2] + *center;
 ConeXYZ[5][1]
         = -EigenVect.col(2) * val[2] + *center;

 Cone cone;
 cone.Radius = 1.0;
 ConeForDraw.assign(6, cone);

 for(int i = 0; i < 2; ++i) {
  ConeForDraw[i].Pos0 = &ConeXYZ[i][0];
  ConeForDraw[i].Pos1 = &ConeXYZ[i][1];
  for(int j = 0; j < 3; ++j)
   ConeForDraw[i].Color[j] = molColor[0][j];
 }
 for(int i = 2; i < 4; ++i) {
  ConeForDraw[i].Pos0 = &ConeXYZ[i][0];
  ConeForDraw[i].Pos1 = &ConeXYZ[i][1];
  for(int j = 0; j < 3; ++j)
   ConeForDraw[i].Color[j] = molColor[1][j];
 }
 for(int i = 4; i < 6; ++i) {
  ConeForDraw[i].Pos0 = &ConeXYZ[i][0];
  ConeForDraw[i].Pos1 = &ConeXYZ[i][1];
  for(int j = 0; j < 3; ++j)
   ConeForDraw[i].Color[j] = molColor[2][j];
 }
}

void Statistics_Draw::rotate_other_object_xy(int ox, int oy) {
 double dx = (double)(ox - MouseX) / (double)height() * 360.0;
 double dy = (double)(oy - MouseY) / (double)width() * 360.0;
 for(std::array<Eigen::Vector3d, 2> &row : ConeXYZ) {
  Affine.rotate_point_xy(&row[0], dx, dy);
  Affine.rotate_point_xy(&row[1], dx, dy);
 }
}

