import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score, mean_absolute_error
import pickle
from pathlib import Path
from typing import Dict, Tuple, Union, Optional, List
from collections import defaultdict

from shan_pack.reporting import ModelEvaluator
from shan_pack.transfernet import IPFTransferNet
from shan_pack.clock import IPFClock
from shan_pack.other import IPFPathways, UKBDataset

class OLINKScaler:
    """Combined scaler and formatter for OLINK protein data"""

    def __init__(self, required_features: List[str]):
        """Initialize scaler with optional required features

        Args:
            required_features: List of protein features required by the model.
                             If None, will accept all features found in data.
        """
        self.data_min = None
        self.data_max = None
        self.data_median = None
        self.feature_means = None
        self.feature_stds = None
        self.required_features = required_features

    def long_to_wide(self,
                         data: pd.DataFrame,
                         id_col: str = 'patient_id',
                         gene_col: str = 'gene_symbol',
                         value_col: str = 'NPX') -> pd.DataFrame:
        """Format long-format protein data into wide format

        Args:
            data: DataFrame in long format (one measurement per row)
            id_col: Name of column containing sample/patient IDs
            gene_col: Name of column containing gene/protein names
            value_col: Name of column containing measurement values

        Returns:
            DataFrame in wide format with samples as rows and proteins as columns
        """
        # Validate input columns
        required_cols = [id_col, gene_col, value_col]
        missing_cols = [col for col in required_cols if col not in data.columns]
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")

        # Pivot data to wide format
        wide_data = data.pivot(
            index=id_col,
            columns=gene_col,
            values=value_col
        )

        # Check for required features if specified
        missing_features = set(self.required_features) - set(wide_data.columns)
        if missing_features:
            print(f"Warning: Missing required features: {missing_features}")
            # Add missing columns
            for feature in missing_features:
                wide_data[feature] = np.nan

        # Ensure columns are in correct order
        wide_data = wide_data[list(self.required_features)]

        return wide_data

    def fit(self,
            data: Union[pd.DataFrame, np.ndarray]) -> 'OLINKScaler':
        """Fit scaler to training data and compute feature means

        Args:
            data: Training data to fit scaler on

        Returns:
            Self for chaining
        """

        if isinstance(data, pd.DataFrame):
            data = data.loc[:, self.required_features]
            data = data.values

        # Store raw means for imputation
        self.feature_means = np.nanmean(data, axis=0)
        self.feature_stds = np.nanstd(data, axis=0)

        # Compute scaling parameters
        self.data_min = np.nanmin(data, axis=0)
        self.data_max = np.nanmax(data, axis=0)
        self.data_range = (self.data_max - self.data_min)
        # Avoid division by zero
        self.data_range[self.data_range == 0] = 1

        # Compute median after min-max scaling
        scaled = (data - self.data_min) / self.data_range
        self.data_median = np.nanmedian(scaled, axis=0)
        return self

    def prepare_wide(self,
                  data: Union[pd.DataFrame, np.ndarray],
                  impute_missing: bool = True,
                  transform: Optional[str] = None) -> np.ndarray:
        """Transform data using fitted parameters with optional mean imputation

        Args:
            data: Data to transform
            impute_missing: Whether to impute missing values using training means

        Returns:
            Transformed data as numpy array
        """
        if self.data_min is None:
            raise ValueError("Scaler must be fitted before transforming data")

        # Convert DataFrame to array if needed
        if isinstance(data, pd.DataFrame):
            data = data.loc[:, self.required_features]
            data = data.values

        # Make a copy to avoid modifying input
        data = data.copy()

        # Impute missing values if requested
        if impute_missing and self.feature_means is not None:
            data = self.impute_means(data)

        if not transform is None:
            match transform:
                case 'minmax':
                    data = self.transform_minmax(data)
                case 'standard':
                    data = self.transform_standard(data)
                case _:
                    pass
        return data.astype(np.float32)

    def transform_minmax(self, data: np.ndarray) -> np.ndarray:
        data = data.copy()

        sample_min = np.nanmin(data, axis=0)
        sample_max = np.nanmax(data, axis=0)

        var_mask = sample_max != sample_min
        data[:, var_mask] = data[:, var_mask] - sample_min[var_mask] / (sample_max[var_mask] - sample_min[var_mask])
        data[:, var_mask] = data[:, var_mask] * (self.data_max[var_mask] - self.data_min[var_mask]) + self.data_min[var_mask]

        return(data)

    def transform_standard(self, data: np.ndarray) -> np.ndarray:

        data = data.copy()

        sample_mean = np.nanmean(data, axis=0)
        sample_std = np.nanstd(data, axis=0)

        var_mask = sample_std != 0
        data[:, var_mask] = (data[:, var_mask] - sample_mean[var_mask])/sample_std[var_mask]
        data[:, var_mask] = (data[:, var_mask] * self.feature_stds[var_mask]) + self.feature_means[var_mask]

        return(data)
    def impute_means(self, data: np.ndarray) -> np.ndarray:
        data = data.copy()

        missing_mask = np.isnan(data)
        if missing_mask.any():
            n_missing = missing_mask.sum()
            print(f"Imputing {n_missing} missing values using training means")
            data[missing_mask] = np.take(self.feature_means,
                                         np.where(missing_mask)[1])
        return(data)

    def fit_transform(self,
                      data: Union[pd.DataFrame, np.ndarray],
                      impute_missing: bool = True) -> np.ndarray:
        """Convenience method to fit and transform in one step

        Args:
            data: Data to fit and transform
            impute_missing: Whether to impute missing values

        Returns:
            Transformed data as numpy array
        """
        return self.fit(data).transform(data, impute_missing)

    def save(self, filepath: str):
        """Save scaler parameters including feature means

        Args:
            filepath: Path to save scaler parameters
        """
        with open(filepath, 'wb') as f:
            pickle.dump({
                'data_min': self.data_min,
                'data_max': self.data_max,
                'data_median': self.data_median,
                'feature_means': self.feature_means,
                'feature_stds': self.feature_stds,
                'required_features': self.required_features
            }, f)

    @classmethod
    def load(cls, filepath: str) -> 'OLINKScaler':
        """Load saved scaler parameters

        Args:
            filepath: Path to saved scaler parameters

        Returns:
            Loaded scaler instance
        """
        with open(filepath, 'rb') as f:
            params = pickle.load(f)
        scaler = cls(required_features=params.get('required_features'))
        scaler.data_min = params['data_min']
        scaler.data_max = params['data_max']
        scaler.data_median = params['data_median']
        scaler.feature_means = params['feature_means']
        scaler.feature_stds = params['feature_stds']
        scaler.data_range = (scaler.data_max - scaler.data_min)
        scaler.data_range[scaler.data_range == 0] = 1
        return scaler

class AgingClockPredictor:
    """Wrapper class for making predictions with both types of aging clock models"""

    def __init__(self,
                 model_path: str,
                 feats: Dict[str, str],
                 model_type: str = 'auto',
                 device: Optional[str] = None):
        """Initialize predictor

        Args:
            model_path: Path to saved model
            model_type: Type of model ('auto', 'ipf_clock', or 'ipf_transfer')
            device: Device to run predictions on
        """
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = torch.device(self.device)

        # Load model
        self._load_model(model_path, model_type, feats)

        # Initialize evaluator
        self.evaluator = ModelEvaluator(self.model)

    def _load_model(self,
                    model_path: str,
                    model_type: str,
                    feats: Dict[str, str]) -> None:
        """Load the trained model

        Args:
            model_path: Path to saved model
            model_type: Type of model to load
        """
        model_dict = torch.load(model_path, map_location=self.device)

        # Get OLINK features from model
        self.olink_features = list(feats.keys())

        # Determine model type if auto
        if model_type == 'auto':
            # Check model architecture from state dict
            has_pathway_head = any('pathway_head' in key for key in model_dict['model_state_dict'].keys())
            model_type = 'ipf_transfer' if has_pathway_head else 'ipf_clock'

        # Initialize appropriate model
        if model_type == 'ipf_transfer':
            self.model = IPFTransferNet(
                olink_dim=len(self.olink_features),
                olink_to_gene=feats,
                pathway_definitions=IPFPathways()
            ).float()
        elif model_type == 'ipf_clock':
            self.model = IPFClock(
                olink_dim=len(self.olink_features),
                olink_to_gene=feats,
                pathway_definitions=IPFPathways()
            ).float()
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

        self.model.to(self.device)
        self.model.load_state_dict(model_dict['model_state_dict'])
        self.model.eval()

        # Store model type
        self.model_type = model_type

    def _predict_batch(self,
                       batch: torch.Tensor) -> Tuple:
        """Make predictions for a single batch

        Args:
            batch: Input tensor of shape (batch_size, n_features)

        Returns:
            Tuple of predictions (age, pathway_scores, attention_weights)
        """
        if self.model_type == 'ipf_transfer':
            age_pred, pathway_pred, attention = self.model(batch)
            return age_pred, pathway_pred, attention
        else:
            age_pred, attention = self.model(batch)
            return age_pred, None, attention

    def predict(self,
                data: Union[pd.DataFrame, np.ndarray],
                true_ages: Optional[np.ndarray] = None,
                batch_size: int = 32) -> Dict[str, np.ndarray]:
        """Make predictions on new data

        Args:
            data: OLINK protein measurements
            true_ages: Optional array of true ages for evaluation
            batch_size: Batch size for prediction

        Returns:
            Dictionary containing predictions and optionally evaluation metrics
        """
        # Create dataset
        if true_ages is not None:
            dataset = UKBDataset(data, true_ages.astype(np.float32))
        else:
            dataset = UKBDataset(data, np.zeros(len(data)))

        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False
        )

        # Get predictions
        age_preds = []
        attention_weights = []
        pathway_preds = [] if self.model_type == 'ipf_transfer' else None

        with torch.no_grad():
            for batch in loader:
                olink = batch['olink'].to(self.device)
                age_pred, pathway_pred, attention = self._predict_batch(olink)

                age_preds.append(age_pred.squeeze().cpu().numpy())
                attention_weights.append(attention.cpu().numpy())
                if pathway_pred is not None:
                    pathway_preds.append(pathway_pred.cpu().numpy())

        # Combine predictions
        results = {
            'age_predictions': np.concatenate(age_preds),
            'attention_weights': np.concatenate(attention_weights)
        }

        # Add pathway predictions if available
        if pathway_preds is not None:
            results['pathway_predictions'] = np.concatenate(pathway_preds)

            # Add true pathway scores if model supports it
            if hasattr(self.model, 'compute_pathway_scores'):
                true_pathway_scores = []
                with torch.no_grad():
                    for batch in loader:
                        olink = batch['olink'].to(self.device)
                        scores = self.model.compute_pathway_scores(olink)
                        true_pathway_scores.append(scores.cpu().numpy())
                results['true_pathway_scores'] = np.concatenate(true_pathway_scores)

                # Calculate correlations
                correlations = []
                for i in range(results['pathway_predictions'].shape[1]):
                    corr = np.corrcoef(
                        results['true_pathway_scores'][:, i],
                        results['pathway_predictions'][:, i]
                    )[0, 1]
                    correlations.append(corr)
                results['pathway_correlations'] = np.array(correlations)

        # Add evaluation metrics if true ages provided
        if true_ages is not None:
            results['true_ages'] = true_ages
            results['mae'] = mean_absolute_error(true_ages, results['age_predictions'])
            results['r2'] = r2_score(true_ages, results['age_predictions'])

        return results

    def plot_predictions(self,
                         predictions: Dict[str, np.ndarray],
                         output_dir: Optional[Path] = None):
        """Plot prediction results

        Args:
            predictions: Dictionary of predictions from predict()
            output_dir: Optional directory to save plots
        """
        if 'true_ages' not in predictions:
            raise ValueError("True ages required for plotting predictions")

        # Plot age predictions
        age_fig = self.evaluator.plot_age_predictions(
            predictions['true_ages'],
            predictions['age_predictions']
        )

        if output_dir:
            age_fig.savefig(output_dir / 'age_predictions.png',
                            bbox_inches='tight',
                            dpi=300)

        # Plot pathway predictions if available
        if 'pathway_predictions' in predictions:
            pathway_fig = self.evaluator.plot_pathway_correlations(
                predictions['pathway_correlations']
            )

            if output_dir:
                pathway_fig.savefig(output_dir / 'pathway_predictions.png',
                                    bbox_inches='tight',
                                    dpi=300)
        plt.show()

    def save_predictions(self,
                         predictions: Dict[str, np.ndarray],
                         output_dir: Path,
                         sample_ids: Optional[List[str]] = None):
        """Save predictions to files

        Args:
            predictions: Dictionary of predictions from predict()
            output_dir: Directory to save outputs
            sample_ids: Optional list of sample IDs
        """
        output_dir.mkdir(exist_ok=True)

        # Create DataFrame for age predictions
        if sample_ids is None:
            sample_ids = [f'Sample_{i}' for i in range(len(predictions['age_predictions']))]

        outputs = pd.DataFrame({
            'sample_id': sample_ids,
            'predicted_age': predictions['age_predictions']
        })

        if 'true_ages' in predictions:
            outputs['true_age'] = predictions['true_ages']

        outputs.to_csv(output_dir / 'age_predictions.csv', index=False)

        # Save pathway predictions if available
        if 'pathway_predictions' in predictions:
            pathway_names = ['TGF-β', 'ECM', 'Inflammation', 'Oxidative Stress']
            pathway_df = pd.DataFrame(
                predictions['pathway_predictions'],
                columns=[f'{p}_score' for p in pathway_names],
                index=sample_ids
            )

            if 'true_pathway_scores' in predictions:
                true_scores = pd.DataFrame(
                    predictions['true_pathway_scores'],
                    columns=[f'{p}_true_score' for p in pathway_names],
                    index=sample_ids
                )
                pathway_df = pd.concat([pathway_df, true_scores], axis=1)

            pathway_df.to_csv(output_dir / 'pathway_predictions.csv')

            # Save correlation summary
            corr_summary = pd.DataFrame({
                'pathway': pathway_names,
                'correlation': predictions['pathway_correlations']
            })
            corr_summary.to_csv(output_dir / 'pathway_correlations.csv', index=False)


def main():
    """Example usage"""
    # Set paths
    data_path = "./protein_data.csv"
    model_path = "best_clock.pt"
    output_dir = Path("./clock_outputs")
    output_dir.mkdir(exist_ok=True)

    # Load and initialize predictor
    predictor = AgingClockPredictor(model_path)

    # Initialize scaler with required features
    scaler = OLINKScaler(required_features=predictor.olink_features)

    # Load and format data
    raw_data = pd.read_csv(data_path)
    formatted_data = scaler.format_long_data(raw_data)

    # Fit scaler on training data (if not already fitted)
    scaler.fit(formatted_data)
    scaler.save(output_dir / 'olink_scaler.pkl')

    # Transform new data
    scaled_data = scaler.transform(formatted_data)

    # Make predictions
    results = predictor.predict(scaled_data)

    # Save predictions
    output = pd.DataFrame({
        'patient_id': formatted_data.index,
        'predicted_age': results['age_predictions']
    })
    output.to_csv(output_dir / 'predictions.csv', index=False)

    print(f"Results saved to {output_dir}")
    print(f"Predicted ages range: [{results['age_predictions'].min():.1f}, "
          f"{results['age_predictions'].max():.1f}]")