import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
from sklearn.metrics import mean_absolute_error, r2_score

from shan_pack.context_weight import ContextWeights, create_default_ipf_weights
from shan_pack.losses import TripleLoss, LossConfig
from shan_pack.other import IPFPathways, ScaleShift, UKBDataset

class IPFTransferNet(nn.Module):

    def __init__(self,
                 olink_dim: int,
                 olink_to_gene: dict,
                 pathway_definitions,
                 feature_dim: int = 128,
                 dropout_rate: float = 0.2,
                 attention_reg_strength: float = 0.25,
                 loss_config: Optional[LossConfig] = None):
        super().__init__()

        # Basic attributes
        self.feature_dim = feature_dim
        self.olink_to_gene = olink_to_gene
        self.pathway_definitions = pathway_definitions
        self.pathway_knowledge = None
        self.pathway_attention = None

        # Map OLINK IDs to column indices
        self.olink_columns = {id: idx for idx, id in enumerate(self.olink_to_gene.keys())}
        self.setup_protein_indices()

        self.loss_fn = TripleLoss(
                                  LossConfig(
                                    age_weight=2,
                                    pathway_weight=1,
                                    attention_reg_weight = 0.05,
                                )
                            )

        # Feature extraction backbone
        self.feature_extractor = nn.Sequential(
            nn.Linear(olink_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, feature_dim),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU()
        )

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.BatchNorm1d(128),
            nn.Tanh(),
            nn.Linear(128, 4),
            nn.Softmax(dim=1)
        )

        # Age prediction branch - enhanced to handle age scale
        self.age_transform = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.age_head = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Softplus(),  # Ensures positive output
            ScaleShift(scale=85, shift=40)  # Scale to approximate age range [40, 125]
        )

        # Pathway prediction branch
        self.pathway_transform = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Dropout(.0)
        )

        self.pathway_head = nn.Sequential(
            nn.Linear(128, 4),
            nn.Tanh()  # Bound pathway scores between -1 and 1
        )

        # Attention regularization
        self.initial_attention_weights = None
        self.attention_reg_strength = attention_reg_strength

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass"""
        # Extract shared features
        features = self.feature_extractor(x)

        # Generate attention weights
        attention_weights = self.attention(features)

        # Transform features separately for each task
        age_features = self.age_transform(features)
        pathway_features = self.pathway_transform(features)

        # Generate predictions
        age_pred = self.age_head(age_features)
        pathway_scores = self.pathway_head(pathway_features)
        # pathway_scores = self.pathway_head(pathway_features)

        # Apply attention weights to pathway scores
        pathway_scores = attention_weights * pathway_scores

        return age_pred, pathway_scores, attention_weights

    def setup_protein_indices(self):
        """Map OLINK IDs to their column indices"""
        self.pathway_proteins = {
            'tgf_beta': [],
            'ecm': [],
            'inflammation': [],
            'oxidative': []
        }

        # Create reverse mapping from gene to OLINK ID
        gene_to_olink = {v: k for k, v in self.olink_to_gene.items()}

        pathways = self.pathway_definitions.get_all_pathways()

        # Map pathway proteins to column indices
        for gene in pathways['TGF_BETA']:
            if gene in gene_to_olink:
                olink_id = gene_to_olink[gene]
                if olink_id in self.olink_columns:
                    self.pathway_proteins['tgf_beta'].append(self.olink_columns[olink_id])

        for gene in pathways['ECM_REMODELING']:
            if gene in gene_to_olink:
                olink_id = gene_to_olink[gene]
                if olink_id in self.olink_columns:
                    self.pathway_proteins['ecm'].append(self.olink_columns[olink_id])

        for gene in pathways['INFLAMMATION']:
            if gene in gene_to_olink:
                olink_id = gene_to_olink[gene]
                if olink_id in self.olink_columns:
                    self.pathway_proteins['inflammation'].append(self.olink_columns[olink_id])

        for gene in pathways['OXIDATIVE_STRESS']:
            if gene in gene_to_olink:
                olink_id = gene_to_olink[gene]
                if olink_id in self.olink_columns:
                    self.pathway_proteins['oxidative'].append(self.olink_columns[olink_id])

    def encode_pathway_knowledge(self,
                                 rna_data: pd.DataFrame,
                                 pathway_annotations: Dict[str, List[str]],
                                 weights_config) -> None:
        """Encode RNA evidence into pathway knowledge"""
        pathway_embeddings = []
        pathway_importances = []

        for pathway, genes in pathway_annotations.items():
            # Get weighted evidence for pathway genes
            pathway_de = []
            pathway_significance = 0

            for gene in genes:
                fc, pval = weights_config.get_weighted_evidence(
                    rna_data,
                    gene_column='gene',
                    gene=gene,
                    fc_column='log2FC',
                    pvalue_column='pvalue'
                )
                if fc is not None:
                    pathway_de.append((fc, pval))
                    pathway_significance += -np.log10(pval)

            if pathway_de:
                # Create pathway embedding
                embedding = np.zeros(self.feature_dim)
                for fc, pval in pathway_de:
                    weight = -np.log10(pval)
                    embedding += weight * fc

                pathway_embeddings.append(embedding / len(pathway_de))
                pathway_importances.append(pathway_significance / len(pathway_de))

        if pathway_embeddings:
            self.pathway_knowledge = torch.tensor(
                pathway_embeddings,
                dtype=torch.float32,
                device=next(self.parameters()).device  # Put on same device as model
            )

            # Store initial attention weights
            importances = np.array(pathway_importances)
            normalized_importances = importances / importances.sum()

            # Store on same device as model
            self.initial_attention_weights = torch.tensor(
                normalized_importances,
                dtype=torch.float32,
                requires_grad=False,  # Make sure these don't get updated
                device=next(self.parameters()).device  # Put on same device as model
            )

    def compute_pathway_scores(self, olink_data: torch.Tensor) -> torch.Tensor:
        """Compute ground truth pathway activation scores from OLINK data"""
        batch_size = olink_data.shape[0]
        pathway_scores = torch.zeros((batch_size, 4), device=olink_data.device)

        # Compute mean expression for each pathway's proteins
        for i, proteins in enumerate([
            self.pathway_proteins['tgf_beta'],
            self.pathway_proteins['ecm'],
            self.pathway_proteins['inflammation'],
            self.pathway_proteins['oxidative']
        ]):
            if proteins:  # Only if we have OLINK measurements for pathway proteins
                pathway_scores[:, i] = olink_data[:, proteins].mean(dim=1)

        return pathway_scores

    def get_attention_regularization_loss(self, current_attention: torch.Tensor) -> torch.Tensor:
        """Calculate regularization loss to maintain initial pathway importance"""
        if self.initial_attention_weights is None:
            return torch.tensor(0.0, device=current_attention.device)

        # Move initial weights to same device as current attention
        initial_weights = self.initial_attention_weights.to(current_attention.device)

        # Calculate KL divergence between current and initial attention distributions
        kl_div = F.kl_div(
            current_attention.log(),  # Current attention (log probabilities)
            initial_weights.expand_as(current_attention),  # Initial weights expanded to batch size
            reduction='batchmean'
        )

        return self.attention_reg_strength * kl_div


def train_epoch(model: nn.Module,
                train_loader: torch.utils.data.DataLoader,
                optimizer: torch.optim.Optimizer,
                device: str,
                debugging_on: bool = True) -> Dict[str, float]:
    """Training loop with integrated loss calculation"""
    model.train()
    total_losses = {
        'total': 0,
        'age': 0,
        'pathway': 0,
        'attention_reg': 0
    }

    i = 0
    for batch in train_loader:
        i+=1
        optimizer.zero_grad()

        # Move data to device
        olink = batch['olink'].to(device)
        ages = batch['age'].to(device)

        # Calculate ground truth pathway scores
        true_pathway_scores = model.compute_pathway_scores(olink)

        # Forward pass
        age_pred, pathway_pred, attention_weights = model(olink)

        # Calculate losses using model's loss function
        _, loss_dict = model.loss_fn(
            pred_ages=age_pred,
            true_ages=ages,
            pred_pathways=pathway_pred,
            true_pathways=true_pathway_scores,
            attention_weights=attention_weights,
            initial_attention=model.initial_attention_weights
        )

        if debugging_on and i%2000 == 0:
            _, loss_info = model.loss_fn(age_pred,
                                         ages,
                                         pathway_pred,
                                         true_pathway_scores,
                                         attention_weights,
                                         model.initial_attention_weights)
            debug_info  = loss_info['debug']['pathway_scores']
            print("\n==> Pathway Debugging Info:")
            print(f"Raw loss per pathway: {debug_info['raw_loss_per_pathway']}")
            print(f"Raw total loss before scaling: {debug_info['raw_total_loss']}")
            print(f"Scaled total loss after log1p: {debug_info['scaled_total_loss']}")
            print(f"Variance loss: {debug_info['variance_loss']}")
            print(f"Final loss: {debug_info['final_loss']}")
            debug_info = loss_info['debug']['attention_reg']
            print("\n==> Attention Debugging Info:")
            print(f"Raw loss: {debug_info['attention_kl_loss']}")
            print(f"Entropy term: {debug_info['attention_entropy_reg']}")
            print(f"Final loss: {debug_info['final_loss']}")



        # Get batch size and compute batch losses
        batch_size = model.loss_fn.get_batch_size(ages)
        batch_losses = model.loss_fn.compute_batch_losses(batch_size, loss_dict)

        # Backward pass
        loss_dict['total'].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Accumulate losses
        for k in total_losses:
            total_losses[k] += batch_losses[k]

    # Average losses
    n_samples = len(train_loader.dataset)
    return {k: v / n_samples for k, v in total_losses.items()}


def analyze_samples(model: nn.Module,
                    data_loader: torch.utils.data.DataLoader,
                    device: str) -> Dict[str, float]:
    """Evaluate model on data loader using integrated loss calculation"""
    model.eval()
    total_losses = {
        'total': 0,
        'age': 0,
        'pathway': 0,
        'attention_reg': 0
    }

    age_preds = []
    true_ages = []

    with torch.no_grad():
        for batch in data_loader:
            olink = batch['olink'].to(device)
            ages = batch['age'].to(device)

            # Forward pass
            age_pred, pathway_pred, attention_weights = model(olink)

            # Store predictions
            age_preds.append(age_pred.squeeze().cpu().numpy())
            true_ages.append(ages.cpu().numpy())

            # Calculate ground truth pathway scores
            true_pathway_scores = model.compute_pathway_scores(olink)

            # Calculate losses using model's loss function
            _, loss_dict = model.loss_fn(
                pred_ages=age_pred,
                true_ages=ages,
                pred_pathways=pathway_pred,
                true_pathways=true_pathway_scores,
                attention_weights=attention_weights,
                initial_attention=model.initial_attention_weights
            )

            # Get batch size and compute batch losses
            batch_size = model.loss_fn.get_batch_size(ages)
            batch_losses = model.loss_fn.compute_batch_losses(batch_size, loss_dict)

            # Accumulate losses
            for k in total_losses:
                total_losses[k] += batch_losses[k]

    # Combine predictions for metrics calculation
    age_preds = np.concatenate(age_preds)
    true_ages = np.concatenate(true_ages)

    # Calculate metrics
    r2 = r2_score(true_ages, age_preds)
    mae = mean_absolute_error(true_ages, age_preds)

    # Average losses
    n_samples = len(data_loader.dataset)
    avg_losses = {k: v / n_samples for k, v in total_losses.items()}

    # Add metrics to results
    avg_losses.update({
        'r2': r2,
        'mae': mae
    })

    return avg_losses


def analyze_attention_drift(model: nn.Module,
                            data_loader: torch.utils.data.DataLoader,
                            device: str) -> Dict[str, np.ndarray]:
    """Analyze how much attention weights have drifted from initial values"""
    model.eval()
    all_attention_weights = []

    with torch.no_grad():
        for batch in data_loader:
            olink = batch['olink'].to(device)
            _, _, attention = model(olink)
            all_attention_weights.append(attention.cpu().numpy())

    avg_attention = np.mean(np.concatenate(all_attention_weights, axis=0), axis=0)
    initial_weights = model.initial_attention_weights.cpu().numpy()

    return {
        'current_weights': avg_attention,
        'initial_weights': initial_weights,
        'absolute_drift': np.abs(avg_attention - initial_weights),
        'relative_drift': np.abs((avg_attention - initial_weights) / initial_weights)
    }


def main(olink_data: pd.DataFrame,
         ages: np.ndarray,
         overlapped_data: pd.DataFrame,
         olink_to_gene: Dict[str, str],
         batch_size: int = 32,
         num_epochs: int = 100,
         learning_rate: float = 0.001,
         validation_split: float = 0.2,
         random_seed: int = 42):
    """
    Main training workflow with improved training stability and monitoring
    """
    # Set random seeds
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create dataset splits
    n_samples = len(olink_data)
    n_val = int(n_samples * validation_split)
    indices = np.random.permutation(n_samples)

    train_idx = indices[n_val:]
    val_idx = indices[:n_val]

    # Create datasets
    train_dataset = UKBDataset(
        olink_data.iloc[train_idx].values,
        ages[train_idx]
    )
    val_dataset = UKBDataset(
        olink_data.iloc[val_idx].values,
        ages[val_idx]
    )

    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    # Initialize model
    pathway_defs = IPFPathways()
    model = IPFTransferNet(
        olink_dim=olink_data.shape[1],
        olink_to_gene=olink_to_gene,
        pathway_definitions=pathway_defs
    ).to(device)
    model.age_head.apply(init_weights)

    # Encode pathway knowledge using RNA studies
    model.encode_pathway_knowledge(
        rna_data=overlapped_data,
        pathway_annotations=pathway_defs.get_all_pathways(),
        weights_config=create_default_ipf_weights()
    )

    # Setup training with learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.CosineAnnealingLR(optimizer,
                                              T_max=num_epochs,
                                              eta_min=1e-5)

    best_val_loss = float('inf')
    patience = 5
    patience_counter = 0

    # Track metrics
    train_losses = []
    val_losses = []

    # Training loop
    for epoch in range(num_epochs):
        # Train with scheduler
        train_epoch_losses = train_epoch(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            device=device,
            scheduler=scheduler  # Pass scheduler to train_epoch
        )

        # Step scheduler
        scheduler.step()

        # Validate
        val_epoch_losses = analyze_samples(
            model=model,
            data_loader=val_loader,
            device=device
        )

        # Store losses for plotting
        train_losses.append(train_epoch_losses)
        val_losses.append(val_epoch_losses)

        # Print metrics
        print(f"\nEpoch {epoch + 1}/{num_epochs}:")
        print(f"Train - Age Loss: {train_epoch_losses['age']:.4f}, "
              f"Pathway Loss: {train_epoch_losses['pathway']:.4f}")
        print(f"Val - Age Loss: {val_epoch_losses['age']:.4f}, "
              f"Pathway Loss: {val_epoch_losses['pathway']:.4f}")
        print(f"Learning Rate: {scheduler.get_last_lr()[0]:.6f}")

        # Run diagnostics every 10 epochs
        if (epoch + 1) % 10 == 0:
            diagnostic_checks(model, val_loader, device)
            check_data_scaling(train_loader, val_loader)

        # Early stopping with validation loss
        val_loss = val_epoch_losses['total']
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
            }, 'best_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("\nEarly stopping triggered")
                break

    # Load best model checkpoint
    checkpoint = torch.load('best_model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])

    # Generate training report
    report = generate_training_report(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        train_history=train_losses,
        val_history=val_losses,
        device=device
    )

    return {
        'model': model,
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'train_idx': train_idx,
        'val_idx': val_idx,
        'training_history': {
            'train_losses': train_losses,
            'val_losses': val_losses
        },
        'training_report': report
    }


if __name__ == "__main__":
    main()