#!/usr/bin/env python3
"""
Compute contingency table for a specific variant x phenotype combination.

Usage (on DNAnexus swiss-army-knife):
    python3 contingency_table.py \
        --vcf /mnt/in/chr7_lof.vcf.bgz \
        --parquet-tar /mnt/in/clinical_phenotypes_parquet.tar.gz \
        --samples /mnt/in/train_set.txt \
        --variant STEAP1B_lof \
        --phenotype zopiclone__M796 \
        --variant-type dosage

Variant types:
    hardcall  - discrete GT (0/0, 0/1, 1/1) for SNV, cypmicro, HLA
    dosage    - continuous DS field for MPC, LOF, cypdosage
"""

import argparse
import subprocess
import sys
import os
import tarfile
import tempfile


def extract_phenotype_from_parquet(parquet_tar_path, phenotype_name, samples_path):
    """Extract phenotype values from parquet tar.gz archive."""
    try:
        import pyarrow.parquet as pq
        import pyarrow as pa
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "pyarrow", "-q"])
        import pyarrow.parquet as pq
        import pyarrow as pa

    # Load sample list
    samples = set()
    with open(samples_path) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2 and parts[0] != "FID":
                samples.add(parts[1])  # IID

    # Extract parquet files from tar.gz
    pheno_values = {}  # sample_id -> phenotype_value
    with tempfile.TemporaryDirectory() as tmpdir:
        with tarfile.open(parquet_tar_path, "r:gz") as tar:
            tar.extractall(tmpdir)

        # Find and read parquet files
        for root, dirs, files in os.walk(tmpdir):
            for fname in files:
                if fname.endswith(".parquet"):
                    fpath = os.path.join(root, fname)
                    try:
                        pf = pq.ParquetFile(fpath)
                        schema = pf.schema_arrow
                        col_names = [f.name for f in schema]

                        if phenotype_name not in col_names:
                            continue

                        # Read only eid + phenotype column
                        id_col = "eid" if "eid" in col_names else col_names[0]
                        table = pq.read_table(fpath, columns=[id_col, phenotype_name])
                        df = table.to_pydict()

                        for sid, val in zip(df[id_col], df[phenotype_name]):
                            sid_str = str(sid)
                            if sid_str in samples and val is not None:
                                pheno_values[sid_str] = float(val)

                        print(f"Found {phenotype_name} in {fname}: {len(pheno_values)} samples with values", file=sys.stderr)
                        break
                    except Exception as e:
                        print(f"Error reading {fname}: {e}", file=sys.stderr)
                        continue

    if not pheno_values:
        print(f"ERROR: Phenotype {phenotype_name} not found in any parquet file", file=sys.stderr)
        sys.exit(1)

    return pheno_values


def extract_genotypes_bcftools(vcf_path, variant_id, variant_type):
    """Extract per-sample genotype/dosage using bcftools."""
    if variant_type == "hardcall":
        fmt = "%ID[\\t%GT]\\n"
    else:
        fmt = "%ID[\\t%DS]\\n"

    # First get sample names
    cmd_samples = ["bcftools", "query", "-l", vcf_path]
    result = subprocess.run(cmd_samples, capture_output=True, text=True)
    sample_names = result.stdout.strip().split("\n")

    # Extract variant data
    cmd = ["bcftools", "query", "-i", f'ID="{variant_id}"', "-f", fmt, vcf_path]
    result = subprocess.run(cmd, capture_output=True, text=True)

    if not result.stdout.strip():
        print(f"ERROR: Variant {variant_id} not found in VCF", file=sys.stderr)
        sys.exit(1)

    lines = result.stdout.strip().split("\n")
    genotypes = {}  # sample_id -> value

    for line in lines:
        parts = line.split("\t")
        vid = parts[0]
        if vid != variant_id:
            continue
        values = parts[1:]
        for sample, val in zip(sample_names, values):
            if val in (".", "./.", ".|."):
                continue
            genotypes[sample] = val

    return genotypes


def classify_gt(gt_str):
    """Classify a hardcall GT string."""
    gt = gt_str.replace("|", "/")
    if gt == "0/0":
        return "HOM_REF"
    elif gt in ("0/1", "1/0"):
        return "HET"
    elif gt == "1/1":
        return "HOM_ALT"
    else:
        return "OTHER"


def classify_dosage(ds_str, mode_val, median_nonzero):
    """Classify a dosage value relative to mode and median."""
    try:
        ds = float(ds_str)
    except (ValueError, TypeError):
        return None

    if ds == mode_val:
        return "AT_MODE"
    elif median_nonzero is not None and ds >= median_nonzero:
        return "HIGH"
    else:
        return "LOW"


def compute_mode_and_median(values):
    """Compute mode and median of non-mode values."""
    from collections import Counter
    float_vals = []
    for v in values:
        try:
            float_vals.append(float(v))
        except (ValueError, TypeError):
            continue

    if not float_vals:
        return 0.0, None

    counter = Counter(float_vals)
    mode_val = counter.most_common(1)[0][0]

    non_mode = sorted([v for v in float_vals if v != mode_val])
    if non_mode:
        mid = len(non_mode) // 2
        median_nonzero = non_mode[mid]
    else:
        median_nonzero = None

    return mode_val, median_nonzero


def main():
    parser = argparse.ArgumentParser(description="Compute contingency table for variant x phenotype")
    parser.add_argument("--vcf", required=True, help="VCF/BCF file path")
    parser.add_argument("--parquet-tar", required=True, help="Parquet tar.gz archive")
    parser.add_argument("--samples", required=True, help="Samples list (FID IID)")
    parser.add_argument("--variant", required=True, help="Variant ID to extract")
    parser.add_argument("--phenotype", required=True, help="Phenotype column name")
    parser.add_argument("--variant-type", choices=["hardcall", "dosage"], default="dosage",
                        help="Variant type: hardcall (GT) or dosage (DS)")
    args = parser.parse_args()

    print(f"=== Contingency Table: {args.variant} x {args.phenotype} ===", file=sys.stderr)

    # Step 1: Extract phenotype
    print("Extracting phenotype...", file=sys.stderr)
    pheno = extract_phenotype_from_parquet(args.parquet_tar, args.phenotype, args.samples)
    n_cases = sum(1 for v in pheno.values() if v == 1.0)
    n_controls = sum(1 for v in pheno.values() if v == 0.0)
    print(f"Phenotype: {len(pheno)} samples ({n_cases} cases, {n_controls} controls)", file=sys.stderr)

    # Step 2: Extract genotypes
    print("Extracting genotypes...", file=sys.stderr)
    geno = extract_genotypes_bcftools(args.vcf, args.variant, args.variant_type)
    print(f"Genotype: {len(geno)} samples", file=sys.stderr)

    # Step 3: Join on sample ID
    common = set(pheno.keys()) & set(geno.keys())
    print(f"Common samples: {len(common)}", file=sys.stderr)

    if not common:
        print("ERROR: No overlapping samples between phenotype and genotype", file=sys.stderr)
        sys.exit(1)

    # Step 4: Build contingency table
    if args.variant_type == "hardcall":
        categories = ["HOM_REF", "HET", "HOM_ALT"]
        table = {cat: {"case": 0, "control": 0} for cat in categories}

        for sid in common:
            gt_cat = classify_gt(geno[sid])
            if gt_cat == "OTHER":
                continue
            pheno_val = pheno[sid]
            if pheno_val == 1.0:
                table[gt_cat]["case"] += 1
            elif pheno_val == 0.0:
                table[gt_cat]["control"] += 1

    else:  # dosage
        # Compute mode and median of non-mode values
        dosage_values = [geno[sid] for sid in common]
        mode_val, median_nonzero = compute_mode_and_median(dosage_values)
        print(f"Dosage stats: mode={mode_val}, median_of_non_mode={median_nonzero}", file=sys.stderr)

        if median_nonzero is not None:
            categories = ["AT_MODE", "LOW", "HIGH"]
        else:
            categories = ["AT_MODE", "LOW"]
        table = {cat: {"case": 0, "control": 0} for cat in categories}

        for sid in common:
            ds_cat = classify_dosage(geno[sid], mode_val, median_nonzero)
            if ds_cat is None:
                continue
            pheno_val = pheno[sid]
            if pheno_val == 1.0:
                table[ds_cat]["case"] += 1
            elif pheno_val == 0.0:
                table[ds_cat]["control"] += 1

    # Step 5: Print results
    print(f"\n{'='*60}")
    print(f"CONTINGENCY TABLE: {args.variant} x {args.phenotype}")
    print(f"{'='*60}")

    if args.variant_type == "hardcall":
        print(f"{'':>15} {'HOM_REF':>10} {'HET':>10} {'HOM_ALT':>10} {'TOTAL':>10}")
    else:
        header_cats = categories
        fmt = f"{'':>15}" + "".join(f" {c:>10}" for c in header_cats) + f" {'TOTAL':>10}"
        print(fmt)

    print("-" * 60)
    for status in ["case", "control"]:
        counts = [table[cat][status] for cat in categories]
        total = sum(counts)
        row = f"{status:>15}" + "".join(f" {c:>10}" for c in counts) + f" {total:>10}"
        print(row)

    totals = [table[cat]["case"] + table[cat]["control"] for cat in categories]
    grand_total = sum(totals)
    row = f"{'TOTAL':>15}" + "".join(f" {t:>10}" for t in totals) + f" {grand_total:>10}"
    print("-" * 60)
    print(row)

    # Minimum cell count
    all_cells = [table[cat][status] for cat in categories for status in ["case", "control"]]
    min_cell = min(all_cells)
    print(f"\nMinimum cell count: {min_cell}")
    if min_cell < 5:
        print("WARNING: Minimum cell count < 5 - statistical test may be unreliable")

    print(f"{'='*60}")


if __name__ == "__main__":
    main()
