# V4

from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import numpy as np
from scipy import stats
import pandas as pd


@dataclass
class WeightingFactor:
    """Configuration for a single weighting factor (e.g., tissue, disease type)"""
    name: str
    weights: Dict[str, float]
    column: str  # Column name in the data
    default_value: float = 0.4

    def __post_init__(self):
        self.weights = {x.lower().strip(): y for x, y in self.weights.items()}

    def get_weight(self, value: str) -> float:
        """Get weight for a specific value of this factor"""
        return self.weights.get(value.lower().strip(), self.default_value)


class ContextWeights:
    """Configuration for multiple weighting factors to evaluate RNA study relevance"""

    def __init__(self,
                 factors: List[WeightingFactor],
                 combination_method: str = 'geometric_mean'):
        """
        Args:
            factors: List of WeightingFactor configurations
            combination_method: How to combine multiple weights
                              ('geometric_mean', 'arithmetic_mean', 'product')
        """
        self.factors = {factor.name: factor for factor in factors}
        self.combination_method = combination_method

    def combine_weights(self, weights: List[float]) -> float:
        """Combine multiple weights into a single score"""
        if len(weights) == 0:
            return 0.0

        if self.combination_method == 'geometric_mean':
            return np.prod(weights) ** (1 / len(weights))
        elif self.combination_method == 'arithmetic_mean':
            return np.mean(weights)
        elif self.combination_method == 'product':
            return np.prod(weights)
        else:
            raise ValueError(f"Unknown combination method: {self.combination_method}")

    def calculate_study_relevance(self, study_data: pd.Series) -> float:
        """Calculate relevance score for a single RNA study"""
        weights = []

        for factor in self.factors.values():
            if factor.column in study_data:
                weight = factor.get_weight(study_data[factor.column])
                weights.append(weight)

        return self.combine_weights(weights)

    def get_weighted_evidence(self,
                              data: pd.DataFrame,
                              gene_column: str,
                              gene: str,
                              fc_column: str,
                              pvalue_column: str) -> tuple:
        """Get weighted evidence for a gene across RNA studies"""
        gene_studies = data[data[gene_column] == gene]

        if len(gene_studies) == 0:
            return None, None

        # Calculate relevance scores for each study
        relevance_scores = []
        for _, study in gene_studies.iterrows():
            relevance = self.calculate_study_relevance(study)
            relevance_scores.append(relevance)

        # Combine with statistical significance
        pvalue_weights = -np.log10(gene_studies[pvalue_column])
        total_weights = np.array(relevance_scores) * pvalue_weights

        # Calculate weighted fold change
        weighted_fc = np.average(gene_studies[fc_column], weights=total_weights)

        # Combine p-values using Stouffer's method with weights
        weighted_p = stats.combine_pvalues(
            gene_studies[pvalue_column],
            weights=relevance_scores,
            method='stouffer'
        )[1]

        return weighted_fc, weighted_p


def create_default_ipf_weights() -> ContextWeights:
    """Create default weighting configuration for IPF studies"""

    disease_factor = WeightingFactor(
        name='disease',
        column='disease',
        weights={"IPF": 1,
                 "Alcoholic liver disease": 0.8,
                 'Atrial fibrillation': 0.6,
                 'Biliary Atresia': 0.7,
                 'Cirrhosis': 1.,
                 "CKD": 0.9,
                 'COPD': 0.9,
                 "Crohn's disease": 0.6,
                 'Fibrosis': 0.7,
                 'Frozen shoulder': 0.7,
                 'Heart failure': 0.6,
                 "Ischemia": 0.8,
                 'NAFLD': 0.8,
                 'NALFD': 0.8,
                 "Steatohepatitis": 1,
                 'Systemic scleroderma': 1.},
        default_value=0.5
    )

    comparison_factor = WeightingFactor(
        name='setting',
        column='setting',
        weights={"case_control": 0.8,
                 "case_control_advanced": 0.9,
                 "fibrosis_vs_severe_fibrosis": 0.6,
                 "cell_model": 0.5,
                 "case_control_severe": 1.0

                 },
        default_value=0.5
    )

    tissue_factor = WeightingFactor(
        name='tissue',
        column='tissue',
        weights={'Lung': 1.0,
                 'Liver': 0.7,
                 'PBMC': 0.5,
                 'Heart': 0.6,
                 'Trachea': 0.6,
                 'Kidney': 0.7,
                 'Skin': 0.7,
                 "Synovium": 0.4,
                 'Macrophage': 0.3,
                 'Intestine': 0.6,
                 'Tendom': 0.6
                 },
        default_value=0.5
    )

    return ContextWeights(
        factors=[disease_factor, tissue_factor, comparison_factor],
        combination_method='geometric_mean'
    )


# Function to process RNA study data
def process_rna_studies(total_df: List[str],
                        study_col='file',
                        weights_config: Optional[ContextWeights] = None) -> pd.DataFrame:
    """Process multiple RNA studies and combine their results"""
    if weights_config is None:
        weights_config = create_default_ipf_weights()

    all_results = []

    for study in total_df[study_col].unique():
        study_df = total_df[total_df[study_col] == study].copy()
        # [!]: assuming there is only one value!
        study_metadata = {x: study_df[x].unique().tolist()[0] for x in weights_config.factors}
        relevance = weights_config.calculate_study_relevance(study_metadata)

        # Add relevance score to results
        study_df.loc[:, 'study_relevance'] = relevance
        all_results.append(study_df)

    # Combine all studies
    combined_data = pd.concat(all_results, ignore_index=True)

    return combined_data


# Example usage
if __name__ == "__main__":
    # Create weights configuration
    weights_config = create_default_ipf_weights()

    # Example study data
    example_study = pd.Series({
        'disease': 'IPF',
        'tissue': 'Lung',
        'comparison': 'case_control'
    })

    # Calculate study relevance
    relevance = weights_config.calculate_study_relevance(example_study)
    print(f"Study relevance score: {relevance:.3f}")

    # Example gene evidence calculation
    study_data = pd.DataFrame({
        'gene': ['TGFB1', 'TGFB1', 'COL1A1'],
        'disease': ['IPF', 'Fibrosis', 'IPF'],
        'tissue': ['Lung', 'Lung', 'PBMC'],
        'comparison': ['case_control', 'cell_model', 'case_control'],
        'log2FC': [1.5, 0.8, 2.3],
        'pvalue': [0.001, 0.01, 0.0001]
    })

    # Get weighted evidence for TGFB1
    fc, pval = weights_config.get_weighted_evidence(
        data=study_data,
        gene_column='gene',
        gene='TGFB1',
        fc_column='log2FC',
        pvalue_column='pvalue'
    )
    print(f"\nWeighted evidence for TGFB1:")
    print(f"Fold change: {fc:.2f}")
    print(f"Combined p-value: {pval:.6f}")