#!/usr/bin/env python3
"""
Extract phenotype values from a parquet tar.gz archive for use in contingency table computation.

Runs on DNAnexus swiss-army-knife. Installs pyarrow at runtime if needed.

Usage:
    python3 extract_phenotype.py \
        --parquet-tar /mnt/in/clinical_phenotypes_parquet.tar.gz \
        --parquet-tar /mnt/in/prescription_phenotypes_parquet.tar.gz \
        --parquet-tar /mnt/in/assessment_centre_phenotypes_parquet.tar.gz \
        --samples /mnt/in/train_set.txt \
        --phenotype zopiclone__M796 \
        --output /home/dnanexus/out/out/zopiclone__M796_pheno.tsv
"""

import argparse
import subprocess
import sys
import os
import tarfile
import tempfile


def main():
    parser = argparse.ArgumentParser(description="Extract phenotype from parquet archive")
    parser.add_argument("--parquet-tar", required=True, action="append", help="Parquet tar.gz archive path (can specify multiple)")
    parser.add_argument("--samples", required=True, help="Train set file (FID IID format)")
    parser.add_argument("--phenotype", required=True, help="Phenotype column name")
    parser.add_argument("--output", required=True, help="Output TSV path")
    parser.add_argument("--phenotype-type", default="binary", choices=["binary", "continuous"],
                        help="Phenotype type: binary (impute missing as 0) or continuous (only keep actual values)")
    args = parser.parse_args()

    # Install pyarrow if needed
    try:
        import pyarrow.parquet as pq
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "pyarrow", "-q"])
        import pyarrow.parquet as pq

    # Load sample list (IID column)
    samples = set()
    with open(args.samples) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2 and parts[0] != "FID":
                samples.add(parts[1])
    print(f"Loaded {len(samples)} samples from train set", file=sys.stderr)

    # Search all parquet archives for the phenotype (read ALL partitions)
    pheno_values = {}
    found_in_archive = None
    for tar_path in args.parquet_tar:
        if found_in_archive:
            break
        with tempfile.TemporaryDirectory() as tmpdir:
            print(f"Extracting {tar_path}...", file=sys.stderr)
            with tarfile.open(tar_path, "r:gz") as tar:
                tar.extractall(tmpdir)

            n_files_with_col = 0
            for root, dirs, files in os.walk(tmpdir):
                for fname in sorted(files):
                    if not fname.endswith(".parquet"):
                        continue
                    fpath = os.path.join(root, fname)
                    try:
                        pf = pq.ParquetFile(fpath)
                        col_names = [f.name for f in pf.schema_arrow]

                        if args.phenotype not in col_names:
                            continue

                        id_col = "eid" if "eid" in col_names else col_names[0]
                        table = pq.read_table(fpath, columns=[id_col, args.phenotype])
                        df = table.to_pydict()

                        n_before = len(pheno_values)
                        for sid, val in zip(df[id_col], df[args.phenotype]):
                            sid_str = str(sid)
                            if sid_str in samples and val is not None:
                                pheno_values[sid_str] = float(val)

                        n_new = len(pheno_values) - n_before
                        n_files_with_col += 1
                        if n_new > 0:
                            print(f"  {fname}: +{n_new} samples (total: {len(pheno_values)})", file=sys.stderr)
                    except Exception as e:
                        print(f"Error reading {fname}: {e}", file=sys.stderr)
                        continue

            if n_files_with_col > 0:
                found_in_archive = tar_path
                print(f"Found {args.phenotype} in {n_files_with_col} parquet files from {tar_path}: {len(pheno_values)} samples total", file=sys.stderr)

    if not pheno_values:
        print(f"ERROR: Phenotype {args.phenotype} not found in any parquet file", file=sys.stderr)
        sys.exit(1)

    n_from_parquet = len(pheno_values)

    if args.phenotype_type == "continuous":
        print(f"Continuous phenotype: {n_from_parquet} samples with values", file=sys.stderr)
    else:
        n_cases = sum(1 for v in pheno_values.values() if v == 1.0)
        n_controls_explicit = sum(1 for v in pheno_values.values() if v == 0.0)

        # Only use explicit cases and controls from parquet — no imputation
        # REGENIE uses only samples with actual phenotype values (~14K, not 138K)
        n_controls = n_controls_explicit
        print(f"From parquet: {n_from_parquet} samples ({n_cases} cases, {n_controls} explicit controls)", file=sys.stderr)
        print(f"Total: {len(pheno_values)} samples ({n_cases} cases, {n_controls} controls)", file=sys.stderr)

    # Write output
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    with open(args.output, "w") as f:
        f.write("SAMPLE_ID\tPHENO_VALUE\n")
        for sid, val in sorted(pheno_values.items()):
            f.write(f"{sid}\t{val}\n")

    print(f"Written {len(pheno_values)} samples to {args.output}", file=sys.stderr)

    if args.phenotype_type == "continuous":
        # For continuous: all samples with values go to cases keep file, empty controls
        output_dir = os.path.dirname(args.output)
        cases_path = os.path.join(output_dir, f"{args.phenotype}_cases.txt")
        controls_path = os.path.join(output_dir, f"{args.phenotype}_controls.txt")

        with open(cases_path, "w") as f:
            f.write("FID\tIID\n")
            for sid in sorted(pheno_values.keys()):
                f.write(f"{sid}\t{sid}\n")

        with open(controls_path, "w") as f:
            f.write("FID\tIID\n")

        print(f"Written {len(pheno_values)} samples to cases (continuous, all present)", file=sys.stderr)
        print(f"Written 0 controls (continuous mode)", file=sys.stderr)
    else:
        # Write case/control keep files for plink2
        output_dir = os.path.dirname(args.output)
        cases_path = os.path.join(output_dir, f"{args.phenotype}_cases.txt")
        controls_path = os.path.join(output_dir, f"{args.phenotype}_controls.txt")

        with open(cases_path, "w") as f:
            f.write("FID\tIID\n")
            for sid in sorted(pheno_values.keys()):
                if pheno_values[sid] == 1.0:
                    f.write(f"{sid}\t{sid}\n")

        with open(controls_path, "w") as f:
            f.write("FID\tIID\n")
            for sid in sorted(pheno_values.keys()):
                if pheno_values[sid] == 0.0:
                    f.write(f"{sid}\t{sid}\n")

        print(f"Written {n_cases} cases to {cases_path}", file=sys.stderr)
        print(f"Written {n_controls} controls to {controls_path}", file=sys.stderr)


if __name__ == "__main__":
    main()
