import matplotlib.pyplot as plt
import scipy.stats as stats
# --- 1. CONFIGURE PATHS ---
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(script_dir))
results_file = sys.argv[1]
results_file = os.path.join(project_root, "results", "tables", "all_subjects_results.csv")
ref_file = os.path.join(project_root, "data", "references", "centiloid_values.csv")
output_dir = os.path.join(project_root, "results", "reports")
os.makedirs(output_dir, exist_ok=True)
print(f"Reading results from: {results_file}")
df_calc = pd.read_csv(results_file)
print(f"Reading reference from: {ref_file}")
df_ref = pd.read_csv(ref_file)
df_ref.columns = [c.strip() for c in df_ref.columns]
print("Calculated Data Preview:")
# --- 3. MERGE DATASETS ---
df_calc['subject_id'] = df_calc['subject_id'].astype(str).str.strip()
df_ref['Subject'] = df_ref['Subject'].astype(str).str.strip()
df_merged = pd.merge(df_calc, df_ref, left_on='subject_id', right_on='Subject', how='inner')
print("ERROR: No matching subjects found.")
print(f"Successfully merged {len(df_merged)} subjects.")
# --- 4. CORRELATION ANALYSIS ---
x_suvr = df_merged['global_cortical_suvr']
y_suvr = df_merged['SUVR']
r_suvr, p_suvr = stats.pearsonr(x_suvr, y_suvr)
x_cl = df_merged['global_cortical_centiloid']
y_cl = df_merged['Centiloid']
r_cl, p_cl = stats.pearsonr(x_cl, y_cl)
# --- 5. PRINT RESULTS ---
print("\n=== STATISTICAL ANALYSIS RESULTS ===")
print(f"Number of Subjects: {len(df_merged)}")
print(f"\n1. SUVR Correlation:")
print(f" Pearson r: {r_suvr:.4f}")
print(f" p-value: {p_suvr:.4e}")
print(f"\n2. Centiloid Correlation:")
print(f" Pearson r: {r_cl:.4f}")
print(f" p-value: {p_cl:.4e}")
print("====================================\n")
# --- 6. GENERATE PLOTS ---
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].scatter(x_suvr, y_suvr, alpha=0.7)
axes[0].set_title(f'SUVR Correlation\nr={r_suvr:.3f}, p={p_suvr:.3e}')
axes[0].set_xlabel('Calculated SUVR')
axes[0].set_ylabel('Reference SUVR')
m, b = np.polyfit(x_suvr, y_suvr, 1)
axes[0].plot(x_suvr, m*x_suvr + b, color='red', linestyle='--')
axes[0].grid(True, linestyle=':', alpha=0.6)
axes[1].scatter(x_cl, y_cl, color='green', alpha=0.7)
axes[1].set_title(f'Centiloid Correlation\nr={r_cl:.3f}, p={p_cl:.3e}')
axes[1].set_xlabel('Calculated Centiloid')
axes[1].set_ylabel('Reference Centiloid')
m, b = np.polyfit(x_cl, y_cl, 1)
axes[1].plot(x_cl, m*x_cl + b, color='red', linestyle='--')
axes[1].grid(True, linestyle=':', alpha=0.6)
output_plot = os.path.join(output_dir, "correlation_plots.png")
print(f"Plots saved to: {output_plot}")
if __name__ == "__main__":