from tokenizers import Tokenizer, normalizers, pre_tokenizers, processors
from tokenizers.models import WordLevel
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedTokenizerFast

from tokenizers.normalizers import Strip, Precompiled
from tokenizers.pre_tokenizers import Split

from typing import List, Dict, Set
import re
import logging

logger = logging.getLogger(__name__)

class XMLAwareTokenizer(PreTrainedTokenizerFast):
    
    def _preprocess_text(self, text):
        # Add spaces between consecutive tags
        text = re.sub(r'>(<)', r'> <', text)
        # Add spaces around tags if not present
        text = re.sub(r'(<[^>]+>)(?=[^\s])', r'\1 ', text)
        text = re.sub(r'(?<=[^\s])(<[^>]+>)', r' \1', text)
        return text

    def encode(self, text, *args, **kwargs):
        if isinstance(text, str):
            text = self._preprocess_text(text)
        elif isinstance(text, list):
            text = [self._preprocess_text(t) if isinstance(t, str) else t for t in text]
        return super().encode(text, *args, **kwargs)

def debug_print_vocab(vocab: Dict[str, int], label: str = ""):
    """Helper function to debug vocabulary contents."""
    logger.info(f"\n=== Vocabulary Contents {label} ===")
    logger.info(f"Total vocabulary size: {len(vocab)}")
    logger.info("First 20 tokens:")
    for token, idx in list(vocab.items())[:20]:
        logger.info(f"{idx}: {token}")

class PromptStructureTokenizer:
    # Basic special tokens
    SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
    
    # Instructions (without XML tags)
    INSTRUCTIONS = [
        "disease2diff2disease",
        "age_group2diff2age_group",
        "compound2diff2compound"
    ]
    
    # Field tags
    FIELD_TAGS = [
        "tissue",
        "cell",
        "disease",
        "drug",
        "dose",
        "time",
        "case",
        "control",
        "age",
        "species",
        "gender",
        "up",
        "down",
        "disease_ancestors",
        "disease_tissue",
        "tissue_ancestors",
        "omics",
        "is_fibrosis"
    ]
    
    # Common field values
    COMMON_VALUES = [
        "Yes",
        "No",
        "human",
        "mouse",
        "transcriptomics",
        "proteomics",
        "expression",
        "methylation"
    ]

    def create_xml_tokens(self) -> List[str]:
        """Create all possible XML tag combinations."""
        xml_tokens = []
        
        # Add instruction tags
        for instr in self.INSTRUCTIONS:
            xml_tokens.append(f"<{instr}>")
        
        # Add field tags both opening and closing
        for field in self.FIELD_TAGS:
            xml_tokens.append(f"<{field}>")
            xml_tokens.append(f"</{field}>")
        
        return xml_tokens

    @staticmethod
    def preprocess_prompt(prompt: str) -> str:
        """Preprocess prompt while preserving XML tags."""
        # Add spaces between consecutive tags
        prompt = re.sub(r'>(<)', r'> <', prompt)
        # Add spaces around tags if not present
        prompt = re.sub(r'(<[^>]+>)(?=[^\s])', r'\1 ', prompt)
        prompt = re.sub(r'(?<=[^\s])(<[^>]+>)', r' \1', prompt)
        return prompt
    
    def build_vocabulary(self, prompts: List[str]) -> Dict[str, int]:
        """Build vocabulary with XML tags as single tokens."""
        vocab = {}
        current_idx = 0
        
        # Add special tokens
        for token in self.SPECIAL_TOKENS:
            vocab[token] = current_idx
            current_idx += 1
        
        # Add XML tag tokens
        xml_tokens = self.create_xml_tokens()
        for token in xml_tokens:
            if token not in vocab:
                vocab[token] = current_idx
                current_idx += 1
        
        # Add common values
        for token in self.COMMON_VALUES:
            if token not in vocab:
                vocab[token] = current_idx
                current_idx += 1
        
        # Process prompts
        for prompt in prompts:
            processed = self.preprocess_prompt(prompt)
            
            # Split into tokens while preserving XML tags
            pattern = r'(<[^>]+>)|([^<\s]+)'
            matches = re.finditer(pattern, processed)
            
            for match in matches:
                token = match.group()
                if token and token not in vocab:
                    vocab[token] = current_idx
                    current_idx += 1
        
        logger.info(f"Vocabulary breakdown:")
        logger.info(f"- Special tokens: {len(self.SPECIAL_TOKENS)}")
        logger.info(f"- XML tags: {len(xml_tokens)}")
        logger.info(f"- Common values: {len(self.COMMON_VALUES)}")
        logger.info(f"- Total size: {len(vocab)}")
        
        return vocab

    def train_tokenizer(self, prompts: List[str], vocab_size: int = 3500) -> PreTrainedTokenizerFast:
        """Train tokenizer with XML-aware tokenization."""
        # Preprocess prompts
        processed_prompts = [self.preprocess_prompt(p) for p in prompts]
        
        # Build vocabulary
        vocab = self.build_vocabulary(processed_prompts)
        
        if len(vocab) > vocab_size:
            logger.warning(f"Vocabulary size ({len(vocab)}) exceeds limit ({vocab_size})")
        
        # Initialize tokenizer
        tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token="[UNK]"))
        
        # Add basic normalizer
        tokenizer.normalizer = normalizers.Sequence([Strip()])
         
        # Use a combination of pre-tokenizers that preserve XML tags
        xml_pattern = r"</?[^>\s]+>"
        #xml_pattern = r"(</?(?:disease2diff2disease|age_group2diff2age_group|compound2diff2compound|instruction|tissue|cell|disease|drug|dose|time|case|control|age|species|gender|up|down|disease_ancestors|disease_tissue|tissue_ancestors|omics|is_fibrosis)>)"
        tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
                            pre_tokenizers.Split(pattern=xml_pattern, behavior="isolated"),
                            pre_tokenizers.WhitespaceSplit()
        ])
        
        
        # Add post-processor
        tokenizer.post_processor = TemplateProcessing(
            single="[BOS] $A [EOS]",
            special_tokens=[
                ("[BOS]", tokenizer.token_to_id("[BOS]")),
                ("[EOS]", tokenizer.token_to_id("[EOS]"))
            ]
        )
        
        # Convert to PreTrainedTokenizerFast
        pretrained_tokenizer = XMLAwareTokenizer(
            tokenizer_object=tokenizer,
            unk_token="[UNK]",
            pad_token="[PAD]",
            bos_token="[BOS]",
            eos_token="[EOS]"
        )
        
        # Test the tokenizer
        test_output = pretrained_tokenizer.encode(prompts[0], add_special_tokens=True)
        logger.info(f"\nTest tokenization:")
        logger.info(f"Input: {prompts[0]}")
        logger.info(f"Tokens: {pretrained_tokenizer.convert_ids_to_tokens(test_output)}")
        
        return pretrained_tokenizer
    
def test_tokenization(tokenizer: PreTrainedTokenizerFast, prompt: str):
    """Test tokenization with detailed output."""
    logger.info(f"\nTesting tokenization of: {prompt}")
    
    # Get tokens and IDs
    encoded = tokenizer.encode(prompt, add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(encoded)
    
    logger.info(f"Tokens: {tokens}")
    logger.info(f"Token IDs: {encoded}")
    logger.info(f"Decoded: {tokenizer.decode(encoded)}")
    
    # Check for UNK tokens
    unk_indices = [i for i, t in enumerate(tokens) if t == "[UNK]"]
    if unk_indices:
        logger.warning(f"Found [UNK] tokens at positions: {unk_indices}")
        
        # Show problematic parts
        pattern = r'(<[^>]+>)|([^<\s]+)'
        parts = [m.group() for m in re.finditer(pattern, prompt)]
        logger.warning(f"Original parts: {parts}")

def test_tokenizer_structure(tokenizer: PreTrainedTokenizerFast):
    """Test tokenizer's handling of prompt structure."""
    test_cases = {
        "Basic spaced tags": {
            "input": "<tissue> lung </tissue>",
            "expected_tokens": ["<tissue>", "lung", "</tissue>"]
        },
        "Concatenated tags": {
            "input": "<tissue>lung</tissue>",
            "expected_tokens": ["<tissue>", "lung", "</tissue>"]
        },
        "Nested tags with spaces": {
            "input": "<disease2diff2disease> <tissue> lung </tissue> <is_fibrosis> Yes </is_fibrosis>",
            "expected_tokens": ["<disease2diff2disease>", "<tissue>", "lung", "</tissue>", "<is_fibrosis>", "Yes", "</is_fibrosis>"]
        },
        "Nested tags without spaces": {
            "input": "<disease2diff2disease><tissue>lung</tissue><is_fibrosis>Yes</is_fibrosis>",
            "expected_tokens": ["<disease2diff2disease>", "<tissue>", "lung", "</tissue>", "<is_fibrosis>", "Yes", "</is_fibrosis>"]
        },
        "Empty tags": {
            "input": "<cell></cell><disease></disease>",
            "expected_tokens": ["<cell>", "</cell>", "<disease>", "</disease>"]
        },
        "Complex values": {
            "input": "<case>70.0-80.0</case><control>30.0-40.0</control>",
            "expected_tokens": ["<case>", "70.0-80.0", "</case>", "<control>", "30.0-40.0", "</control>"]
        }
    }
    
    logger.info("\nTesting tokenizer structure handling:")
    
    for test_name, test_case in test_cases.items():
        logger.info(f"\n=== Testing {test_name} ===")
        input_text = test_case["input"]
        expected_tokens = test_case["expected_tokens"]
        
        # Encode and decode
        encoded = tokenizer.encode(input_text)
        tokens = tokenizer.convert_ids_to_tokens(encoded)
        decoded = tokenizer.decode(encoded)
        
        # Remove special tokens for comparison
        actual_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
        
        logger.info(f"Input: {input_text}")
        logger.info(f"Expected tokens: {expected_tokens}")
        logger.info(f"Actual tokens: {actual_tokens}")
        logger.info(f"Decoded: {decoded}")
        
        # Check if all expected tokens are present
        missing_tokens = [t for t in expected_tokens if t not in actual_tokens]
        if missing_tokens:
            logger.warning(f"Missing expected tokens: {missing_tokens}")
        
        # Check if there are any unexpected tokens
        extra_tokens = [t for t in actual_tokens if t not in expected_tokens]
        if extra_tokens:
            logger.warning(f"Found unexpected tokens: {extra_tokens}")
        
        # Check if tokens appear in the correct order
        expected_str = " ".join(expected_tokens)
        actual_str = " ".join(actual_tokens)
        if expected_str != actual_str:
            logger.warning(f"Token order mismatch!\nExpected: {expected_str}\nActual: {actual_str}")

    def load_tokenizer(load_dir: str) -> PreTrainedTokenizerFast:
        """Load a trained tokenizer from a directory."""
        tokenizer = XMLAwareTokenizer.from_pretrained(load_dir)
        logger.info(f"Tokenizer loaded from {load_dir}")
        return tokenizer
    
def load_tokenizer(load_dir: str) -> PreTrainedTokenizerFast:
    """Load a trained tokenizer from a directory."""
    tokenizer = XMLAwareTokenizer.from_pretrained(load_dir)
    logger.info(f"Tokenizer loaded from {load_dir}")
    return tokenizer

# Usage example
if __name__ == "__main__":
    # Create tokenizer trainer
    structure_tokenizer = PromptStructureTokenizer()
    
    # Sample prompts
    sample_prompts = [
        "<disease2diff2disease><tissue>lung</tissue><is_fibrosis>Yes</is_fibrosis>",
        "<age_group2diff2age_group><case>70.0-80.0</case><control>30.0-40.0</control>"
    ]
    
    # Train tokenizer
    tokenizer = structure_tokenizer.train_tokenizer(sample_prompts)
    
    # Test structure handling
    test_tokenizer_structure(tokenizer)