import torch
import torch.nn.functional as F

import numpy as np


# Prompt completion
def complete_metadata(input_text, input_instr,
                      model, device, tokenizer,
                      up_opened = False,
                      up_closed = False,
                      down_opened = False,
                      down_closed = False,
                      genes_per_section = 50,
                      temperature=0.7,
                      verbose = False):

    inputs = tokenizer.encode(
                            input_text,
                            add_special_tokens=True,
                            truncation=True,
                            return_tensors='pt'
                        )
    input_ids = torch.tensor(inputs).to(device)
    attention_mask = torch.ones_like(input_ids).to(device)
    key_padding_mask = attention_mask < 0.5
    instruction_id = model.instruction_map[input_instr]
    instruction_ids = torch.tensor([instruction_id]).to(device)

    up_genes_count = 0
    down_genes_count = 0
    used_genes = set()

    with torch.no_grad():
        for step in range(4 + genes_per_section*2):
            outputs = model(input_ids, instruction_ids, None)
            next_token_logits = outputs[:, -1, :].clone()  # Clone to modify logits
            
            # Apply repetition penalty
            for token_id in range(next_token_logits.size(-1)):
                token_text = tokenizer.decode([token_id])
                if token_text in used_genes:
                    next_token_logits[0, token_id] = float('-inf')
            
            # Mask out special tokens except for </up> and </down> when needed
            for token_id in range(next_token_logits.size(-1)):
                token_text = tokenizer.decode([token_id])
                # Mask special tokens by default
                if token_text.startswith('<') or token_text.startswith('['):
                    next_token_logits[0, token_id] = float('-inf')
                
                # Allow </up> only when we have enough up genes
                if token_text == '</up>' and up_genes_count >= genes_per_section:
                    next_token_logits[0, token_id] = 0
                # Allow </down> only when we have enough down genes
                elif token_text == '</down>' and down_genes_count >= genes_per_section:
                    next_token_logits[0, token_id] = 0
            
            # Apply temperature and convert to probabilities
            next_token_logits = next_token_logits / temperature
            next_token_probs = F.softmax(next_token_logits, dim=-1)
            
            # Apply constraints based on current state
            if not up_opened:
                if verbose:
                    print("Forcing <up> tag")
                next_token = torch.tensor([[tokenizer.convert_tokens_to_ids("<up>")]], device=device)
                up_opened = True
            elif up_opened and not up_closed:
                # Sample from the distribution instead of taking argmax
                next_token = torch.multinomial(next_token_probs, num_samples=1)
                token_text = tokenizer.decode([next_token.squeeze().item()])
                if verbose:
                    print(f"Generated token in up section: {token_text}")
                
                if up_genes_count >= genes_per_section:
                    if verbose:
                        print("Forcing </up> tag")
                    next_token = torch.tensor([[tokenizer.convert_tokens_to_ids("</up>")]], device=device)
                    up_closed = True
                else:
                    if not token_text.startswith('<') and not token_text.startswith('['):
                        used_genes.add(token_text)  # Add to used genes
                        up_genes_count += 1
                        if verbose:
                            print(f"Counted as gene, total up genes: {up_genes_count}")
            elif up_closed and not down_opened:
                if verbose:
                    print("Forcing <down> tag")
                next_token = torch.tensor([[tokenizer.convert_tokens_to_ids("<down>")]], device=device)
                down_opened = True
            elif down_opened and not down_closed:
                # Sample from the distribution instead of taking argmax
                next_token = torch.multinomial(next_token_probs, num_samples=1)
                token_text = tokenizer.decode([next_token.squeeze().item()])
                if verbose:
                    print(f"Generated token in down section: {token_text}")
                
                if down_genes_count >= genes_per_section:
                    if verbose:
                        print("Forcing </down> tag")
                    next_token = torch.tensor([[tokenizer.convert_tokens_to_ids("</down>")]], device=device)
                    down_closed = True
                else:
                    if not token_text.startswith('<') and not token_text.startswith('['):
                        used_genes.add(token_text)  # Add to used genes
                        down_genes_count += 1
                        if verbose:
                            print(f"Counted as gene, total down genes: {down_genes_count}")
            else:
                if verbose:
                    print("All sections complete, breaking")
                break
                
            input_ids = torch.cat([input_ids, next_token], dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=device)], dim=1)
            key_padding_mask = attention_mask < 0.5
            
            if input_ids.size(1) >= 512:
                if verbose:
                    print("Reached max length, breaking")
                break
                
    out_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return(out_text)

def add_is_fibrosis(input_text, input_instr,
                    model, device, tokenizer,
                    temperature=0.7,
                    verbose = False):
    
    inputs = tokenizer.encode(
                        input_text,
                        add_special_tokens=False,
                        truncation=True,
                        return_tensors='pt'
                    )
    
    input_ids = torch.tensor(inputs).to(device)
    attention_mask = torch.ones_like(input_ids).to(device)
    key_padding_mask = attention_mask < 0.5
    instruction_id = model.instruction_map[input_instr]
    instruction_ids = torch.tensor([instruction_id]).to(device)

    attention_mask = torch.ones((1, input_ids.size(1)), device=device)
    key_padding_mask = attention_mask < 0.5
    
    yes_token_id = tokenizer.encode("Yes", add_special_tokens=False)[0]
    no_token_id = tokenizer.encode("No", add_special_tokens=False)[0]
    
    with torch.no_grad():
        # Add opening tag
        open_tag = tokenizer.encode("<is_fibrosis>", add_special_tokens=False)
        open_tag_tensor = torch.tensor([open_tag], device=device)
        input_ids = torch.cat([input_ids, open_tag_tensor], dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones((1, len(open_tag)), device=device)], dim=1)
    
        # Generate prediction token
        outputs = model(input_ids, instruction_ids, None)
        next_token_logits = outputs[:, -1, :].clone()
    
        mask = torch.full_like(next_token_logits, float('-inf'))
        mask[0, yes_token_id] = next_token_logits[0, yes_token_id]
        mask[0, no_token_id] = next_token_logits[0, no_token_id]
        next_token_logits = mask

        # Apply temperature and sample
        next_token_logits = next_token_logits / temperature
        next_token_probs = F.softmax(next_token_logits, dim=-1)
        if verbose:
            no_prob = next_token_probs[0, no_token_id]
            yes_prob = next_token_probs[0, yes_token_id]
            print(f"The case cohort is fibrotic:\nYes: {100*yes_prob:.0f}%\nNo: {100*no_prob:.0f}%")
        next_token = torch.multinomial(next_token_probs, num_samples=1)
        
        # Add prediction token
        input_ids = torch.cat([input_ids, next_token], dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=device)], dim=1)
    
        # Add closing tag
        close_tag = tokenizer.encode("</is_fibrosis>", add_special_tokens=False)
        close_tag_tensor = torch.tensor([close_tag], device=device)
        input_ids = torch.cat([input_ids, close_tag_tensor], dim=1)
    
    out_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return(out_text)



def get_attention_hook(attention_patterns):
    """
    Creates an attention hook function that captures attention patterns.
    
    Args:
        attention_patterns: List to store captured attention patterns
        
    Returns:
        hook_fn: Hook function for MultiheadAttention module
    """
    def hook_fn(module, input, output):
        # output is a tuple of (attn_output, attn_weights)
        if isinstance(output, tuple):
            attn_output, attn_weights = output
            if attn_weights is not None:
                attention_patterns.append(attn_weights.detach().cpu())
        else:
            # If output is not a tuple, try to get attention weights directly
            if hasattr(output, 'detach'):
                attention_patterns.append(output.detach().cpu())
            else:
                print(f"Warning: Unexpected output type in attention hook: {type(output)}")
                
    return hook_fn

def add_is_fibrosis_with_attention(input_text, input_instr,
                                    model, device, tokenizer,
                                    temperature=0.7,
                                    verbose=False):
    
    # Register hooks for all attention layers
    attention_patterns = []
    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.MultiheadAttention):
            hook = module.register_forward_hook(get_attention_hook(attention_patterns))
            hooks.append(hook)
    
    # First generate the complete output including the prediction
    out_text = add_is_fibrosis(input_text, input_instr, model, device, tokenizer, temperature, verbose)
    
    # Now run the model again with the complete sequence to get attention patterns
    inputs = tokenizer.encode(out_text, add_special_tokens=False, return_tensors='pt')
    input_ids = inputs.to(device)
    attention_mask = torch.ones((1, input_ids.size(1)), device=device)
    key_padding_mask = attention_mask < 0.5
    instruction_id = model.instruction_map[input_instr]
    instruction_ids = torch.tensor([instruction_id]).to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, instruction_ids, None)
    
    # Remove the hooks
    for hook in hooks:
        hook.remove()
    
    # Get the tokens for visualization
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    if verbose:
        print("\nAttention patterns captured from layers:", len(attention_patterns))
        print("Shape of attention in first layer:", attention_patterns[0].shape)
    
    return out_text, tokens, attention_patterns

def visualize_attention_ascii(attention_scores, tokens, threshold=0.1):
    """
    Visualize token attention strengths.
    
    Args:
        attention_scores: attention array - either 1D or 2D
        tokens: list of token strings
        threshold: minimum attention score to show (for readability)
    """
    # Convert to numpy if needed
    if torch.is_tensor(attention_scores):
        attention_scores = attention_scores.numpy()
    
    print("\nAttention scores shape:", attention_scores.shape)
    
    # If we have a vector, visualize token-wise attention
    if len(attention_scores.shape) == 1:
        # Normalize scores to 0-1
        max_score = np.max(attention_scores)
        norm_scores = attention_scores / max_score if max_score > 0 else attention_scores
        
        # Characters for attention intensity
        ATTENTION_CHARS = " ░▒▓█"
        
        print("\nToken attention strengths (normalized):")
        # Find tokens with attention above threshold
        for i, (score, token) in enumerate(zip(norm_scores, tokens)):
            if score >= threshold:
                char_idx = min(int(score * len(ATTENTION_CHARS)), len(ATTENTION_CHARS) - 1)
                char = ATTENTION_CHARS[char_idx]
                print(f"{token:30} | {char} {score:.4f}")
    
    # If we have a matrix, visualize attention patterns
    elif len(attention_scores.shape) == 2:
        seq_len = attention_scores.shape[0]
        # Characters for attention intensity
        ATTENTION_CHARS = " ░▒▓█"
        
        # Normalize scores to 0-1
        max_score = np.max(attention_scores)
        norm_scores = attention_scores / max_score if max_score > 0 else attention_scores
        
        # Get max token length for padding
        max_len = max(len(str(t)) for t in tokens[:seq_len])
        
        # Print column headers
        print("\nTo tokens →")
        print("From tokens ↓" + " " * (max_len - 5), end="")
        for t in tokens[:seq_len]:
            print(f" {str(t):3}", end="")
        print()
        
        # Print attention matrix
        for i, source_token in enumerate(tokens[:seq_len]):
            print(f"{str(source_token):{max_len}}", end=" ")
            for score in norm_scores[i]:
                if score < threshold:
                    char = " "
                else:
                    char_idx = min(int(score * len(ATTENTION_CHARS)), len(ATTENTION_CHARS) - 1)
                    char = ATTENTION_CHARS[char_idx]
                print(f" {char} ", end="")
            print()

def summarize_key_attentions(attention_scores, tokens, top_k=5):

    """
    Summarize tokens with strongest attention.
    
    Args:
        attention_scores: attention array - either 1D or 2D
        tokens: list of token strings
        top_k: number of top attention patterns to show
    """
    if torch.is_tensor(attention_scores):
        attention_scores = attention_scores.numpy()
    
    print("\nAttention scores shape:", attention_scores.shape)
    
    # For vector attention scores
    if len(attention_scores.shape) == 1:
        # Get indices of top attention scores
        top_indices = np.argpartition(attention_scores, -top_k)[-top_k:]
        top_indices = top_indices[np.argsort(-attention_scores[top_indices])]
        
        print(f"\nTop {top_k} attended tokens:")
        for idx in top_indices:
            print(f"'{tokens[idx]}': {attention_scores[idx]:.4f}")
            
    # For matrix attention scores
    elif len(attention_scores.shape) == 2:
        # Get flattened indices of top attention scores
        flat_indices = np.argpartition(attention_scores.flatten(), -top_k)[-top_k:]
        row_indices, col_indices = np.unravel_index(flat_indices, attention_scores.shape)
        
        print(f"\nTop {top_k} attention patterns:")
        for row, col in zip(row_indices, col_indices):
            score = attention_scores[row, col]
            print(f"From '{tokens[row]}' to '{tokens[col]}': {score:.3f}")

def analyze_gene_attention(attention_matrix, tokens, yes_idx, threshold=0.001):
    """
    Analyze attention patterns between genes and the Yes/No prediction.
    Returns attention scores as percentages relative to the highest score (100%).
    """
    # Find indices where up/down tags appear
    try:
        up_start = tokens.index('<up>')
        up_end = tokens.index('</up>')
        down_start = tokens.index('<down>')
        down_end = tokens.index('</down>')
        
        # Extract genes and their attention scores
        up_genes = []
        for i in range(up_start + 1, up_end):
            if tokens[i].strip() and tokens[i] not in ['<up>', '</up>']:
                up_genes.append((tokens[i], attention_matrix[yes_idx][i], 'up'))
                
        down_genes = []
        for i in range(down_start + 1, down_end):
            if tokens[i].strip() and tokens[i] not in ['<down>', '</down>']:
                down_genes.append((tokens[i], attention_matrix[yes_idx][i], 'down'))
        
        # Combine all genes
        all_genes = up_genes + down_genes
        
        if not all_genes:
            return []
            
        # Find the maximum score for normalization
        max_score = max(score for _, score, _ in all_genes)
        if max_score == 0:
            return []
            
        # Normalize scores to percentages and sort
        normalized_genes = [(gene, (score/max_score) * 100, direction) 
                          for gene, score, direction in all_genes]
        normalized_genes.sort(key=lambda x: x[1], reverse=True)
        
        # Convert threshold to percentage relative to max score
        threshold_percentage = (threshold/max_score) * 100
        return [(gene, score, direction) 
                for gene, score, direction in normalized_genes 
                if score >= threshold_percentage]
        
    except ValueError as e:
        print(f"Error finding gene sections: {e}")
        return []

def main():
    input_text = ("<age_group2diff2age_group><omics>transcriptomics</omics>"
              "<tissue>UBERON_0002048 </tissue><cell></cell>"
              "<tissue_ancestors>UBERON_0000061 UBERON_0002075 </tissue_ancestors>"
              # "<tissue>UBERON_0000178 </tissue>"
              # "<tissue_ancestors>UBERON_0006314 </tissue_ancestors>"
              "<disease></disease>"
              "<disease_ancestors></disease_ancestors>"
              "<disease_tissue></disease_tissue>"
              "<drug></drug><dose></dose><time></time><case>70.0-80.0</case><control>30.0-40.0</control>"
              "<age></age><species>human </species><gender></gender>")

    out1 = complete_metadata(input_text, 'disease2diff2disease',
                         model, device, tokenizer,
                         genes_per_section = 50) # 50 genes is the cutoff used in training. better not to go lower than that
    # out1 = input_text
    out2 = add_is_fibrosis(out1, 'disease2diff2disease',
                        model, device, tokenizer,
                        verbose = True)

    output, tokens, attention_patterns = add_is_fibrosis_with_attention(
                                                    out1, 
                                                    "disease2diff2disease",
                                                    model, 
                                                    device, 
                                                    tokenizer,
                                                    verbose=True
                                                                        )

    # print(visualize_attention_ascii(attention_patterns[2].squeeze(0), tokens, threshold=0.0))
    # Find indices of important tokens

    try:
        yes_idx = tokens.index('Yes')
        attention_matrix = attention_patterns[-1].squeeze(0)  # Use last layer
        
        # Analyze gene attention patterns
        print("\nAnalyzing gene contributions to 'Yes' prediction...")
        gene_attention = analyze_gene_attention(attention_matrix, tokens, yes_idx, threshold=0.0001)
        
        if gene_attention:
            print("\nTop genes contributing to the prediction (attention score >= 0.0001):")
            print("\n{:<15} {:<10} {:<8}".format("Gene", "Direction", "Attention"))
            print("-" * 35)
            for gene, score, direction in gene_attention:
                print("{:<15} {:<10} {:.3f}".format(gene, direction, score))
        else:
            print("No genes found with significant attention scores (>= 0.1)")
            
        # Original attention analysis
        print("\nOverall attention patterns:")
        yes_attention = attention_matrix[yes_idx]
        print("\nWhat 'Yes' is attending to:")
        print(visualize_attention_ascii(yes_attention, tokens, threshold=0.1))
        
        print("\nKey attention patterns for 'Yes' token:")
        summarize_key_attentions(yes_attention, tokens, top_k=10)
        
    except ValueError as e:
        print("Could not find 'Yes' token in the sequence. This might mean the model predicted 'No'")