import snapatac2 as snap
import snapatac2._snapatac2
import numpy as np
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
from subprocess import run
import os
import click
import json
import gzip
import scipy.io


def makedict(chrlenfile):
    d = {}
    size_list = []
    with open(chrlenfile,'r') as fh:
        for line in fh:
            chrom = str(line.strip().split('\t')[0])
            lenth = int(line.strip().split('\t')[1])
            d[chrom] = lenth
            size_list.append(f'{chrom}:1-{lenth}')
    return d, size_list

def calulate_chunck_size(detail_file):
    chunck_size = 5000000
    with open(detail_file) as f:
        line_count = sum(1 for _ in f)
    n_split = int(line_count/chunck_size)
    if n_split == 0:
        n_split = 1
    chunck_size = int(line_count/n_split) + 1
    return chunck_size

def get_int_type(max_value):
    if max_value <= 127:
        return np.int8
    elif max_value <= 32767:
        return np.int16
    elif max_value <= 2147483647:
        return np.int32
    else:
        return np.int64

@click.command(context_settings=dict(help_option_names=['-h', '--help']))
@click.option('--bam', required=True,help='atac. step2/bwa_pe/asample_mem_pe_Sort.bam')
@click.option('--atacjson', required=True,help='atac. summary.json ')
@click.option('--gexjson', required=True,help='gex. summary.json ')
@click.option('--outdir', required=True,help='output dir')
@click.option('--samplename', required=True,help='sample name')
@click.option('--countxls', required=True,help='gex. step3/counts.xls')
@click.option('--detailxls', required=True,help='gex. step3/detail.xls')
@click.option('--species', required=True,help='human or mouse')
@click.option('--refpath', required=True,help='reference path')
@click.option('--bedtoolspath', required=True,help='bedtools path')
@click.option("--core", default=8, show_default=True, help="Set max number of cpus that pipeline might request at the same time.")
@click.option("--qvalue", default=0.05, show_default=True, help="macs3 parameter")
@click.option("--nolambda", is_flag=True, default=False, show_default=True, help="macs3 parameter")
@click.option("--shift", default=0, show_default=True, help="macs3 parameter")
@click.option("--extsize", default=400, show_default=True, help="macs3 parameter")
@click.option("--min_len", default=400, show_default=True, help="macs3 parameter")
@click.option("--blacklist", default=None, show_default=True, help="macs3 parameter")

def runpipe(bam, atacjson, gexjson, outdir, samplename, countxls, detailxls, species, refpath, bedtoolspath, core, qvalue, nolambda, shift, extsize, min_len, blacklist):
    chrNameLength = os.path.join(refpath, 'star/chrNameLength.txt')
    assert os.path.exists(chrNameLength), f'{chrNameLength} not found!'
    gtf = os.path.join(refpath, 'genes/genes.gtf')
    assert os.path.exists(gtf), f'{gtf} not found!'

    qc = {
        "Sequencing": {},
        "Cells": {},
        "Mapping": {},
        "Targeting": {},
        "median": {},
        "peaks_target": {},
        "tss": {},
        "insert": {},
        "joint_cell": {}
    }
    qc["refpath"] = refpath
    qc["Organism"] = species

    # make fragments and QC
    fragments_file = os.path.join(outdir, samplename+"_fragments.tsv.gz")
    bam_qc = snap.pp.make_fragment_file(
        bam_file=bam,
        output_file=fragments_file,
        barcode_regex = "^(.*?)_",
        compression = 'gzip',
        compression_level = 6)


    with open(atacjson, "r") as fh:
        atac_summary = json.load(fh)
    with open(gexjson, "r") as fh:
        gex_summary = json.load(fh)

    qc["Sequencing"]["Sequenced read pairs"] = int(atac_summary["stat"]["total"])
    qc["Sequencing"]["Valid barcodes"] = atac_summary["stat"]["valid"]/atac_summary["stat"]["total"]
    qc["Sequencing"]["Too short"] = atac_summary["stat"]["valid"] - atac_summary["stat"]["step1_available"]
    b_total_base = sum([sum(v) for v in atac_summary["barcode_q"].values()])
    b30_base = sum([sum(v[30:]) for v in atac_summary["barcode_q"].values()])
    qc["Sequencing"]["Q30 bases in barcode"] = b30_base/b_total_base
    qc["Sequencing"]["Q30 bases in read 1"] = bam_qc["frac_q30_bases_read1"]
    qc["Sequencing"]["Q30 bases in read 2"] = bam_qc["frac_q30_bases_read2"]
    qc["Sequencing"]["Percent duplicates"] = bam_qc["frac_duplicates"]
    qc["Mapping"]["Confidently mapped read pairs"] = bam_qc["frac_confidently_mapped"]
    qc["Mapping"]["Unmapped read pairs"] = bam_qc["frac_unmapped"]
    qc["Mapping"]["Non-nuclear read pairs"] = bam_qc["frac_nonnuclear"]

    # make Anndata obj
    size_dict, size_list = makedict(chrNameLength)
    atac = snap.pp.import_data(fragment_file=fragments_file, 
                               chrom_sizes=size_dict,
                               min_num_fragments=0)


    # plot fragments size fenbu
    snap.pl.frag_size_distr(atac, out_file=os.path.join(outdir, samplename+"_fragments_size.png"))
    # fragments size fenbu data to json
    frag_size_list = list(range(1001))
    frag_count_list = list(atac.uns['frag_size_distr'])
    qc["insert"]["size"] = frag_size_list
    qc["insert"]["count"] = frag_count_list


    # TSS
    snap.metrics.tsse(atac, gene_anno=gtf)
    max_value = max(list(atac.uns['TSS_profile'])[999:3000])
    min_value = min(list(atac.uns['TSS_profile'])[999:3000])
    listtss=list(atac.uns['TSS_profile'])[999:3000]
    listtssbz = [x / min_value for x in listtss]
    plt.plot(range(-1000, 1001), listtssbz)
    plt.xlabel('Relative Position (bp from TSS)')
    plt.ylabel('Relative Enrichment')
    plt.savefig(os.path.join(outdir, samplename+"_tss_enrichment.png"))
    # TSS score fenbu data to json
    tss_position_list = list(range(-1000,1001))
    tss_score_list = listtssbz
    qc["tss"]["position"] = tss_position_list
    qc["tss"]["score"] = tss_score_list


    # call peak
    genomelen=atac.uns['reference_sequences']['reference_seq_length'].sum()
    print(f'参数 : qvalue={qvalue}, nolambda={nolambda}, shift={shift}, extsize={extsize}, min_len={min_len}, blacklist={blacklist}, n_jobs={core}')
    print(f'Call Peak start ...')
    snap.tl.macs3(atac, qvalue=qvalue, nolambda=nolambda, shift=shift, extsize=extsize, min_len=min_len, blacklist=blacklist, n_jobs=core)
    print(f'Call Peak Done !!!')
    peakfile=os.path.join(outdir, samplename+"_rawpeaks.bed")
    peaksortu=os.path.join(outdir, samplename+"_peaksort-u.bed")
    peakuniq=os.path.join(outdir, samplename+"_peaks.bed")
    with open(peakfile, 'w') as fhout:
        for index, row in atac.uns['macs3_pseudobulk'].iterrows():
            chrom = row['chrom']
            start = row['start']
            end = row['end']
            fhout.write(f'{chrom}\t{start}\t{end}\n')
    # quchong
    cmd = ("less {peakfile} |sort -u > {peaksortu}; "
           "{bedtoolspath} sort -i {peaksortu} > {peakuniq} && rm {peaksortu} {peakfile}"
        ).format(bedtoolspath=bedtoolspath, peakfile=peakfile, peaksortu=peaksortu, peakuniq=peakuniq)
    run(cmd, shell=True)
    # count peaks
    peakuniqlen = 0
    peaknum=0
    peaks = []
    with open(peakuniq, 'r') as fh:
        for line in fh:
            if line.startswith('#'): continue
            peaknum+=1
            chrom = line.strip().split('\t')[0]
            start = int(line.strip().split('\t')[1])
            end = int(line.strip().split('\t')[2])
            peakuniqlen += end - start
            peaks.append(f'{chrom}:{start}-{end}')

    print(f'snapatac2 call peak 的数量：{peaknum}')
    print(f'snapatac2 call peak 的长度：{peakuniqlen}')
    print(f'snapatac2 call peak 的占比：{peakuniqlen/genomelen:.2%}')


    # ---------------count metrics per barcode---------------
    snapatac2.metrics.frip(atac, {"n_frag_overlap_peak": peaks}, normalized=False)
    snapatac2.metrics.frip(atac, {"events_overlap_peak": peaks}, normalized=False, count_as_insertion=True, inplace=True)
    snapatac2.metrics.frip(atac, {"events_all": size_list}, normalized=False, count_as_insertion=True, inplace=True)
    atac.write(os.path.join(outdir, samplename+"_snapatac2_raw.h5ad"))
    raw_peak_mat = snap.pp.make_peak_matrix(atac, use_rep=peaks, counting_strategy='insertion')
    # raw_peak_mat.write(os.path.join(outdir, "raw_peaks_bc_matrix.h5ad"))
    # output raw_peaks_bc_matrix dir
    os.makedirs(os.path.join(outdir, "raw_peaks_bc_matrix"), exist_ok=True)
    scipy.io.mmwrite(os.path.join(outdir, "raw_peaks_bc_matrix/matrix.mtx"), raw_peak_mat.X.T.astype(np.float32))
    with open(os.path.join(outdir, "raw_peaks_bc_matrix/matrix.mtx"), 'rb') as f_in:
        with gzip.open(os.path.join(outdir, "raw_peaks_bc_matrix/matrix.mtx.gz"), 'wb') as f_out:
            f_out.writelines(f_in)
    with gzip.open(os.path.join(outdir, "raw_peaks_bc_matrix/features.tsv.gz"), 'wt') as f:
        for feature in raw_peak_mat.var_names:
            f.write(f"{feature}\t{feature}\tpeaks\n")
    with gzip.open(os.path.join(outdir, "raw_peaks_bc_matrix/barcodes.tsv.gz"), 'wt') as f:
        for cellbarcode in raw_peak_mat.obs_names:
            f.write(f"{cellbarcode}\n")
    os.remove(os.path.join(outdir, "raw_peaks_bc_matrix/matrix.mtx"))


    # 使用 snap 得到的 fragments number 与 fragment
    frag_df = pd.DataFrame(atac.obs['n_fragment'])
    frag_df['barcode'] = atac.obs.index

    # fragments overlap peaks
    frag_overpeak_df = pd.DataFrame(atac.obs['n_frag_overlap_peak'])
    frag_overpeak_df['barcode'] = atac.obs.index
    merged_df = pd.merge(frag_df, frag_overpeak_df, on='barcode', how='inner')

    # events overlap peak
    events_overpeak_df = pd.DataFrame(atac.obs['events_overlap_peak'])
    events_overpeak_df['barcode'] = atac.obs.index
    merged_df = pd.merge(merged_df, events_overpeak_df, on='barcode', how='inner')
    # events per barcode
    events_all_df = pd.DataFrame(atac.obs['events_all'])
    events_all_df['barcode'] = atac.obs.index
    merged_df = pd.merge(merged_df, events_all_df, on='barcode', how='inner')

    # atac reads per barcode
    print("read atac step3 fragments.tsv.gz ...")
    d = {}
    with gzip.open(fragments_file, 'rt') as fh, open(os.path.join(outdir, "frag_counts.xls"), 'w') as fhout:
        fhout.write(f'barcode\tfragment\tnum\treads\n')
        for row in fh:
            tmp = row.strip().split('\t')
            cb = tmp[3]
            fragment = f'{tmp[0]}_{tmp[1]}-{tmp[2]}'
            readsnum = tmp[4]
            d[cb] = d.get(cb, {'fragments_num':0, 'atac_reads':0})
            d[cb]['fragments_num'] += 1
            d[cb]['atac_reads'] += int(readsnum)
            fhout.write(f'{cb}\t{fragment}\t{1}\t{int(readsnum)}\n')
    atac_reads_df = pd.DataFrame.from_dict(d, orient='index')
    atac_reads_df.reset_index(inplace=True)
    atac_reads_df.rename(columns={'index': 'barcode'}, inplace=True)
    merged_df = pd.merge(merged_df, atac_reads_df, on='barcode', how='inner')
    merged_df['fraction_frag_overlap_peak'] = merged_df['n_frag_overlap_peak'] / merged_df['fragments_num']
    print("count atac fragments and reads completed.")


    # ---------------gex gene UMI and reads per barcode---------------
    print("read gex step3 counts.xls...")
    d = {}
    with open(countxls,"rt") as gex_file:
        for line in gex_file:
            if line.startswith("cellID"): continue
            ls = line.strip().split("\t")
            barcode = ls[0]
            gene = ls[1]
            umi = ls[2]
            reads = ls[3]
            d[barcode] = d.get(barcode, {'gex_gene':0, 'gex_umi':0, 'gex_reads':0})
            d[barcode]['gex_gene'] += 1
            d[barcode]['gex_umi'] += int(umi)
            d[barcode]['gex_reads'] += int(reads)
    print("creat dict completed.")
    print("creat gex df...")
    gexdf = pd.DataFrame.from_dict(d, orient='index')
    gexdf.reset_index(inplace=True)
    gexdf.rename(columns={'index': 'barcode'}, inplace=True)
    print("creat gex df completed.")
    merged_df = pd.merge(gexdf, merged_df, on='barcode', how='outer')
    merged_df.fillna(0, inplace=True)
    merged_df['gex_gene'] = merged_df['gex_gene'].astype(int)
    merged_df['gex_umi'] = merged_df['gex_umi'].astype(int)
    merged_df['gex_reads'] = merged_df['gex_reads'].astype(int)
    merged_df['n_fragment'] = merged_df['n_fragment'].astype(int)
    merged_df['n_frag_overlap_peak'] = merged_df['n_frag_overlap_peak'].astype(int)
    merged_df['events_overlap_peak'] = merged_df['events_overlap_peak'].astype(int)
    merged_df['events_all'] = merged_df['events_all'].astype(int)
    merged_df['fragments_num'] = merged_df['fragments_num'].astype(int)
    merged_df['atac_reads'] = merged_df['atac_reads'].astype(int)

    # ---------------filter gex cell barcode---------------
    filter1_df = merged_df[merged_df['fraction_frag_overlap_peak'] > peakuniqlen/genomelen]
    filter2_df = filter1_df[(filter1_df['gex_umi'] > 1) & (filter1_df['events_overlap_peak'] > 1)]
    filter4_df = filter2_df[filter2_df['gex_umi'] >= 500]
    gex_cell_list = filter4_df['barcode'].tolist()
    print(f"gex call cell num: {len(gex_cell_list)}")
    filter3_df = filter2_df[filter2_df['events_overlap_peak'] >= 2500]
    snap_cell_list = filter3_df['barcode'].tolist()
    print(f"atac call cell num: {len(snap_cell_list)}")
    joint_cb_list = list(set(gex_cell_list).intersection(set(snap_cell_list)))
    print(f"joint call cell num: {len(joint_cb_list)}")
 

    # -----------------call cell------------------------
    merged_df['is_cell'] = merged_df['barcode'].isin(joint_cb_list).astype(int)
    merged_df.to_csv(os.path.join(outdir, "per_barcode_metrics.csv"), index=False)
    cell_merged_df = merged_df[merged_df['is_cell'] == 1]

    # -----------------gex cell tihuan-------------------
    gex_summary["cells"]["Fraction Reads in Cells"] = cell_merged_df['gex_reads'].sum()/merged_df['gex_reads'].sum()
    gex_summary["cells"]["Mean Reads per Cell"] = int(gex_summary["stat"]["total"]/len(joint_cb_list))
    gex_summary["cells"]["Median Genes per Cell"] = int(cell_merged_df['gex_gene'].median())
    gex_summary["cells"]["Median UMI Counts per Cell"] = int(cell_merged_df['gex_umi'].median())

    # ----------------------count medain gene and saturation--------------------------------
    chunck_size = calulate_chunck_size(detailxls)
    csv_reader = pd.read_csv(detailxls,
                            dtype={
                                "cellID": "category",
                                "geneID": "category",
                                "UMINum": "category",
                                "ReadsNum": "int32"
                            },
                            sep="\t",
                            chunksize=chunck_size)
    saturation_tmp = defaultdict(lambda: defaultdict(int))
    median_tmp = defaultdict(list)
    basedir = os.path.dirname(detailxls)

    chunk_count = 0
    max_val = 0
    dtypes = np.int16
    cell_reads_total = 0
    for df in csv_reader:
        df = df.loc[df["cellID"].isin(joint_cb_list), :].reset_index(drop=True)
        cell_reads_total += df["Num"].sum()
        rep = df["Num"]
        df = df.drop(["Num"], axis=1)
        idx = df.index.repeat(rep)
        df = df.iloc[idx].reset_index(drop=True)
        del rep, idx
        # shuffle
        df = df.sample(frac=1.0).reset_index(drop=True)
        # downsample
        n_cols_key = [str((i+1)/ 10) for i in range(0,10)]
        for n, interval in enumerate(np.array_split(np.arange(df.shape[0]), 10)):
            idx = interval[-1]
            percentage = n_cols_key[n]
            tmp_file = "tmp_" + str(chunk_count) + "_" + percentage + ".xls"
            tmp_file = os.path.join(basedir, tmp_file)
            # calculate saturation for each portion
            sampled_df = df.iloc[:idx]
            sampled_df = sampled_df.assign(**{percentage: 1})
            sampled_df = sampled_df.groupby(['cellID', 'geneID', 'UMI'], observed=True) \
                                .sum() \
                                .reset_index()
            np.savetxt(tmp_file, sampled_df[percentage].to_numpy(), fmt='%d')
            saturation_tmp[percentage][tmp_file] = idx
            # calculate median for each portion
            median = sampled_df.groupby([sampled_df["cellID"]],observed=True)["geneID"] \
                            .nunique() \
                            .reset_index(drop=True) \
                            .median()
            median_tmp[percentage].append(int(median))
            # refreshing int dtype
            max_curr = sampled_df[percentage].max()
            if max_val < max_curr:
                dtypes = get_int_type(max_curr)
                max_val = max_curr

        chunk_count +=1
    percentage_sampling = [0]
    saturation_sampling = [0]
    median_sampling = [0]
    for perc, files in saturation_tmp.items():
        arr_perc = np.array([], dtype=dtypes)
        all_obs = 0
        for file_path, count in files.items():
            arr = np.loadtxt(file_path, dtype=dtypes)
            arr_perc = np.append(arr_perc, arr)
            all_obs += count
            os.remove(file_path)
        saturation = (np.sum(arr_perc[arr_perc>1] -1) - 0.0)/ all_obs * 100
        saturation_sampling.append(saturation)
    for perc, medians in median_tmp.items():
        median = int(sum(medians)/len(medians))
        median_sampling.append(median)
        percentage_sampling.append(float(perc))
    mean_gexreads_list = [0] + [int(gex_summary["stat"]["total"]/len(joint_cb_list) * float(p))for p in n_cols_key]

    gex_summary["downsample"]["percentage"] = percentage_sampling
    gex_summary["downsample"]["saturation"] = saturation_sampling
    gex_summary["downsample"]["median"] = median_sampling
    gex_summary["downsample"]["Reads"] = mean_gexreads_list
    with open(gexjson, "w") as fh:
        json.dump(
            gex_summary,
            fh,
            indent=4,
            default=lambda o: int(o) if isinstance(o, np.int64) else o
        )


    # ---------------output filter matrix---------------
    filter_atac = atac[joint_cb_list, :]
    filter_atac.write(os.path.join(outdir, samplename+"_snapatac2_filter.h5ad"))
    filter_peak_mat = snap.pp.make_peak_matrix(filter_atac, use_rep=peaks, counting_strategy='insertion')
    # filter_peak_mat.write(os.path.join(outdir, "filter_peaks_bc_matrix.h5ad"))
    # output filter_peaks_bc_matrix dir
    os.makedirs(os.path.join(outdir, "filter_peaks_bc_matrix"), exist_ok=True)
    scipy.io.mmwrite(os.path.join(outdir, "filter_peaks_bc_matrix/matrix.mtx"), filter_peak_mat.X.T.astype(np.float32))
    with open(os.path.join(outdir, "filter_peaks_bc_matrix/matrix.mtx"), 'rb') as f_in:
        with gzip.open(os.path.join(outdir, "filter_peaks_bc_matrix/matrix.mtx.gz"), 'wb') as f_out:
            f_out.writelines(f_in)
    with gzip.open(os.path.join(outdir, "filter_peaks_bc_matrix/features.tsv.gz"), 'wt') as f:
        for feature in filter_peak_mat.var_names:
            f.write(f"{feature}\t{feature}\tpeaks\n")
    with gzip.open(os.path.join(outdir, "filter_peaks_bc_matrix/barcodes.tsv.gz"), 'wt') as f:
        for cellbarcode in filter_peak_mat.obs_names:
            f.write(f"{cellbarcode}\n")
    os.remove(os.path.join(outdir, "filter_peaks_bc_matrix/matrix.mtx"))


    # ---------------cell count---------------
    n_cells = len(joint_cb_list)
    print(f"Estimated number of cells : {n_cells}")
    print(f'Mean raw read pairs per cell : {int(qc["Sequencing"]["Sequenced read pairs"]/n_cells)}')
    print(f"Fraction of high-quality fragments in cells : {cell_merged_df['fragments_num'].sum()/merged_df['fragments_num'].sum():.2%}")
    print(f"Fraction of transposition events in peaks in cells : {cell_merged_df['events_overlap_peak'].sum() / cell_merged_df['events_all'].sum():.2%}")
    print(f"Median_high-quality_fragments_per_cell : {int(cell_merged_df['fragments_num'].median())}")

    qc["Cells"]["Estimated number of cells"] = n_cells
    qc["Cells"]["Mean raw read pairs per cell"] = int(qc["Sequencing"]["Sequenced read pairs"]/n_cells)
    qc["Cells"]["Fraction of high-quality fragments in cells"] = cell_merged_df['fragments_num'].sum()/merged_df['fragments_num'].sum()
    qc["Cells"]["Fraction of transposition events in peaks in cells"] = cell_merged_df['events_overlap_peak'].sum() / cell_merged_df['events_all'].sum()
    qc["Cells"]["Median high-quality fragments per cell"] = int(cell_merged_df['fragments_num'].median())

    # ---------------Targeting Count---------------
    print(f"Number of peaks : {len(peaks)}")
    print(f"Fraction of genome in peaks : {peakuniqlen/genomelen}")
    print(f"TSS enrichment score : {max_value/min_value}")
    print(f"Fraction of high-quality fragments overlapping TSS : {atac.uns['frac_overlap_TSS']:.2%}")
    print(f"Fraction of high-quality fragments overlapping peaks : {cell_merged_df['n_frag_overlap_peak'].sum()/cell_merged_df['fragments_num'].sum():.2%}")

    qc["Targeting"]["Number of peaks"] = len(peaks)
    qc["Targeting"]["Fraction of genome in peaks"] = peakuniqlen/genomelen
    qc["Targeting"]["TSS enrichment score"] = max_value/min_value
    qc["Targeting"]["Fraction of high-quality fragments overlapping TSS"] = atac.uns['frac_overlap_TSS']
    qc["Targeting"]["Fraction of high-quality fragments overlapping peaks"] = cell_merged_df['n_frag_overlap_peak'].sum()/cell_merged_df['fragments_num'].sum()


    # ---------------count medain fragments--------------- 
    chunck_size = calulate_chunck_size(os.path.join(outdir, "frag_counts.xls"))
    csv_reader = pd.read_csv(os.path.join(outdir, "frag_counts.xls"),
                            dtype={
                                "barcode": "category",
                                "fragment": "category",
                                "num": "category",
                                "reads": "int32"
                            },
                            sep="\t",
                            chunksize=chunck_size)

    median_tmp = defaultdict(list)
    cell_reads_total = 0
    for df in csv_reader:
        df = df.loc[df["barcode"].isin(joint_cb_list), :].reset_index(drop=True)
        cell_reads_total += df["reads"].sum()
        rep = df["reads"]
        df = df.drop(["reads"], axis=1)
        idx = df.index.repeat(rep)
        df = df.iloc[idx].reset_index(drop=True)
        del rep, idx
        # shuffle
        df = df.sample(frac=1.0).reset_index(drop=True)
        # downsample
        n_cols_key = [str((i+1)/ 10) for i in range(0,10)]
        for n, interval in enumerate(np.array_split(np.arange(df.shape[0]), 10)):
            idx = interval[-1]
            percentage = n_cols_key[n]
            sampled_df = df.iloc[:idx]
            sampled_df = sampled_df.assign(**{percentage: 1})
            # calculate median for each portion
            median = sampled_df.groupby([sampled_df["barcode"]],observed=True)["fragment"] \
                            .nunique() \
                            .reset_index(drop=True) \
                            .median()
            median_tmp[percentage].append(int(median))

    median_fragments_list = [0]
    for perc, medians in median_tmp.items():
        median = int(sum(medians)/len(medians))
        median_fragments_list.append(median)
    median_fragments_list
    mean_reads_list = [0] + [int(qc["Sequencing"]["Sequenced read pairs"]/n_cells * float(p))for p in n_cols_key]
    percentage_list = ['0'] + [str((i+1)/ 10) for i in range(0,10)]


    qc["median"]["percentage"] = percentage_list
    qc["median"]["mean_reads"] = mean_reads_list
    qc["median"]["median_fragments"] = median_fragments_list


    # ---------------joint call cell 散点图，去重及密度下采样---------------
    joint_df = merged_df[['events_overlap_peak', 'gex_umi', 'is_cell']]
    joint_df_unique = joint_df.drop_duplicates(subset=['events_overlap_peak', 'gex_umi'])
    joint_df_sorted = joint_df_unique.sort_values(by='events_overlap_peak')
    x_max = joint_df_sorted['events_overlap_peak'].max()
    bins = np.linspace(0, x_max, num=21)
    # 存储下采样后的点
    sampled_points = []
    # 对每个 bin 进行随机抽样
    for i in range(len(bins) - 1):
        # 获取当前 bin 的数据
        bin_data = joint_df_sorted[(joint_df_sorted['events_overlap_peak'] > bins[i]) & (joint_df_sorted['events_overlap_peak'] <= bins[i + 1])]
        # print(f'第 {i} 个bin，长度为：{len(bin_data)}')
        chuqu = int((len(bin_data) / len(joint_df_unique))*2000)
        if len(bin_data) > 50:
            sampled = bin_data.sample(n=chuqu, random_state=42)
            # print(f'第 {i} 个bin，sampled 后为：{len(sampled)}')
        else:
            sampled = bin_data       
        sampled_points.append(sampled)
    # 合并所有抽样的点
    final_joint_df = pd.concat(sampled_points)
    cell_umi = final_joint_df.loc[final_joint_df['is_cell'] == 1, 'gex_umi'].tolist()
    cell_events = final_joint_df.loc[final_joint_df['is_cell'] == 1, 'events_overlap_peak'].tolist()
    nocell_umi = final_joint_df.loc[final_joint_df['is_cell'] == 0, 'gex_umi'].tolist()
    nocell_events = final_joint_df.loc[final_joint_df['is_cell'] == 0, 'events_overlap_peak'].tolist()
    qc["joint_cell"]["cell_umi"] = cell_umi # json x2
    qc["joint_cell"]["cell_events"] = cell_events # json y2
    qc["joint_cell"]["nocell_umi"] = nocell_umi # json x1
    qc["joint_cell"]["nocell_events"] = nocell_events # json y1
    qc["joint_cell"]["cell_num"] = (merged_df['is_cell'] == 1).sum() # json x2num
    qc["joint_cell"]["nocell_num"] = (merged_df['is_cell'] == 0).sum() # json x1num


    # ---------------frag 散点图，去重及密度下采样---------------
    frag_df = merged_df[['fragments_num', 'fraction_frag_overlap_peak', 'is_cell']]
    frag_df_unique = frag_df.drop_duplicates(subset=['fragments_num', 'fraction_frag_overlap_peak'])
    frag_df_sorted = frag_df_unique.sort_values(by='fragments_num')
    x_max = frag_df_sorted['fragments_num'].max()
    bins = np.linspace(0, x_max, num=21)
    sampled_points = []
    for i in range(len(bins) - 1):
        bin_data = frag_df_sorted[(frag_df_sorted['fragments_num'] > bins[i]) & (frag_df_sorted['fragments_num'] <= bins[i + 1])]
        chuqu = int((len(bin_data) / len(frag_df_unique))*2000)
        if len(bin_data) > 50:
            sampled = bin_data.sample(n=chuqu, random_state=42)
        else:
            sampled = bin_data       
        sampled_points.append(sampled)
    final_frag_df = pd.concat(sampled_points)
    cell_fragments = final_frag_df.loc[final_frag_df['is_cell'] == 1, 'fragments_num'].tolist()
    cell_frac = final_frag_df.loc[final_frag_df['is_cell'] == 1, 'fraction_frag_overlap_peak'].tolist()
    nocell_fragments = final_frag_df.loc[final_frag_df['is_cell'] == 0, 'fragments_num'].tolist()
    nocell_frac = final_frag_df.loc[final_frag_df['is_cell'] == 0, 'fraction_frag_overlap_peak'].tolist()
    qc["peaks_target"]["cell_fragments"] = cell_fragments
    qc["peaks_target"]["cell_frac"] = cell_frac
    qc["peaks_target"]["nocell_fragments"] = nocell_fragments
    qc["peaks_target"]["nocell_frac"] = nocell_frac

    with open(atacjson, "w") as fh:
        atac_summary["atac"] = qc
        json.dump(
            atac_summary,
            fh,
            indent=4,
            default=lambda o: int(o) if isinstance(o, np.int64) else o
        )


if __name__ == '__main__':
    runpipe()