import { Injectable } from '@angular/core';
import { Observable, throwError } from 'rxjs';
import { map, switchMap } from 'rxjs/operators';
import { GraphUtilService, GridData, StandardCodonsEnum } from './graph-util.service';
import {
  AggregatedDocumentQuery,
  AggregatedDocumentResponse,
  DocumentServiceDocumentQuery,
} from '../../../nucleus/services/documentService/document-service.v1.http';
import { AbstractGraphDataService } from './abstract-graph-data.service';
import { GraphOption, HeatmapGraphOption } from './graph-aba-data.service';
import { IGridResourceResponse } from '../../../nucleus/services/models/response.model';
import { ResultDocumentFieldEnum as FIELDS } from './result-document-field.enum';
import { IBarChartInfo } from './column-chart/BarChartInfo.model';
import { SeriesColumnOptions } from 'highcharts';
import { IBarChartData } from './column-chart/column-chart.component';
import { HeatmapData, isHeatmapData } from './graph-heatmap/graph-heatmap.component';
import { SummaryGraphParams } from '../../core/ngs/ngs-summary-graph-viewer/ngs-summary-graph-viewer.component';
import { HeatmapInfo } from './graph-heatmap/heatmap-info.model';
import { DocumentTableQueryService } from '../../core/document-table-service/document-table-state/document-table-query.service';
import {
  Aggregation,
  AggregationKind,
  OrderByTableQueryKind,
  SearchTable,
} from '@geneious/nucleus-api-client';
import { sanitizeDTSTableOrColumnName } from '../../../nucleus/services/documentService/document-service.v1';
import { DocumentTableStateService } from '../../core/document-table-service/document-table-state/document-table-state.service';

@Injectable({
  providedIn: 'root',
})
export class GraphDocumentDataService implements AbstractGraphDataService {
  CODON_USAGE_TABLE = 'DOCUMENT_TABLE_CODON_USAGE';

  constructor(
    private documentTableQueryService: DocumentTableQueryService,
    private documentTableStateService: DocumentTableStateService,
  ) {}

  public getBarChart(docID: string, params: any, region: GraphOption): Observable<IBarChartInfo> {
    const table = region.table;
    return this.documentTableStateService.getTable(docID, table).pipe(
      switchMap((thisTable) => {
        const metadata = thisTable.metadata;
        const totalField = getTotalColumn(
          thisTable?.columns?.map((col) => col.displayName) ?? [],
          metadata?.clusters?.usedBeforeCollapsingFrequencies,
        );
        if (params.type === 'cluster_diversity') {
          const xKey = totalField;
          const yKey = 'yValue';
          const query: AggregatedDocumentQuery = {
            fields: [xKey, `COUNT('${sanitizeDTSTableOrColumnName(region.name)} ID') ${yKey}`],
            orderBy: [{ kind: 'ascending', field: totalField }],
            groupBy: [xKey],
          };
          return this.documentTableQueryService.queryTableAggregate(docID, table, query).pipe(
            map((response: AggregatedDocumentResponse) => {
              const name = `${region.name} Cluster Diversity`;
              const xLabel = 'Cluster size';
              const yLabel = 'Number of Clusters';
              return GraphDocumentDataService.formatAggregateBarChartData(
                response,
                name,
                xLabel,
                yLabel,
                xKey,
                yKey,
              );
            }),
          );
        } else if (params.type === 'cluster_sizes') {
          const xKey = sanitizeDTSTableOrColumnName(region.name);
          const yKey = totalField;

          const query: DocumentServiceDocumentQuery = {
            fields: [xKey, yKey],
            orderBy: [{ kind: 'descending', field: totalField }],
            limit: 25,
          };
          return this.documentTableQueryService
            .queryTable(docID, table, query)
            .pipe(
              map((data) =>
                GraphDocumentDataService.formatBarChartData(
                  data,
                  `${region.name} counts`,
                  `${region.name} (top 25)`,
                  'Number of sequences',
                  xKey,
                  yKey,
                ),
              ),
            );
        } else {
          // Cluster lengths.
          const xKey = FIELDS.LENGTH;
          const yKey = 'yValue';
          const query: AggregatedDocumentQuery = {
            fields: [xKey, `SUM(${totalField}) ${yKey}`],
            orderBy: [{ kind: 'ascending', field: FIELDS.LENGTH }],
            groupBy: [xKey],
          };
          return this.documentTableQueryService
            .queryTableAggregate(docID, table, query)
            .pipe(
              map((data) =>
                GraphDocumentDataService.formatAggregateBarChartData(
                  data,
                  `${region.name} Cluster lengths`,
                  `${region.name} length`,
                  'Number of Sequences',
                  xKey,
                  yKey,
                ),
              ),
            );
        }
      }),
    );
  }

  public getAminoAcidDistribution(
    documentID: string,
    region: GraphOption,
    length: number,
  ): Observable<SeriesColumnOptions[]> {
    return this.documentTableStateService.getTable(documentID, region.table).pipe(
      switchMap((thisTable) => {
        const metadata = thisTable?.metadata;
        const totalField = getTotalColumn(
          thisTable?.columns?.map((col) => col.displayName) ?? [],
          metadata?.clusters?.usedBeforeCollapsingFrequencies,
        );
        const aggs: Array<Aggregation> = Array.from({ length }, (_, i) => ({
          kind: AggregationKind.Terms,
          id: (i + 1).toString(),
          script: {
            source:
              "doc[params['clusterField']].getValue().substring(params['positionIndex'],params['positionIndex']+1)",
            params: {
              clusterField: `${sanitizeDTSTableOrColumnName(region.name)}.keyword`,
              positionIndex: i,
            },
          },
          // Assuming that the number of possible amino acid characters will never exceed this value.
          size: 50,
          sort: [{ field: '_count', kind: OrderByTableQueryKind.Descending }],
          subAggregations: [
            {
              id: 'sum',
              kind: AggregationKind.Sum,
              field: totalField,
            },
          ],
        }));
        const query: SearchTable = {
          fields: [],
          query: {
            queryString: `${FIELDS.LENGTH}:${length || 10}`,
          },
          aggregations: aggs,
          limit: 0,
        };
        return this.documentTableQueryService.searchTable(documentID, region.table, query).pipe(
          map((response) => {
            const aaDist = new Map<string, Map<number, number>>();

            for (let [position, agg] of Object.entries(
              response.data.aggregations as Record<string, { buckets: AggregationBucket[] }>,
            )) {
              agg.buckets.forEach((bucket) => {
                let countByPos = aaDist.get(bucket.key);
                if (!countByPos) {
                  countByPos = new Map<number, number>();
                  aaDist.set(bucket.key, countByPos);
                }
                const pos = parseInt(position);
                const count = countByPos.get(pos);
                countByPos.set(pos, count == null ? bucket.sum.value : count + bucket.sum.value);
              });
            }

            const orderedAminos = Array.from(aaDist.keys()).sort();
            const result: SeriesColumnOptions[] = orderedAminos.map((aminoAcid) => {
              return {
                name: aminoAcid.toUpperCase(),
                data: aaDist.has(aminoAcid)
                  ? Array.from(aaDist.get(aminoAcid)).map(([position, total]) => ({
                      x: position,
                      y: total,
                    }))
                  : [],
                type: 'column',
              };
            });
            return result;
          }),
        );
      }),
    );
    // Specifying length will excludes clusters of 'Unknown' residues.
  }

  /**
   * Find out all the lengths for an entire dataset; it's important that we get ALL the lengths
   * and not just some because otherwise it's misleading for our customers.
   */
  public getClusterLengths(documentID: string, region: GraphOption): Observable<number[]> {
    const query: AggregatedDocumentQuery = {
      // No need to get any other fields back so reduce data sent from the server.
      fields: [FIELDS.LENGTH],
      // To get all the lengths for all clusters we must set this number higher than the max length.
      // 10K is a bit higher than we need but this ensures that for any known molecule we get them (no
      // molecule we have yet encountered (even with accompanying sequence around it) is 10K residues long).
      limit: 10000,
      // It's critical that we request grouped rows. If we don't then to be sure we'd found ALL
      // the lengths we'd have to get every row - which would take too long for large datasets.
      groupBy: [FIELDS.LENGTH],
    };

    return this.documentTableQueryService.queryTableAggregate(documentID, region.table, query).pipe(
      map((response: AggregatedDocumentResponse) =>
        response.data.Length.buckets.map((bucket) => bucket.key),
      ),
      map((lengths: number[]) =>
        lengths
          .filter((length) => !!length)
          // By default, the sort method sorts elements alphabetically.
          // @see https://stackoverflow.com/questions/1063007/how-to-sort-an-array-of-integers-correctly
          .sort((a, b) => a - b),
      ),
    );
  }

  public getFormattedCodonUsageDataByLength(
    documentID: string,
    region: GraphOption,
    length: any,
    tableFormat: boolean,
  ): Observable<HeatmapData | GridData> {
    const dataRequest = this.getCodonUsageDataByLength(documentID, region, length);
    const excludedCols = [FIELDS.LENGTH, FIELDS.REGION, FIELDS.TOTAL_SEQS, FIELDS.ROW_NUMBER];
    if (tableFormat) {
      return dataRequest.pipe(
        map((data) => {
          data.forEach((item) => (item[FIELDS.POSITION] = `Position ${item[FIELDS.POSITION]}`));
          const flipped = GraphUtilService.transposeDataList(
            data,
            FIELDS.POSITION,
            FIELDS.CODON,
            excludedCols,
          ).map((row) => {
            row[FIELDS.CODON] = row[FIELDS.CODON].toUpperCase();
            // Add corresponding amino acid value to each row to form a new column.
            row[FIELDS.AMINO] = GraphUtilService.getFormattedAminoAcid(
              StandardCodonsEnum[row[FIELDS.CODON] as keyof typeof StandardCodonsEnum],
            );
            return row;
          });
          return GraphUtilService.rowsToTable(flipped, FIELDS.AMINO);
        }),
      );
    } else {
      // Heatmap format.
      return dataRequest.pipe(
        map((data) => {
          const options = { hasCodonCols: true };
          return GraphUtilService.rowsToHeatmap(data, excludedCols, FIELDS.POSITION, options);
        }),
      );
    }
  }

  public getHeatmapChart(
    docID: string,
    params: SummaryGraphParams,
    options: HeatmapGraphOption,
  ): Observable<HeatmapInfo> {
    if (params.source === 'codon') {
      const isAllLengthsChart = options.length === 'All lengths';
      const data = isAllLengthsChart
        ? this.getFormattedCodonUsageSummaryData(docID, options.region, options.isTable)
        : this.getFormattedCodonUsageDataByLength(
            docID,
            options.region,
            options.length,
            options.isTable,
          );

      const sharedInfo = {
        title: 'Codon Distribution Chart',
        xAxisTitle: 'Codon',
        yAxisTitle: isAllLengthsChart ? 'Length' : 'Position',
      };

      return data.pipe(
        map((graphData) => {
          if (!isHeatmapData(graphData)) {
            return {
              ...sharedInfo,
              data: graphData,
              type: 'table',
            };
          }

          return {
            ...sharedInfo,
            data: graphData,
            type: 'heatmap',
          };
        }),
      );
    }

    return throwError(() => 'Unsupported table source.');
  }

  public getFormattedCodonUsageSummaryData(
    documentID: string,
    region: GraphOption,
    isTable: boolean,
  ): Observable<GridData | HeatmapData> {
    const dataRequest = this.getCodonUsageSummaryData(documentID, region);
    if (isTable) {
      return dataRequest.pipe(map((data) => this.formatCodonUsageSummaryTable(data)));
    } else {
      // Is heatmap.
      return dataRequest.pipe(
        map((data) => {
          data.unshift(this.calculateOverallCodonDistribution(data, 'All lengths'));
          const dataAsPercentage = GraphUtilService.rowsToPercentage(
            data,
            [FIELDS.LENGTH],
            FIELDS.TOTAL_CODONS,
            1,
            6,
          );
          const options = { isPercent: true, hasCodonCols: true };
          return GraphUtilService.rowsToHeatmap(
            dataAsPercentage,
            [FIELDS.TOTAL_CODONS],
            FIELDS.LENGTH,
            options,
          );
        }),
      );
    }
  }

  static transformBarChartData(rawData: any[], xKey: string, yKey: string): SeriesColumnOptions[] {
    const formatted = rawData.map((row) => ({ label: row[xKey], yValue: parseFloat(row[yKey]) }));
    const binned = this.binXAxis(formatted);
    return [
      {
        name: xKey,
        data: binned.map((d) => [`${d.label}`, d.yValue]),
        type: 'column',
        color: '#03a9f4',
      },
    ];
  }

  private static binXAxis(data: IBarChartData[]): IBarChartData[] {
    // Given an array of data with labels and values, this method will determine if there are too many to display
    // and verify they are integers then bin them into preset bin sizes for display to the user.

    // Only bin if there are more elements than can be easily visualised.
    if (data.length <= 75) {
      return data;
    }

    // Only bin if all labels are integers.
    if (!data.some((element: any) => /^[0-9]*$/.test(element.label))) {
      return data;
    }

    const bins = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1000];
    let binnedData = [];

    // Iterate over all bins and sum the data values that fall in each one.
    for (let i = 0; i < bins.length; i++) {
      let binTotal = 0;
      const binBottom = bins[i];
      const binTop = bins[i + 1];

      data.forEach((element: any) => {
        if (element.label >= binBottom && (!binTop || element.label < binTop)) {
          binTotal += element.yValue;
        }
      });

      // Conditional labels for bins: single digits maintained and upper bound noted, otherwise range.
      let binLabel = '';
      if (binTop === undefined) {
        binLabel = binBottom + '+';
      } else if (binTop <= 10) {
        binLabel = String(binBottom);
      } else {
        binLabel = binBottom + '-' + binTop;
      }

      binnedData.push({ label: binLabel, yValue: binTotal });
    }

    // Remove all bins from the start that have no observed values.
    const firstNonzeroIndex = binnedData.findIndex((bin) => bin.yValue !== 0);
    binnedData = firstNonzeroIndex === -1 ? [] : binnedData.slice(firstNonzeroIndex);

    // Remove all bins from the end that have no observed values.
    for (let i = binnedData.length - 1; i >= 0; i--) {
      if (binnedData[i].yValue === 0) {
        binnedData.splice(i, 1);
      } else {
        // Stop as soon as a bin with an observed value is found, preserving all smaller bins with no observed values.
        break;
      }
    }

    return binnedData;
  }

  private formatCodonUsageSummaryTable(data: any[]): GridData {
    // Get statistics averaged over all lengths.
    const overallRowRaw = this.calculateOverallCodonDistribution(data, '% of any AA');
    const aminoComparison = this.calculateCodonDistributionByAA(overallRowRaw);

    data.forEach((row) => (row[FIELDS.LENGTH] = `Length ${row[FIELDS.LENGTH]} (% of any AA)`));
    data.unshift(overallRowRaw);

    // Convert everything to percentages so that it is more intuitive.
    const summaryRowsPercent = GraphUtilService.rowsToPercentage(
      data,
      [FIELDS.LENGTH],
      FIELDS.TOTAL_CODONS,
      100,
      3,
    );

    // Amino Comparison is already a percentage.
    summaryRowsPercent.unshift(aminoComparison);

    // Flip columns and rows so that it looks better.
    const flipped = GraphUtilService.transposeDataList(
      summaryRowsPercent,
      FIELDS.LENGTH,
      FIELDS.CODON,
      [FIELDS.TOTAL_CODONS],
    ).map((row) => {
      // Add corresponding amino acid value to each row to form a new column.
      row[FIELDS.AMINO] = GraphUtilService.getFormattedAminoAcid(
        StandardCodonsEnum[row[FIELDS.CODON] as keyof typeof StandardCodonsEnum],
      );
      return row;
    });

    const FIXED_PARAMS = { [FIELDS.AMINO]: 1, [FIELDS.CODON]: 2 };
    const sortRowHeaders = GraphUtilService.getCompareWithFixed(FIXED_PARAMS);
    return GraphUtilService.rowsToTable(flipped, FIELDS.AMINO, sortRowHeaders);
  }

  private calculateOverallCodonDistribution(data: any, name: string) {
    const overallRow = data.reduce((sumRow: any, row: any) => {
      Object.keys(row)
        .filter((key) => row.hasOwnProperty(key) && key !== FIELDS.LENGTH)
        .forEach((codon) => (sumRow[codon] = (sumRow[codon] || 0) + parseFloat(row[codon])));
      return sumRow;
    }, {});
    overallRow[FIELDS.LENGTH] = name;
    return overallRow;
  }

  private calculateCodonDistributionByAA(overallRow: any) {
    // Create a row with statistics of the percentage prevalence of a codon amongst all codons coding for the same amino acid.
    const codonKeys = Object.keys(overallRow).filter(
      (key) => !!StandardCodonsEnum[key as keyof typeof StandardCodonsEnum],
    );
    const aminoTotals = codonKeys.reduce((aminoTotals: any, codon) => {
      const aminoAcid = StandardCodonsEnum[codon as keyof typeof StandardCodonsEnum];
      aminoTotals[aminoAcid] = (aminoTotals[aminoAcid] || 0) + overallRow[codon];
      return aminoTotals;
    }, {});

    const aminoComparison = {
      Length: '% of same AA',
      totalCodons: overallRow[FIELDS.TOTAL_CODONS],
    };
    codonKeys.forEach((codon) => {
      const aminoAcid = StandardCodonsEnum[codon as keyof typeof StandardCodonsEnum];
      const aminoTotal = parseFloat(aminoTotals[aminoAcid]);
      const percentage =
        aminoTotal && aminoTotal > 0 ? (parseFloat(overallRow[codon]) / aminoTotal) * 100 : 0;
      aminoComparison[codon.toUpperCase() as keyof typeof aminoComparison] = +percentage.toFixed(3);
    });

    return aminoComparison;
  }

  private getCodonUsageDataByLength(
    documentID: string,
    region: GraphOption,
    length: number,
  ): Observable<any[]> {
    const query = {
      fields: [] as any[],
      where: `${FIELDS.LENGTH}=${length || 10} AND ${FIELDS.REGION}='${region.name}'`,
    };
    return this.documentTableQueryService
      .queryTable(documentID, this.CODON_USAGE_TABLE, query)
      .pipe(map((response) => response.data));
  }

  private getCodonUsageSummaryData(documentID: string, region: GraphOption): Observable<any[]> {
    // Formulate a query sub-string representing a list of Sums of each column that we want returned.
    const codonSums = Object.keys(StandardCodonsEnum).map((codon) => `SUM(${codon}) ${codon}`);

    const query = {
      fields: [FIELDS.LENGTH, `SUM(${FIELDS.TOTAL_SEQS}) totalCodons`].concat(codonSums),
      where: `${FIELDS.REGION}='${region.name}'`,
      groupBy: [FIELDS.LENGTH],
    };

    return this.documentTableQueryService
      .queryTableAggregate(documentID, this.CODON_USAGE_TABLE, query)
      .pipe(
        map((result) =>
          GraphDocumentDataService.getRowsFromAggregationResponse(result, FIELDS.LENGTH),
        ),
      );
  }

  static formatAggregateBarChartData(
    response: AggregatedDocumentResponse,
    name: string,
    xLabel: string,
    yLabel: string,
    xKey: string,
    yKey: string,
  ): IBarChartInfo {
    const data = this.getRowsFromAggregationResponse(response, xKey);
    const metadata = { total: data.length, offset: 0, limit: data.length };
    return GraphDocumentDataService.formatBarChartData(
      { data, metadata },
      name,
      xLabel,
      yLabel,
      xKey,
      yKey,
    );
  }

  static formatBarChartData(
    response: IGridResourceResponse<any>,
    name: string,
    xLabel: string,
    yLabel: string,
    xKey: string,
    yKey: string,
  ): IBarChartInfo {
    return {
      title: name,
      xLabel: xLabel,
      yLabel: yLabel,
      data: GraphDocumentDataService.transformBarChartData(response.data, xKey, yKey),
    };
  }

  static getRowsFromAggregationResponse(response: any, aggregationKey: string) {
    return response.data[aggregationKey].buckets.map((bucket: any) => {
      const row: any = {};
      Object.keys(bucket)
        .filter((prop) => bucket.hasOwnProperty(prop) && typeof bucket[prop].value !== 'undefined')
        .forEach((prop) => (row[prop] = bucket[prop].value));
      row[aggregationKey] = bucket.key;
      return row;
    });
  }
}

interface AggregationBucket {
  key: string;
  sum: { value: number };
}

export function getTotalColumn(columnNames: string[], useBeforeCollapsing: boolean) {
  return (
    (useBeforeCollapsing ? '# of Sequences' : columnNames.find((name) => name === '# of Clones')) ??
    FIELDS.TOTAL
  );
}
