#!/usr/bin/env python3

import pandas as pd
import requests
import xmltodict
import time
import matplotlib.pyplot as plt
import os
import seaborn as sns
import scipy.stats as stats
import argparse
from tqdm import tqdm
import logging
from typing import Tuple, Dict, Optional, List
import plotly.express as px
import plotly.io as pio

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Constants
NCBI_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
METADATA_FOLDER_NAME = "metadata_output"
FIGURES_FOLDER_NAME = "figures"
SEQUENCE_FOLDER_NAME = "sequence"

# Cache to store fetched metadata
metadata_cache: Dict[str, Tuple] = {}


def load_data(input_file: str) -> pd.DataFrame:
    """Load the TSV file into a DataFrame."""
    try:
        df = pd.read_csv(input_file, sep='\t')
        logging.info(f"Data loaded successfully from {input_file}")
        return df
    except Exception as e:
        logging.error(f"Error loading data from {input_file}: {e}")
        raise


def filter_data(df: pd.DataFrame, checkm_threshold: float, ani_status_list: list) -> pd.DataFrame:
    """Filter the DataFrame based on CheckM completeness and ANI Check status."""
    try:
        filtered_df = df[
            df["CheckM completeness"].notna() &
            (df["CheckM completeness"] > checkm_threshold)
        ]

        # Only apply ANI filtering if 'all' is not in the list
        if "all" not in ani_status_list:
            filtered_df = filtered_df[filtered_df["ANI Check status"].isin(ani_status_list)]

        logging.info(f"Data filtered with CheckM threshold {checkm_threshold} and ANI status {ani_status_list}")
        return filtered_df
    except Exception as e:
        logging.error(f"Error filtering data: {e}")
        raise


def create_output_directory(output_directory: str, organism_name: str) -> Tuple[str, str, str, str]:
    """Create the output directory and subdirectories."""
    try:
        organism_folder = os.path.join(output_directory, organism_name.replace(" ", "_"))
        metadata_folder = os.path.join(organism_folder, METADATA_FOLDER_NAME)
        figures_folder = os.path.join(organism_folder, FIGURES_FOLDER_NAME)
        sequence_folder = os.path.join(organism_folder, SEQUENCE_FOLDER_NAME)

        os.makedirs(metadata_folder, exist_ok=True)
        os.makedirs(figures_folder, exist_ok=True)
        os.makedirs(sequence_folder, exist_ok=True)

        logging.info(f"Output directories created: {organism_folder}")
        return organism_folder, metadata_folder, figures_folder, sequence_folder
    except Exception as e:
        logging.error(f"Error creating output directories: {e}")
        raise



def fetch_metadata(biosample_id: str, sleep_time: float) -> Tuple:
    """Fetch metadata from NCBI."""
    if biosample_id in metadata_cache:
        return metadata_cache[biosample_id]

    url = f"{NCBI_URL}?db=biosample&id={biosample_id}&retmode=xml"
    try:
        response = requests.get(url)
        response.raise_for_status()
        time.sleep(sleep_time)
        data = xmltodict.parse(response.text)

        if not data.get("BioSampleSet"):
            logging.warning(f"No 'BioSampleSet' found for BioSample {biosample_id}")
            return pd.NA, pd.NA, pd.NA, pd.NA

        biosample = data["BioSampleSet"].get("BioSample")
        if not biosample:
            logging.warning(f"No 'BioSample' found for BioSample {biosample_id}")
            return pd.NA, pd.NA, pd.NA, pd.NA

        attributes = biosample.get("Attributes", {}).get("Attribute", [])
        if not attributes:
            logging.warning(f"No 'Attributes' found for BioSample {biosample_id}")
            return pd.NA, pd.NA, pd.NA, pd.NA

        # Extract metadata
        isolation_source = collection_date = geo_location = host = pd.NA
        for attr in attributes:
            if isinstance(attr, dict):
                if attr.get("@attribute_name") == "isolation_source":
                    isolation_source = attr.get("#text", pd.NA)
                elif attr.get("@attribute_name") == "collection_date":
                    collection_date = attr.get("#text", pd.NA)
                elif attr.get("@attribute_name") == "geo_loc_name":
                    geo_location = attr.get("#text", pd.NA)
                elif attr.get("@attribute_name") == "host":
                    host = attr.get("#text", pd.NA)

        metadata_cache[biosample_id] = (isolation_source, collection_date, geo_location, host)
        return isolation_source, collection_date, geo_location, host
    except requests.exceptions.RequestException as e:
        logging.error(f"Network error fetching BioSample {biosample_id}: {e}")
        return pd.NA, pd.NA, pd.NA, pd.NA
    except Exception as e:
        logging.error(f"Unexpected error fetching BioSample {biosample_id}: {e}")
        return pd.NA, pd.NA, pd.NA, pd.NA


def standardize_date(date: str) -> str:
    """Standardize the 'Collection Date' column."""
    if pd.isna(date) or date in ["unknown", "missing", "NA", "not collected"]:
        return "absent"
    try:
        year = str(date).split("-")[0]
        if year.isdigit() and len(year) == 4:
            return year
        else:
            return "absent"
    except:
        return "absent"


def standardize_location(location: str) -> str:
    """Standardize the 'Geographic Location' column."""
    if pd.isna(location) or location.lower() in ["missing", "unknown", "not applicable", "not collected"]:
        return "absent"
    try:
        country = location.split(":")[0].strip()
        return country
    except:
        return "absent"


def standardize_host(host: str) -> str:
    """Standardize the 'Host' column."""
    if pd.isna(host) or host in ["unknown", "missing", "not applicable", "not collected", ""]:
        return "absent"
    return host


def save_summary(df: pd.DataFrame, output_file: str) -> None:
    """Save the DataFrame to a TSV file."""
    try:
        df.to_csv(output_file, sep='\t', index=False)
        logging.info(f"Data saved to {output_file}")
    except Exception as e:
        logging.error(f"Error saving data to {output_file}: {e}")
        raise


def plot_bar_charts(variable: str, frequency: pd.Series, percentage: pd.Series, figures_folder: str) -> None:
    """Generate and save bar plots for a given variable."""
    try:
        plt.figure(figsize=(12, 6))

        # Frequency plot
        plt.subplot(1, 2, 1)
        sns.barplot(x=frequency.index, y=frequency.values, palette="viridis")
        plt.title(f"Frequency of {variable}")
        plt.xlabel(variable)
        plt.ylabel("Frequency")
        plt.xticks(rotation=45, ha="right")

        # Percentage plot
        plt.subplot(1, 2, 2)
        sns.barplot(x=percentage.index, y=percentage.values, palette="viridis")
        plt.title(f"Percentage of {variable}")
        plt.xlabel(variable)
        plt.ylabel("Percentage")
        plt.xticks(rotation=45, ha="right")

        plt.tight_layout()
        figure_path = os.path.join(figures_folder, f"{variable}_bar_plots.tiff")
        plt.savefig(figure_path, format="tiff", dpi=300)
        plt.close()
        logging.info(f"Bar plots saved for {variable}")
    except Exception as e:
        logging.error(f"Error generating bar plots for {variable}: {e}")
        raise

def plot_geo_choropleth(variable: str, frequency: pd.Series, figures_folder: str) -> None:
    """Generate and save a choropleth map for a given geographic variable."""
    try:
        # Create a DataFrame for mapping
        map_df = frequency.reset_index()
        map_df.columns = [variable, 'Frequency']

        # Generate choropleth
        fig = px.choropleth(
            map_df,
            locations=variable,
            locationmode="country names",  # or "ISO-3" if you have 3-letter codes
            color="Frequency",
            color_continuous_scale="Viridis",
            title=f"Distribution of {variable}",
            template="plotly_white"
        )

        # Save as static image (TIFF)
        figure_path = os.path.join(figures_folder, f"{variable}_map.jpg")
        pio.write_image(fig, figure_path, format="jpg", scale=3)  # requires kaleido installed

        logging.info(f"Map plot saved for {variable}")
    except Exception as e:
        logging.error(f"Error generating map plot for {variable}: {e}")
        raise


def plot_distribution(column: str, data: pd.Series, title: str, figures_folder: str) -> None:
    """Generate and save a distribution plot."""
    try:
        plt.figure(figsize=(8, 6))
        sns.histplot(data, kde=True, color="blue")
        plt.title(f"Distribution of {title}")
        plt.xlabel(title)
        plt.ylabel("Frequency")
        plt.tight_layout()
        figure_path = os.path.join(figures_folder, f"{column}_distribution.tiff")
        plt.savefig(figure_path, format="tiff", dpi=300)
        plt.close()
        logging.info(f"Distribution plot saved for {column}")
    except Exception as e:
        logging.error(f"Error generating distribution plot for {column}: {e}")
        raise


def plot_scatter_with_trend_and_corr(
    x: pd.Series, y: pd.Series, xlabel: str, ylabel: str, title: str, filename: str, figures_folder: str
) -> None:
    """Generate and save a scatter plot with a trend line and correlation coefficient."""
    try:
        plt.figure(figsize=(10, 6))

        # Convert x (years) to categorical but keep numeric for regression
        x_numeric = pd.to_numeric(x, errors="coerce").round().astype('Int64')
        y_numeric = pd.to_numeric(y, errors="coerce")
        
        plot_data = pd.DataFrame({'x': x_numeric, 'y': y_numeric}).dropna()
        
        if len(plot_data) == 0:
            logging.warning(f"No valid data points for {title}")
            return
            
        if len(plot_data['x']) != len(plot_data['y']):
            logging.error(f"Data length mismatch for {title}: x={len(plot_data['x'])}, y={len(plot_data['y'])}")
            return

        # Compute correlation
        r_value, p_value = stats.pearsonr(plot_data['x'], plot_data['y'])

        # Plot with year-formatted x-axis
        sns.regplot(x=plot_data['x'], y=plot_data['y'], 
                    scatter_kws={"alpha": 0.5}, 
                    line_kws={"color": "red"})
        
        # Force integer ticks for years
        plt.gca().xaxis.set_major_locator(plt.MaxNLocator(integer=True))
        plt.xticks(rotation=90)  # Improve readability

        # Annotation
        plt.text(
            plot_data['x'].min(), plot_data['y'].max(),
            f"r = {r_value:.3f} (p={p_value:.3f})",
            fontsize=12, color="black", bbox=dict(facecolor="white", alpha=0.7)
        )

        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.tight_layout()
        
        figure_path = os.path.join(figures_folder, filename)
        plt.savefig(figure_path, format="tiff", dpi=300)
        plt.close()
        logging.info(f"Scatter plot saved: {title}")
    except Exception as e:
        logging.error(f"Error generating scatter plot for {title}: {e}")
        raise


def save_clean_data(df: pd.DataFrame, columns_to_keep: List[str], output_file: str) -> None:
    """Save a filtered DataFrame with selected columns."""
    try:
        df_filtered = df[columns_to_keep]
        df_filtered.to_csv(output_file, index=False)
        logging.info(f"Filtered dataset saved to {output_file}")
    except Exception as e:
        logging.error(f"Error saving filtered dataset to {output_file}: {e}")
        raise


def generate_metadata_summary(df: pd.DataFrame, output_file: str) -> None:
    """Generate and save a metadata summary."""
    try:
        summary_data = []
        for column in ["Geographic Location", "Host", "Collection Date"]:
            value_counts = df[column].value_counts()
            total = value_counts.sum()
            for value, count in value_counts.items():
                percentage = (count / total) * 100
                summary_data.append([column, value, count, f"{percentage:.2f}%"])

        summary_df = pd.DataFrame(summary_data, columns=["Variable", "Value", "Frequency", "Percentage"])
        summary_df.to_csv(output_file, index=False)
        logging.info(f"Metadata summary saved to {output_file}")
    except Exception as e:
        logging.error(f"Error generating metadata summary: {e}")
        raise


def generate_annotation_summary(df: pd.DataFrame, output_file: str) -> None:
    """Generate and save an annotation summary."""
    try:
        summary_data = []
        for column in ["Annotation Count Gene Total", "Annotation Count Gene Protein-coding", "Annotation Count Gene Pseudogene"]:
            df[column] = pd.to_numeric(df[column], errors="coerce").dropna()
            if not df[column].empty:
                highest = df[column].max()
                mean = df[column].mean()
                median = df[column].median()
                lowest = df[column].min()
            else:
                highest = mean = median = lowest = "No Data"

            summary_data.append([column, "Highest", highest])
            summary_data.append([column, "Mean", mean])
            summary_data.append([column, "Median", median])
            summary_data.append([column, "Lowest", lowest])

        summary_df = pd.DataFrame(summary_data, columns=["Variable", "Summary", "Value"])
        summary_df.to_csv(output_file, index=False)
        logging.info(f"Annotation summary saved to {output_file}")
    except Exception as e:
        logging.error(f"Error generating annotation summary: {e}")
        raise


def generate_assembly_summary(df: pd.DataFrame, output_file: str) -> None:
    """Generate and save an assembly summary."""
    try:
        df["Assembly Stats Total Sequence Length"] = pd.to_numeric(df["Assembly Stats Total Sequence Length"], errors="coerce").dropna()
        if not df["Assembly Stats Total Sequence Length"].empty:
            highest = df["Assembly Stats Total Sequence Length"].max()
            mean = df["Assembly Stats Total Sequence Length"].mean()
            median = df["Assembly Stats Total Sequence Length"].median()
            lowest = df["Assembly Stats Total Sequence Length"].min()
        else:
            highest = mean = median = lowest = "No Data"

        summary_data = [
            ["Assembly Stats Total Sequence Length", "Highest", highest],
            ["Assembly Stats Total Sequence Length", "Mean", mean],
            ["Assembly Stats Total Sequence Length", "Median", median],
            ["Assembly Stats Total Sequence Length", "Lowest", lowest]
        ]

        summary_df = pd.DataFrame(summary_data, columns=["Variable", "Summary", "Value"])
        summary_df.to_csv(output_file, index=False)
        logging.info(f"Assembly summary saved to {output_file}")
    except Exception as e:
        logging.error(f"Error generating assembly summary: {e}")
        raise


COUNTRY_MAPPING = {
    # Africa (54 countries)
    "Algeria": {"Continent": "Africa", "Subcontinent": "Northern Africa"},
    "Angola": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Benin": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Botswana": {"Continent": "Africa", "Subcontinent": "Southern Africa"},
    "Burkina Faso": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Burundi": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Cabo Verde": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Cameroon": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Central African Republic": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Chad": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Comoros": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Congo": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Democratic Republic of the Congo": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Djibouti": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Egypt": {"Continent": "Africa", "Subcontinent": "Northern Africa"},
    "Equatorial Guinea": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Eritrea": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Eswatini": {"Continent": "Africa", "Subcontinent": "Southern Africa"},
    "Ethiopia": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Gabon": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Gambia": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Ghana": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Guinea": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Guinea-Bissau": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Ivory Coast": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Kenya": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Lesotho": {"Continent": "Africa", "Subcontinent": "Southern Africa"},
    "Liberia": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Libya": {"Continent": "Africa", "Subcontinent": "Northern Africa"},
    "Madagascar": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Malawi": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Mali": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Mauritania": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Mauritius": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Morocco": {"Continent": "Africa", "Subcontinent": "Northern Africa"},
    "Mozambique": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Namibia": {"Continent": "Africa", "Subcontinent": "Southern Africa"},
    "Niger": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Nigeria": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Rwanda": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Sao Tome and Principe": {"Continent": "Africa", "Subcontinent": "Middle Africa"},
    "Senegal": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Seychelles": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Sierra Leone": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Somalia": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "South Africa": {"Continent": "Africa", "Subcontinent": "Southern Africa"},
    "South Sudan": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Sudan": {"Continent": "Africa", "Subcontinent": "Northern Africa"},
    "Tanzania": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Togo": {"Continent": "Africa", "Subcontinent": "Western Africa"},
    "Tunisia": {"Continent": "Africa", "Subcontinent": "Northern Africa"},
    "Uganda": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Zambia": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},
    "Zimbabwe": {"Continent": "Africa", "Subcontinent": "Eastern Africa"},

    # Asia (48 countries)
    "Afghanistan": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "Armenia": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Azerbaijan": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Bahrain": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Bangladesh": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "Bhutan": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "Brunei": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Cambodia": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "China": {"Continent": "Asia", "Subcontinent": "Eastern Asia"},
    "Cyprus": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Georgia": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "India": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "Indonesia": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Iran": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "Iraq": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Israel": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Japan": {"Continent": "Asia", "Subcontinent": "Eastern Asia"},
    "Jordan": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Kazakhstan": {"Continent": "Asia", "Subcontinent": "Central Asia"},
    "Kuwait": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Kyrgyzstan": {"Continent": "Asia", "Subcontinent": "Central Asia"},
    "Laos": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Lebanon": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Malaysia": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Maldives": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "Mongolia": {"Continent": "Asia", "Subcontinent": "Eastern Asia"},
    "Myanmar": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Nepal": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "North Korea": {"Continent": "Asia", "Subcontinent": "Eastern Asia"},
    "Oman": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Pakistan": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "Palestine": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Philippines": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Qatar": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Russia": {"Continent": "Asia", "Subcontinent": "Northern Asia"},
    "Saudi Arabia": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Singapore": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "South Korea": {"Continent": "Asia", "Subcontinent": "Eastern Asia"},
    "Sri Lanka": {"Continent": "Asia", "Subcontinent": "Southern Asia"},
    "Syria": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Tajikistan": {"Continent": "Asia", "Subcontinent": "Central Asia"},
    "Thailand": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Timor-Leste": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Turkey": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Turkmenistan": {"Continent": "Asia", "Subcontinent": "Central Asia"},
    "United Arab Emirates": {"Continent": "Asia", "Subcontinent": "Western Asia"},
    "Uzbekistan": {"Continent": "Asia", "Subcontinent": "Central Asia"},
    "Vietnam": {"Continent": "Asia", "Subcontinent": "South-Eastern Asia"},
    "Yemen": {"Continent": "Asia", "Subcontinent": "Western Asia"},

    # Europe (44 countries)
    "Albania": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Andorra": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Austria": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "Belarus": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "Belgium": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "Bosnia and Herzegovina": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Bulgaria": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "Croatia": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Czech Republic": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "Denmark": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Estonia": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Finland": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "France": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "Germany": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "Greece": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Hungary": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "Iceland": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Ireland": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Italy": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Latvia": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Liechtenstein": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "Lithuania": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Luxembourg": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "Malta": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Moldova": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "Monaco": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "Montenegro": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Netherlands": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "North Macedonia": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Norway": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Poland": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "Portugal": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Romania": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "San Marino": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Serbia": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Slovakia": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "Slovenia": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Spain": {"Continent": "Europe", "Subcontinent": "Southern Europe"},
    "Sweden": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Switzerland": {"Continent": "Europe", "Subcontinent": "Western Europe"},
    "Ukraine": {"Continent": "Europe", "Subcontinent": "Eastern Europe"},
    "United Kingdom": {"Continent": "Europe", "Subcontinent": "Northern Europe"},
    "Vatican City": {"Continent": "Europe", "Subcontinent": "Southern Europe"},

    # North America (23 countries)
    "Antigua and Barbuda": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Bahamas": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Barbados": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Belize": {"Continent": "North America", "Subcontinent": "Central America"},
    "Canada": {"Continent": "North America", "Subcontinent": "Northern America"},
    "Costa Rica": {"Continent": "North America", "Subcontinent": "Central America"},
    "Cuba": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Dominica": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Dominican Republic": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "El Salvador": {"Continent": "North America", "Subcontinent": "Central America"},
    "Grenada": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Guatemala": {"Continent": "North America", "Subcontinent": "Central America"},
    "Haiti": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Honduras": {"Continent": "North America", "Subcontinent": "Central America"},
    "Jamaica": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Mexico": {"Continent": "North America", "Subcontinent": "Central America"},
    "Nicaragua": {"Continent": "North America", "Subcontinent": "Central America"},
    "Panama": {"Continent": "North America", "Subcontinent": "Central America"},
    "Saint Kitts and Nevis": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Saint Lucia": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Saint Vincent and the Grenadines": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "Trinidad and Tobago": {"Continent": "North America", "Subcontinent": "Caribbean"},
    "United States": {"Continent": "North America", "Subcontinent": "Northern America"},

    # South America (12 countries)
    "Argentina": {"Continent": "South America", "Subcontinent": "South America"},
    "Bolivia": {"Continent": "South America", "Subcontinent": "South America"},
    "Brazil": {"Continent": "South America", "Subcontinent": "South America"},
    "Chile": {"Continent": "South America", "Subcontinent": "South America"},
    "Colombia": {"Continent": "South America", "Subcontinent": "South America"},
    "Ecuador": {"Continent": "South America", "Subcontinent": "South America"},
    "Guyana": {"Continent": "South America", "Subcontinent": "South America"},
    "Paraguay": {"Continent": "South America", "Subcontinent": "South America"},
    "Peru": {"Continent": "South America", "Subcontinent": "South America"},
    "Suriname": {"Continent": "South America", "Subcontinent": "South America"},
    "Uruguay": {"Continent": "South America", "Subcontinent": "South America"},
    "Venezuela": {"Continent": "South America", "Subcontinent": "South America"},

    # Oceania (14 countries)
    "Australia": {"Continent": "Oceania", "Subcontinent": "Australia and New Zealand"},
    "Fiji": {"Continent": "Oceania", "Subcontinent": "Melanesia"},
    "Kiribati": {"Continent": "Oceania", "Subcontinent": "Micronesia"},
    "Marshall Islands": {"Continent": "Oceania", "Subcontinent": "Micronesia"},
    "Micronesia": {"Continent": "Oceania", "Subcontinent": "Micronesia"},
    "Nauru": {"Continent": "Oceania", "Subcontinent": "Micronesia"},
    "New Zealand": {"Continent": "Oceania", "Subcontinent": "Australia and New Zealand"},
    "Palau": {"Continent": "Oceania", "Subcontinent": "Micronesia"},
    "Papua New Guinea": {"Continent": "Oceania", "Subcontinent": "Melanesia"},
    "Samoa": {"Continent": "Oceania", "Subcontinent": "Polynesia"},
    "Solomon Islands": {"Continent": "Oceania", "Subcontinent": "Melanesia"},
    "Tonga": {"Continent": "Oceania", "Subcontinent": "Polynesia"},
    "Tuvalu": {"Continent": "Oceania", "Subcontinent": "Polynesia"},
    "Vanuatu": {"Continent": "Oceania", "Subcontinent": "Melanesia"},
}

ALIASES = {
    "USA": "United States",
    "Korea": "North Korea",
    "UK": "United Kingdom",
    "U.K.": "United Kingdom",
    "U.S.A.": "United States",
    "DRC": "Democratic Republic of the Congo",
}

def normalize_country_name(country):
    """Normalize country names using aliases and lowercase matching"""
    if pd.isna(country):
        return None
        
    # Convert to string and strip whitespace
    country = str(country).strip()
    
    # Check aliases first
    if country in ALIASES:
        return ALIASES[country]
    
    # Try case-insensitive matching with COUNTRY_MAPPING
    country_lower = country.lower()
    for mapped_country in COUNTRY_MAPPING:
        if mapped_country.lower() == country_lower:
            return mapped_country
    
    # If no match found, return original (will default to 'Unknown' later)
    return country

def extract_country(geo_location):
    """Extract country name from Geographic Location"""
    if pd.isna(geo_location):
        return None
    # Handle cases like "USA: New York" or "United States: California"
    raw_country = geo_location.split(":")[0].strip()
    return normalize_country_name(raw_country)

def add_geo_columns(df):
    """Add Continent and Subcontinent columns based on Geographic Location"""
    # Extract and normalize country names
    df['Country'] = df['Geographic Location'].apply(extract_country)
    
    # Add continent and subcontinent with case-insensitive matching
    df['Continent'] = df['Country'].apply(
        lambda x: COUNTRY_MAPPING.get(x, {}).get('Continent', 'Unknown'))
    df['Subcontinent'] = df['Country'].apply(
        lambda x: COUNTRY_MAPPING.get(x, {}).get('Subcontinent', 'Unknown'))
    
    # Drop temporary Country column
    df = df.drop(columns=['Country'], errors='ignore')
    return df

def filter_dataframe(df, args):
    """Filter DataFrame based on command line arguments"""
    if args.host:
        host_filter = '|'.join(args.host)
        df = df[df['Host'].str.contains(host_filter, case=False, na=False)]
    
    if args.year:
        # First create Year column, handling invalid dates
        try:
            # Try converting to datetime, coercing errors to NaT (Not a Time)
            df['Year'] = pd.to_datetime(df['Collection Date'], errors='coerce').dt.year
        except Exception as e:
            logging.error(f"Error parsing dates: {e}")
            raise ValueError("Could not parse dates from 'Collection Date' column")
        
        # Drop rows where Year couldn't be determined
        df = df.dropna(subset=['Year'])
        df['Year'] = df['Year'].astype(int)
        
        year_ranges = []
        for year_arg in args.year:
            if '-' in year_arg:
                start, end = year_arg.split('-')
                year_ranges.append((int(start), int(end)))
            else:
                year = int(year_arg)
                year_ranges.append((year, year))
        
        # Apply year filters
        conditions = []
        for start, end in year_ranges:
            conditions.append((df['Year'] >= start) & (df['Year'] <= end))
        
        if conditions:
            combined_condition = conditions[0]
            for cond in conditions[1:]:
                combined_condition = combined_condition | cond
            df = df[combined_condition]
    
    if args.country:
        country_filter = '|'.join(args.country)
        df = df[df['Geographic Location'].str.contains(country_filter, case=False, na=False)]

    if args.cont:
        cont_filter = '|'.join(args.cont)
        df = df[df['Continent'].str.contains(cont_filter, case=False, na=False)]    

    if args.subcont:
    # Normalize input for consistent matching
        subcont_filter = [s.strip().lower() for s in args.subcont]
        df = df[df['Subcontinent'].str.lower().isin(subcont_filter)]     
    return df

def download_genome_fasta_ftp(assembly_accession: str, assembly_name: str, output_folder: str) -> None:
    """Download genome FASTA file via FTP."""
    base_url = "https://ftp.ncbi.nlm.nih.gov/genomes/all"
    accession_parts = assembly_accession.split("_")
    dir1 = accession_parts[1][:3]
    dir2 = accession_parts[1][3:6]
    dir3 = accession_parts[1][6:9]
    url = f"{base_url}/GCF/{dir1}/{dir2}/{dir3}/{assembly_accession}_{assembly_name}/{assembly_accession}_{assembly_name}_genomic.fna.gz"

    try:
        # Create the output folder if it doesn't exist
        os.makedirs(output_folder, exist_ok=True)

        # Download the .fna.gz file
        gz_filename = os.path.join(output_folder, f"{assembly_accession}_{assembly_name}_genomic.fna.gz")
        os.system(f"wget {url} -O {gz_filename}")

        # Unzip the .fna.gz file
        fna_filename = os.path.join(output_folder, f"{assembly_accession}_{assembly_name}_genomic.fna")
        os.system(f"gunzip -c {gz_filename} > {fna_filename}")

        logging.info(f"Downloaded genome FASTA for {assembly_accession} to {fna_filename}")
    except Exception as e:
        logging.error(f"Error downloading genome FASTA for {assembly_accession}: {e}")


def main():
    """Main function to execute the script."""
    parser = argparse.ArgumentParser(description="Metadata")
    parser.add_argument("--input", required=True, help="Path to the input TSV file")
    parser.add_argument("--outdir", required=True, help="Path to the output directory")
    parser.add_argument("--sleep", type=float, default=0.5, help="Time to wait between requests (default: 0.5s)")
    parser.add_argument("--ani", nargs='+', choices=['OK', 'Inconclusive', 'Failed', 'all'], default=['OK'],
    help="Filter genomes by ANI status. Choices: OK, Inconclusive, Failed, all. Default is OK.")   
    parser.add_argument("--checkm", type=float, default=95, help="Minimum CheckM completeness threshold (default: 95)")
    parser.add_argument("--seq", action="store_true", help="Run the script to download sequences")
    parser.add_argument('--host', nargs='+', help='Filter by host species (e.g., "Homo sapiens" "Bos taurus")')
    parser.add_argument('--year', nargs='+', help='Filter by year or year range (e.g., "2015" "2018-2025")')
    parser.add_argument('--country', nargs='+', help='Filter by country (e.g., "Bangladesh" "United States")')
    parser.add_argument('--cont', nargs='+', help='Filter by countinent (e.g., "Asia" "Africa")')
    parser.add_argument('--subcont', nargs='+', help='Filter by subcontinent (e.g., "Southern Asia" "Western Africa")')
    args = parser.parse_args()

    try:
        # Load and filter data
        df = load_data(args.input)
        df = filter_data(df, args.checkm, args.ani)

        # Create output directories
        organism_name = df["Organism Name"].iloc[0].replace(" ", "_")
        organism_folder, metadata_folder, figures_folder, sequence_folder = create_output_directory(args.outdir, organism_name)

        # Fetch metadata
        df["Isolation Source"] = pd.NA
        df["Collection Date"] = pd.NA
        df["Geographic Location"] = pd.NA
        df["Host"] = pd.NA

        for index, row in tqdm(df.iterrows(), total=len(df), desc="Fetching metadata"):
            biosample_id = row["Assembly BioSample Accession"]
            if pd.notna(biosample_id):
                isolation_source, collection_date, geo_location, host = fetch_metadata(biosample_id, args.sleep)
                df.at[index, "Isolation Source"] = isolation_source
                df.at[index, "Collection Date"] = collection_date
                df.at[index, "Geographic Location"] = geo_location
                df.at[index, "Host"] = host

        # Standardize columns
        df["Collection Date"] = df["Collection Date"].apply(standardize_date)
        df["Geographic Location"] = df["Geographic Location"].apply(standardize_location)
        df["Host"] = df["Host"].apply(standardize_host)

        # Save updated data
        output_file = os.path.join(metadata_folder, "ncbi_dataset_updated.tsv")
        save_summary(df, output_file)
        
        #save metadata summary
        # Sort the DataFrame to prioritize rows with "GCF" in "Assembly Accession"
        df_sorted = df.sort_values(by="Assembly Accession", key=lambda x: x.str.startswith("GCF"), ascending=False)
        # Drop duplicates based on "Assembly Name", keeping the first occurrence (which will be the "GCF" row)
        df2 = df_sorted.drop_duplicates(subset=["Assembly Name"], keep="first")
        output_file = os.path.join(metadata_folder, "metadata_summary.csv")
        generate_metadata_summary(df2, output_file)
        
        #save assembly summary
        output_file = os.path.join(metadata_folder, "assembly_summary.csv")
        generate_assembly_summary(df2, output_file)
        
        # Generate distribution plots for assembly columns
        for column in ["Assembly Stats Total Sequence Length"]:
            df2.loc[:, column] = pd.to_numeric(df2[column], errors="coerce")
            if not df2[column].empty:
                plot_distribution(column, df2[column], column, figures_folder)

        # Generate and save bar plots
        for variable in ["Geographic Location", "Host", "Collection Date"]:
            frequency = df2[variable].value_counts()
            percentage = (frequency / frequency.sum()) * 100
            plot_bar_charts(variable, frequency, percentage, figures_folder)
        
        #Save map plots
        for variable in ["Geographic Location"]:
            frequency = df2[variable].value_counts()
            plot_geo_choropleth(variable, frequency, figures_folder)

        #save annotation summary
        df3 = df[df["Annotation Name"] == "NCBI Prokaryotic Genome Annotation Pipeline (PGAP)"]
        output_file = os.path.join(metadata_folder, "annotation_summary.csv")
        generate_annotation_summary(df3, output_file)
        
        # Generate distribution plots for annotation columns
        for column in ["Annotation Count Gene Total", "Annotation Count Gene Protein-coding", "Annotation Count Gene Pseudogene"]:
            df3.loc[:, column] = pd.to_numeric(df3[column], errors="coerce")
            if not df3[column].empty:
                plot_distribution(column, df3[column], column, figures_folder)

        # Filter and sort data for scatter plots
        df3_filtered = df3[df3["Collection Date"] != "absent"].copy()
        df3_filtered["Collection Date"] = pd.to_numeric(df3_filtered["Collection Date"], errors="coerce")
        df3_filtered = df3_filtered.dropna(subset=["Collection Date"])
        df3_filtered = df3_filtered.sort_values(by="Collection Date", ascending=True)
        df2_clean = df2[["Collection Date", "Assembly Stats Total Sequence Length"]].dropna()

        # Generate scatter plots with trend lines
        plot_scatter_with_trend_and_corr(
            x=df2_clean["Collection Date"],
            y=df2_clean["Assembly Stats Total Sequence Length"],
            xlabel="Collection Date",
            ylabel="Total Sequence Length",
            title="Scatter Plot: Total Sequence Length vs Collection Date",
            filename="scatter_plot_total_sequence_length_vs_collection_date.tiff",
            figures_folder=figures_folder
        )
        plot_scatter_with_trend_and_corr(
            x=df3_filtered["Collection Date"],
            y=df3_filtered["Annotation Count Gene Total"],
            xlabel="Collection Date",
            ylabel="Annotation Count Gene Total",
            title="Scatter Plot: Annotation Count Gene Total vs Collection Date",
            filename="scatter_plot_gene_total_vs_collection_date.tiff",
            figures_folder=figures_folder
        )
        plot_scatter_with_trend_and_corr(
            x=df3_filtered["Collection Date"],
            y=df3_filtered["Annotation Count Gene Protein-coding"],
            xlabel="Collection Date",
            ylabel="Annotation Count Gene Protein-coding",
            title="Scatter Plot: Annotation Count Gene Protein-coding vs Collection Date",
            filename="scatter_plot_gene_protein_coding_vs_collection_date.tiff",
            figures_folder=figures_folder
        )

        # Save clean data
        columns_to_keep = [
            "Organism Name", "Assembly BioSample Accession", "Assembly Accession", "Assembly Name", "Assembly BioProject Accession",
            "Organism Infraspecific Names Strain", "Assembly Stats Total Sequence Length",
            "Isolation Source", "Collection Date", "Geographic Location", "Host"
        ]
        # Modify your existing code:
        df4 = df2[columns_to_keep]
        df4 = add_geo_columns(df4)  # Add the new columns
        clean_data_file = os.path.join(metadata_folder, "ncbi_clean.csv")
        save_clean_data(df4, columns_to_keep + ['Continent', 'Subcontinent'], clean_data_file)

        # Generate and save bar plots for Continent and Subcontinent
        for variable in ["Continent", "Subcontinent"]:
            frequency = df4[variable].value_counts()
            percentage = (frequency / frequency.sum()) * 100
            plot_bar_charts(variable, frequency, percentage, figures_folder)
        
        # Check if --seq argument is provided
        if not args.seq:
            print("Please use the --seq argument to download the sequences")
            exit()

        input_file = os.path.join(metadata_folder, "ncbi_clean.csv")
        filtered_file = os.path.join(metadata_folder, "ncbi_filtered.csv")
        
        if not os.path.isfile(input_file):
            logging.error(f"Input file not found at: {input_file}")
            raise FileNotFoundError(f"Input file not found at: {input_file}")

        df_clean = pd.read_csv(input_file)

        required_columns = {"Assembly Accession", "Assembly Name", "Host", "Collection Date", "Geographic Location", "Continent", "Subcontinent"}
        if not required_columns.issubset(df_clean.columns):
            missing = required_columns - set(df_clean.columns)
            raise ValueError(f"Required columns {missing} not found in the CSV file.")

        # Apply filters if any were provided
        if args.host or args.year or args.country or args.cont or args.subcont:
            df_filtered = filter_dataframe(df_clean, args)
            logging.info(f"Filtered to {len(df_filtered)} records based on provided criteria")
            
            # Save filtered dataframe
            df_filtered.to_csv(filtered_file, index=False)
            logging.info(f"Saved filtered data to: {filtered_file}")
        else:
            df_filtered = df_clean
            logging.info("No filters provided, using all available records")

        if len(df_filtered) == 0:
            logging.error("No records match the filtering criteria")
            exit()

        # Download genome FASTA files
        for index, row in tqdm(df_filtered.iterrows(), total=len(df_filtered), 
                              desc="Downloading genome FASTA files"):
            assembly_accession = row["Assembly Accession"]
            assembly_name = row["Assembly Name"]
            download_genome_fasta_ftp(assembly_accession, assembly_name, sequence_folder)

        logging.info(f"Sequence downloading completed. Downloaded {len(df_filtered)} sequences.")
        logging.info("Script completed successfully.")

    except Exception as e:
        logging.error(f"Script failed: {e}")


if __name__ == "__main__":
    main()
