#!/usr/bin/env python3
"""
Compute contingency tables for ALL variants in a VCF/pfile against a phenotype.

Supports two modes:
- bcftools mode (default): Streams variants via bcftools. Works for hardcall (GT) and dosage (DS).
- plink2 mode (--use-plink2): Uses plink2 --geno-counts for fast hardcall computation.

Usage (bcftools mode):
    python3 contingency_all.py \
        --vcf /mnt/in/chr7_lof.vcf.bgz \
        --phenotype-tsv /mnt/in/zopiclone__M796_pheno.tsv \
        --variant-type hardcall \
        --vcf-source wes/lof/chr7_lof

Usage (plink2 mode):
    python3 contingency_all.py \
        --use-plink2 \
        --pfile-prefix /mnt/project/results/gemo/snp/chr1/chr1_snv \
        --phenotype-tsv /mnt/in/zopiclone__M796_pheno.tsv \
        --variant-type hardcall \
        --vcf-source snp/chr1/chr1_snv \
        --output /home/dnanexus/out/out/result.tsv

Output (TSV, one row per variant):
    VARIANT_ID  VCF_SOURCE  FORMAT_TYPE  CAT1_NAME  CAT2_NAME  CAT1_CASE  CAT2_CASE  CAT1_CTRL  CAT2_CTRL  TOTAL_CASE  TOTAL_CTRL  MIN_CELL  EXTRA
"""

import argparse
import os
import shutil
import subprocess
import sys
import tempfile
from collections import Counter


def load_phenotype(tsv_path, phenotype_type="binary"):
    """Load phenotype TSV into dict {sample_id -> 0 or 1}.
    For continuous phenotypes, all samples with values are marked as 1 (present)."""
    pheno = {}
    with open(tsv_path) as f:
        header = f.readline()
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) >= 2:
                sid = parts[0]
                val = float(parts[1])
                if phenotype_type == "continuous":
                    pheno[sid] = 1  # mark as present
                else:
                    if val == 1.0:
                        pheno[sid] = 1
                    elif val == 0.0:
                        pheno[sid] = 0
    return pheno


def get_sample_names(vcf_path):
    """Get ordered sample names from VCF."""
    result = subprocess.run(
        ["bcftools", "query", "-l", vcf_path],
        capture_output=True, text=True
    )
    if result.returncode != 0:
        print(f"ERROR: bcftools query -l failed: {result.stderr}", file=sys.stderr)
        sys.exit(1)
    return result.stdout.strip().split("\n")


def build_sample_index(sample_names, pheno):
    """Build array mapping sample position -> phenotype status (0, 1, or -1 for missing)."""
    index = []
    for name in sample_names:
        if name in pheno:
            index.append(pheno[name])
        else:
            index.append(-1)
    return index


def process_hardcall_line(values, sample_index, phenotype_type="binary"):
    """Process a single hardcall variant line. Returns contingency counts."""
    hom_ref_case = 0
    het_case = 0
    hom_alt_case = 0
    hom_ref_ctrl = 0
    het_ctrl = 0
    hom_alt_ctrl = 0

    for i, gt_str in enumerate(values):
        if i >= len(sample_index):
            break
        status = sample_index[i]
        if status == -1:
            continue

        gt = gt_str.replace("|", "/")
        if gt in (".", "./.", ".|."):
            continue

        if gt == "0/0":
            if status == 1:
                hom_ref_case += 1
            else:
                hom_ref_ctrl += 1
        elif gt in ("0/1", "1/0"):
            if status == 1:
                het_case += 1
            else:
                het_ctrl += 1
        elif gt == "1/1":
            if status == 1:
                hom_alt_case += 1
            else:
                hom_alt_ctrl += 1

    carrier_case = het_case + hom_alt_case
    carrier_ctrl = het_ctrl + hom_alt_ctrl
    total_case = hom_ref_case + carrier_case
    total_ctrl = hom_ref_ctrl + carrier_ctrl

    # Allele counting: het contributes 1 to each allele
    ref_alleles_case = hom_ref_case * 2 + het_case
    ref_alleles_ctrl = hom_ref_ctrl * 2 + het_ctrl
    alt_alleles_case = hom_alt_case * 2 + het_case
    alt_alleles_ctrl = hom_alt_ctrl * 2 + het_ctrl

    # MIN_CELL on the minor allele (whichever is rarer overall)
    total_ref = ref_alleles_case + ref_alleles_ctrl
    total_alt = alt_alleles_case + alt_alleles_ctrl
    if total_ref <= total_alt:
        # REF is minor allele
        minor_case = ref_alleles_case
        minor_ctrl = ref_alleles_ctrl
    else:
        # ALT is minor allele
        minor_case = alt_alleles_case
        minor_ctrl = alt_alleles_ctrl

    if phenotype_type == "continuous":
        min_cell = minor_case  # total minor allele count (all samples are "cases")
    else:
        min_cell = min(minor_case, minor_ctrl)
    extra = f"{het_case}:{het_ctrl}:{hom_alt_case}:{hom_alt_ctrl}"

    return (hom_ref_case, carrier_case, hom_ref_ctrl, carrier_ctrl,
            total_case, total_ctrl, min_cell, extra)


def process_dosage_line(values, sample_index, phenotype_type="binary"):
    """Process a single dosage variant line. Returns contingency counts with mode binning."""
    # First pass: collect all valid DS values and their phenotype status
    ds_values = []
    ds_status = []

    for i, ds_str in enumerate(values):
        if i >= len(sample_index):
            break
        status = sample_index[i]
        if status == -1:
            continue
        if ds_str in (".", ""):
            continue
        try:
            ds = float(ds_str)
        except (ValueError, TypeError):
            continue
        ds_values.append(ds)
        ds_status.append(status)

    if not ds_values:
        return None

    # Compute mode
    counter = Counter(ds_values)
    mode_val = counter.most_common(1)[0][0]

    # Classify AT_MODE vs NOT_AT_MODE
    at_mode_case = 0
    not_at_mode_case = 0
    at_mode_ctrl = 0
    not_at_mode_ctrl = 0

    for ds, status in zip(ds_values, ds_status):
        if ds == mode_val:
            if status == 1:
                at_mode_case += 1
            else:
                at_mode_ctrl += 1
        else:
            if status == 1:
                not_at_mode_case += 1
            else:
                not_at_mode_ctrl += 1

    total_case = at_mode_case + not_at_mode_case
    total_ctrl = at_mode_ctrl + not_at_mode_ctrl
    if phenotype_type == "continuous":
        min_cell = not_at_mode_case  # total non-mode carriers
    else:
        min_cell = min(not_at_mode_case, not_at_mode_ctrl)
    extra = f"mode={mode_val}"

    return (at_mode_case, not_at_mode_case, at_mode_ctrl, not_at_mode_ctrl,
            total_case, total_ctrl, min_cell, extra)


HEADER = "\t".join([
    "VARIANT_ID", "VCF_SOURCE", "FORMAT_TYPE",
    "CAT1_NAME", "CAT2_NAME",
    "CAT1_CASE", "CAT2_CASE", "CAT1_CTRL", "CAT2_CTRL",
    "TOTAL_CASE", "TOTAL_CTRL", "MIN_CELL", "EXTRA"
])


def parse_gcount(path):
    """Parse plink2 .gcount file. Returns list of (variant_id, counts_dict)."""
    variants = []
    with open(path) as f:
        header_line = f.readline().strip()
        cols = header_line.split("\t")
        # plink2 prefixes first column with #
        if cols[0].startswith("#"):
            cols[0] = cols[0].lstrip("#")
        col_map = {name: i for i, name in enumerate(cols)}

        id_idx = col_map.get("ID")
        hom_ref_idx = col_map.get("HOM_REF_CT")
        het_idx = col_map.get("HET_REF_ALT_CTS")
        hom_alt_idx = col_map.get("TWO_ALT_GENO_CTS")

        if any(idx is None for idx in [id_idx, hom_ref_idx, het_idx, hom_alt_idx]):
            print(f"ERROR: Missing columns in {path}. Found: {cols}", file=sys.stderr)
            sys.exit(1)

        for line in f:
            parts = line.strip().split("\t")
            vid = parts[id_idx]
            variants.append((vid, {
                "hom_ref": int(parts[hom_ref_idx]),
                "het": int(parts[het_idx]),
                "hom_alt": int(parts[hom_alt_idx]),
            }))
    return variants


def open_output(args):
    """Open output file or stdout."""
    if args.output:
        os.makedirs(os.path.dirname(args.output), exist_ok=True)
        return open(args.output, "w")
    return sys.stdout


def run_plink2_mode(args):
    """Use plink2 --geno-counts for fast hardcall contingency computation."""
    pheno = load_phenotype(args.phenotype_tsv, args.phenotype_type)
    n_cases = sum(1 for v in pheno.values() if v == 1)
    n_controls = sum(1 for v in pheno.values() if v == 0)
    print(f"Phenotype: {len(pheno)} samples ({n_cases} cases, {n_controls} controls)", file=sys.stderr)

    tmpdir = tempfile.mkdtemp()

    # Write keep files for plink2
    cases_path = os.path.join(tmpdir, "cases.txt")
    controls_path = os.path.join(tmpdir, "controls.txt")
    with open(cases_path, "w") as fc, open(controls_path, "w") as fk:
        fc.write("FID\tIID\n")
        fk.write("FID\tIID\n")
        for sid, val in pheno.items():
            if val == 1:
                fc.write(f"{sid}\t{sid}\n")
            elif val == 0:
                fk.write(f"{sid}\t{sid}\n")
    print(f"Keep files: {n_cases} cases, {n_controls} controls", file=sys.stderr)

    # Run plink2 for cases
    case_prefix = os.path.join(tmpdir, "cases")
    cmd = ["plink2", "--pfile", args.pfile_prefix,
           "--keep", cases_path, "--geno-counts", "--out", case_prefix]
    print(f"Running: {' '.join(cmd)}", file=sys.stderr)
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"ERROR plink2 (cases): {result.stderr}", file=sys.stderr)
        sys.exit(1)
    print(f"plink2 cases done", file=sys.stderr)

    # Parse cases gcount
    case_variants = parse_gcount(case_prefix + ".gcount")

    if args.phenotype_type == "continuous":
        # Continuous: all samples are cases, no controls to run
        # Create dummy ctrl counts (all zeros) matching case variants
        ctrl_by_id = {vid: {"hom_ref": 0, "het": 0, "hom_alt": 0} for vid, _ in case_variants}
        print(f"Continuous mode: skipping controls plink2 run (all samples are cases)", file=sys.stderr)
        print(f"Variants: {len(case_variants)} (cases gcount), {len(ctrl_by_id)} (dummy controls)", file=sys.stderr)
    else:
        # Binary: run plink2 for controls
        ctrl_prefix = os.path.join(tmpdir, "controls")
        cmd = ["plink2", "--pfile", args.pfile_prefix,
               "--keep", controls_path, "--geno-counts", "--out", ctrl_prefix]
        print(f"Running: {' '.join(cmd)}", file=sys.stderr)
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"ERROR plink2 (controls): {result.stderr}", file=sys.stderr)
            sys.exit(1)
        print(f"plink2 controls done", file=sys.stderr)

        ctrl_variants = parse_gcount(ctrl_prefix + ".gcount")
        ctrl_by_id = {vid: counts for vid, counts in ctrl_variants}
        print(f"Variants: {len(case_variants)} (cases gcount), {len(ctrl_variants)} (controls gcount)", file=sys.stderr)

    # Write output
    out = open_output(args)
    out.write(HEADER + "\n")

    n_variants = 0
    for vid, case_cts in case_variants:
        ctrl_cts = ctrl_by_id.get(vid)
        if ctrl_cts is None:
            continue

        hom_ref_case = case_cts["hom_ref"]
        het_case = case_cts["het"]
        hom_alt_case = case_cts["hom_alt"]
        carrier_case = het_case + hom_alt_case

        hom_ref_ctrl = ctrl_cts["hom_ref"]
        het_ctrl = ctrl_cts["het"]
        hom_alt_ctrl = ctrl_cts["hom_alt"]
        carrier_ctrl = het_ctrl + hom_alt_ctrl

        total_case = hom_ref_case + carrier_case
        total_ctrl = hom_ref_ctrl + carrier_ctrl

        # Allele counting: het contributes 1 to each allele
        ref_alleles_case = hom_ref_case * 2 + het_case
        ref_alleles_ctrl = hom_ref_ctrl * 2 + het_ctrl
        alt_alleles_case = hom_alt_case * 2 + het_case
        alt_alleles_ctrl = hom_alt_ctrl * 2 + het_ctrl

        total_ref = ref_alleles_case + ref_alleles_ctrl
        total_alt = alt_alleles_case + alt_alleles_ctrl
        if total_ref <= total_alt:
            minor_case, minor_ctrl = ref_alleles_case, ref_alleles_ctrl
        else:
            minor_case, minor_ctrl = alt_alleles_case, alt_alleles_ctrl

        if args.phenotype_type == "continuous":
            min_cell = minor_case
        else:
            min_cell = min(minor_case, minor_ctrl)
        extra = f"{het_case}:{het_ctrl}:{hom_alt_case}:{hom_alt_ctrl}"

        out.write("\t".join([
            vid, args.vcf_source, "hardcall",
            "HOM_REF", "CARRIER",
            str(hom_ref_case), str(carrier_case),
            str(hom_ref_ctrl), str(carrier_ctrl),
            str(total_case), str(total_ctrl), str(min_cell), extra
        ]) + "\n")

        n_variants += 1
        if n_variants % 100000 == 0:
            print(f"Written {n_variants} variants...", file=sys.stderr)

    if args.output:
        out.close()

    shutil.rmtree(tmpdir, ignore_errors=True)
    print(f"Done. {n_variants} variants from {args.vcf_source} (plink2 mode)", file=sys.stderr)


def run_bcftools_mode(args):
    """Bcftools streaming mode for VCF files."""
    pheno = load_phenotype(args.phenotype_tsv, args.phenotype_type)
    n_cases = sum(1 for v in pheno.values() if v == 1)
    n_controls = sum(1 for v in pheno.values() if v == 0)
    print(f"Phenotype: {len(pheno)} samples ({n_cases} cases, {n_controls} controls)", file=sys.stderr)

    sample_names = get_sample_names(args.vcf)
    sample_index = build_sample_index(sample_names, pheno)
    n_mapped = sum(1 for s in sample_index if s != -1)
    print(f"VCF samples: {len(sample_names)}, mapped to phenotype: {n_mapped}", file=sys.stderr)

    if n_mapped == 0:
        print("ERROR: No samples overlap between VCF and phenotype", file=sys.stderr)
        sys.exit(1)

    if args.variant_type == "hardcall":
        cat1_name = "HOM_REF"
        cat2_name = "CARRIER"
    else:
        cat1_name = "AT_MODE"
        cat2_name = "NOT_AT_MODE"

    out = open_output(args)
    out.write(HEADER + "\n")

    # Stream variants via bcftools
    if args.variant_type == "hardcall":
        fmt = r"%ID[\t%GT]\n"
    else:
        fmt = r"%ID[\t%DS]\n"

    cmd = ["bcftools", "query", "-f", fmt, args.vcf]
    print(f"Running: {' '.join(cmd)}", file=sys.stderr)

    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    n_variants = 0
    for line in proc.stdout:
        line = line.rstrip("\n")
        if not line:
            continue

        parts = line.split("\t")
        variant_id = parts[0]
        values = parts[1:]

        if args.variant_type == "hardcall":
            result = process_hardcall_line(values, sample_index, args.phenotype_type)
        else:
            result = process_dosage_line(values, sample_index, args.phenotype_type)

        if result is None:
            continue
        cat1_case, cat2_case, cat1_ctrl, cat2_ctrl, total_case, total_ctrl, min_cell, extra = result

        out.write("\t".join([
            variant_id, args.vcf_source, args.variant_type,
            cat1_name, cat2_name,
            str(cat1_case), str(cat2_case), str(cat1_ctrl), str(cat2_ctrl),
            str(total_case), str(total_ctrl), str(min_cell), extra
        ]) + "\n")

        n_variants += 1
        if n_variants % 1000 == 0:
            print(f"Processed {n_variants} variants...", file=sys.stderr)

    proc.wait()
    if proc.returncode != 0:
        stderr_out = proc.stderr.read()
        print(f"WARNING: bcftools exited with code {proc.returncode}: {stderr_out}", file=sys.stderr)

    if args.output:
        out.close()

    print(f"Done. {n_variants} variants from {args.vcf_source}", file=sys.stderr)


def main():
    parser = argparse.ArgumentParser(description="Compute contingency tables for all variants")
    parser.add_argument("--vcf", help="VCF/BCF file path (bcftools mode)")
    parser.add_argument("--phenotype-tsv", required=True, help="Phenotype TSV (SAMPLE_ID, PHENO_VALUE)")
    parser.add_argument("--variant-type", choices=["hardcall", "dosage"], required=True,
                        help="hardcall (GT) or dosage (DS)")
    parser.add_argument("--vcf-source", required=True, help="VCF source identifier (e.g. wes/lof/chr7_lof)")
    parser.add_argument("--use-plink2", action="store_true", help="Use plink2 --geno-counts (hardcall only)")
    parser.add_argument("--pfile-prefix", help="plink2 pfile prefix (without .pgen/.psam/.pvar extension)")
    parser.add_argument("--output", help="Output file path (default: stdout)")
    parser.add_argument("--phenotype-type", default="binary", choices=["binary", "continuous"],
                        help="Phenotype type: binary (case/control) or continuous (carrier count only)")
    args = parser.parse_args()

    if args.use_plink2:
        if not args.pfile_prefix:
            parser.error("--use-plink2 requires --pfile-prefix")
        if args.variant_type != "hardcall":
            parser.error("--use-plink2 only supports --variant-type hardcall")
        run_plink2_mode(args)
    else:
        if not args.vcf:
            parser.error("bcftools mode requires --vcf")
        run_bcftools_mode(args)


if __name__ == "__main__":
    main()
