import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split

from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import PreTrainedTokenizerFast
import numpy as np
from typing import List, Dict, Optional
import re
from collections import Counter
import json
import pickle

import os
from collections import defaultdict

from tokenizers import Tokenizer, normalizers, pre_tokenizers, processors
from tokenizers.models import WordLevel
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedTokenizerFast

from .tokenizer import XMLAwareTokenizer

class MiniTransformerConfig:
    def __init__(
        self,
        vocab_size: int = 3500,  # Vocab size for genes + special tokens
        d_model: int = 128,      # Hidden dimension size
        n_heads: int = 4,        # Number of attention heads
        n_layers: int = 3,       # Number of transformer layers
        dropout: float = 0.1,    # Dropout rate
        max_seq_len: int = 512,  # Maximum sequence length
        n_instructions: int = 3,  # Number of instruction types
    ):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dropout = dropout
        self.max_seq_len = max_seq_len
        self.n_instructions = n_instructions
        
        # Validate config
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
class MiniTransformerBlock(nn.Module):
    def __init__(self, config: MiniTransformerConfig):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            config.d_model, 
            config.n_heads,
            dropout=config.dropout,
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(config.d_model)
        self.norm2 = nn.LayerNorm(config.d_model)
        
        self.ff = nn.Sequential(
            nn.Linear(config.d_model, config.d_model * 4),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_model * 4, config.d_model),
            nn.Dropout(config.dropout)
        )
        
    def forward(self, x, attention_mask=None):
        # Self-attention with residual connection
        attn_out, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
        x = self.norm1(x + attn_out)
        
        # Feedforward with residual connection
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        
        return x

class MiniP3GPT(nn.Module):
    
    def __init__(self, config: MiniTransformerConfig):
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        
        # Instruction embeddings
        self.instruction_embedding = nn.Embedding(config.n_instructions, config.d_model)
        
        # Instruction mapping
        self.instruction_map = {
            "age_group2diff2age_group": 0,
            "disease2diff2disease": 1,
            "compound2diff2compound": 2
        }
        
        self.dropout = nn.Dropout(config.dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            MiniTransformerBlock(config) for _ in range(config.n_layers)
        ])
        
        # Output layer
        self.out = nn.Linear(config.d_model, config.vocab_size)
        
    def forward(self, input_ids, instruction_ids, attention_mask=None):
        # Get token embeddings
        x = self.token_embedding(input_ids)
        
        # Add instruction embeddings
        instruction_emb = self.instruction_embedding(instruction_ids)
        x = x + instruction_emb.unsqueeze(1).expand(-1, x.size(1), -1)
        
        x = self.dropout(x)
        
        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x, attention_mask)
            
        # Get logits
        logits = self.out(x)
        
        return logits

class FibrosisDataset(Dataset):

    instruction_map = {
                        "age_group2diff2age_group": 0,
                        "disease2diff2disease": 1,
                        "compound2diff2compound": 2
                        }

    def __init__(self, prompts: List[str], instructions: List[str], tokenizer: PreTrainedTokenizerFast):
        self.prompts = prompts
        self.instructions = instructions
        self.tokenizer = tokenizer
        
        unknown_instructions = set(instructions) - set(self.instruction_map.keys())
        if unknown_instructions:
            raise ValueError(f"Found unknown instructions: {unknown_instructions}")
        
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        prompt = self.prompts[idx]
        instruction = self.instructions[idx]
        
        # Get instruction ID
        instruction_id = torch.tensor(self.instruction_map[instruction])
        
        # Tokenize input
        # encodings = self.tokenizer(
        #     prompt,
        #     padding='max_length',
        #     truncation=True,
        #     max_length=512,
        #     return_tensors='pt'
        # )
        # input_ids = encodings['input_ids'].squeeze(0)
        # attention_mask = encodings['attention_mask'].squeeze(0)

        input_ids = self.tokenizer.encode(
            prompt,
            add_special_tokens=True,
            max_length=512,
            truncation=True
        )
        input_ids = torch.tensor(input_ids)
        attention_mask = torch.ones_like(input_ids)
        # Pad if necessary
        if len(input_ids) < 512:
            padding_length = 512 - len(input_ids)
            input_ids = torch.cat([input_ids, torch.full((padding_length,), self.tokenizer.pad_token_id)])
            attention_mask = torch.cat([attention_mask, torch.zeros(padding_length)])
                
        # Shift sequences for causal language modeling
        labels = input_ids.clone()
        labels[:-1] = input_ids[1:]
        labels[-1] = -100  # Ignore last prediction
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'instruction_id': instruction_id,
            'labels': labels
        }
    
    
    @classmethod
    def validate_instruction_distribution(cls, train_dataset, val_dataset):
        """
        Validates that both training and validation datasets contain all instruction types.
        
        Args:
            train_dataset: Training dataset instance (or Subset)
            val_dataset: Validation dataset instance (or Subset)
            
        Returns:
            bool: True if both datasets contain all instruction types
        """
        # Handle both direct FibrosisDataset and Subset instances
        train_data = train_dataset.dataset if hasattr(train_dataset, 'dataset') else train_dataset
        val_data = val_dataset.dataset if hasattr(val_dataset, 'dataset') else val_dataset
        
        # Get indices for each subset if using random_split
        train_indices = train_dataset.indices if hasattr(train_dataset, 'indices') else range(len(train_data))
        val_indices = val_dataset.indices if hasattr(val_dataset, 'indices') else range(len(val_data))
        
        # Get instructions for each split using appropriate indices
        train_instructions = set(train_data.instructions[i] for i in train_indices)
        val_instructions = set(val_data.instructions[i] for i in val_indices)
        
        # Get all unique instructions
        all_instructions = train_instructions.union(val_instructions)
        
        # Check if both datasets have all instruction types
        train_has_all = all_instructions.issubset(train_instructions)
        val_has_all = all_instructions.issubset(val_instructions)
        
        if not train_has_all:
            missing_train = all_instructions - train_instructions
            print(f"Training set is missing instructions: {missing_train}")
        
        if not val_has_all:
            missing_val = all_instructions - val_instructions
            print(f"Validation set is missing instructions: {missing_val}")
        
        return train_has_all and val_has_all
    
def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scaler: torch.cuda.amp.GradScaler,  # Add scaler parameter
    device: str
) -> float:
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        optimizer.zero_grad()

        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        instruction_id = batch['instruction_id'].to(device)
        labels = batch['labels'].to(device)
        key_padding_mask = attention_mask < 0.5

        # Forward pass w gradient scaling
        with torch.cuda.amp.autocast(enabled=True):
            logits = model(input_ids, instruction_id, key_padding_mask)
            logits += 1e-10  # Add small epsilon for numerical stability
            
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=-100
            )
    
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()

    return total_loss / len(dataloader)

def validate(
    model: nn.Module,
    dataloader: DataLoader,
    device: str) -> float:

    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            instruction_id = batch['instruction_id'].to(device)
            labels = batch['labels'].to(device)
            
            key_padding_mask = attention_mask < 0.5
            
            logits = model(input_ids, instruction_id, key_padding_mask)
            logits += 1e-10
            # Calculate loss            
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=-100
            )
            
            total_loss += loss.item()
            
    return total_loss / len(dataloader)


def save_model(model: nn.Module,
    tokenizer: PreTrainedTokenizerFast,
    config: MiniTransformerConfig,
    save_dir: str):
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Save model state
    model_path = os.path.join(save_dir, "state.pt")
    torch.save(model.state_dict(), model_path)
    
    # Save tokenizer
    tokenizer.save_pretrained(save_dir)
    
    # Save config
    config_dict = {
        "vocab_size": config.vocab_size,
        "d_model": config.d_model,
        "n_heads": config.n_heads,
        "n_layers": config.n_layers,
        "dropout": config.dropout,
        "max_seq_len": config.max_seq_len
    }
    
    config_path = os.path.join(save_dir, "config.json")
    with open(config_path, 'w') as f:
        json.dump(config_dict, f, indent=4)

def load_model(load_dir: str) -> tuple[MiniP3GPT,  MiniTransformerConfig, PreTrainedTokenizerFast,]:
    # Load config
    config_path = os.path.join(load_dir, "config.json")
    with open(config_path, 'r') as f:
        config_dict = json.load(f)
    config = MiniTransformerConfig(**config_dict)
    
    # Initialize model
    model = MiniP3GPT(config)
    
    # Load model state
    model_path = os.path.join(load_dir, "state.pt")
    model.load_state_dict(torch.load(model_path))
    
    # Load tokenizer
    tokenizer = XMLAwareTokenizer.from_pretrained(load_dir)
    
    return model, config, tokenizer



######
def create_weighted_sampler(dataset):
    """
    Create a weighted sampler to balance instruction types in training.
    
    Args:
        dataset: Dataset or Subset instance
        
    Returns:
        WeightedRandomSampler instance
    """
    # Handle both direct FibrosisDataset and Subset instances
    data = dataset.dataset if hasattr(dataset, 'dataset') else dataset
    indices = dataset.indices if hasattr(dataset, 'indices') else range(len(data))
    
    # Get instructions for this split
    instructions = [data.instructions[i] for i in indices]
    
    # Calculate class weights
    instruction_counts = Counter(instructions)
    total = len(instructions)
    weights = [total / instruction_counts[inst] for inst in instructions]
    
    # Create weighted sampler
    weights = torch.DoubleTensor(weights)
    sampler = WeightedRandomSampler(weights, len(weights))
    
    return sampler


def validate_by_instruction(model, val_loader, device):
    '''
    # Usage in training:
    train_sampler = create_weighted_sampler(train_instructions)
    train_loader = DataLoader(
        train_dataset,
        batch_size=16,
        sampler=train_sampler
    )
    '''
    model.eval()
    instruction_losses = defaultdict(list)
    instruction_correct = defaultdict(int)
    instruction_total = defaultdict(int)
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            instruction_id = batch['instruction_id'].to(device)
            labels = batch['labels'].to(device)
            
            key_padding_mask = attention_mask < 0.5

            logits = model(input_ids, instruction_id, key_padding_mask)
            logits+=1e-10

            
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=-100,
                reduction='none'
            )
            
            # Group losses by instruction
            for i, inst_id in enumerate(instruction_id):
                inst_name = list(FibrosisDataset.instruction_map.keys())[inst_id]
                seq_loss = loss[i * logits.size(1):(i + 1) * logits.size(1)]
                valid_loss = seq_loss[labels[i] != -100]
                instruction_losses[inst_name].extend(valid_loss.tolist())

                pred = logits[i].argmax(dim=-1)
                valid_mask = labels[i] != -100
                correct = (pred[valid_mask] == labels[i][valid_mask]).sum().item()
                total = valid_mask.sum().item()
                
                instruction_correct[inst_name] += correct
                instruction_total[inst_name] += total
    
    # Calculate metrics for each instruction
    metrics = {}
    for inst in instruction_losses:
        avg_loss = sum(instruction_losses[inst]) / len(instruction_losses[inst])
        accuracy = instruction_correct[inst] / instruction_total[inst]
        metrics[inst] = {
            'loss': avg_loss,
            'accuracy': accuracy
        }
    
    return metrics


######

def prepare_datasets(data_dict, tokenizer, val_split=0.1):
    """Prepare train and validation datasets."""
    # Flatten the dictionary into lists
    all_prompts = []
    all_instructions = []
    
    print(f"Input data keys: {list(data_dict.keys())}")  # Debug print
    print(f"Valid instruction names: {list(FibrosisDataset.instruction_map.keys())}")  # Debug print
    

    for instructions, prompts in data_dict.items():
        # Handle multiple instructions per prompt
        if isinstance(instructions, str):
            instructions = [instructions]
        elif isinstance(instructions, tuple):
            instructions = list(instructions)
            
        print(f"Processing instructions: {instructions} with {len(prompts)} prompts")  # Debug print
        
        # Add each prompt once for each instruction
        for instruction in instructions:
            if instruction in FibrosisDataset.instruction_map:
                all_prompts.extend(prompts)
                all_instructions.extend([instruction] * len(prompts))
            else:
                print(f"Skipping unknown instruction: {instruction}")  # Debug print
    
    print(f"Total prompts: {len(all_prompts)}")  # Debug print
    print(f"Total instructions: {len(all_instructions)}")  # Debug print
    print(f"Unique instructions: {set(all_instructions)}")  # Debug print
    
    if len(all_prompts) == 0:
        raise ValueError("No valid prompts found! Check that instruction names match between data and FibrosisDataset.instruction_map")
    
    # Calculate split sizes
    total_size = len(all_prompts)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size

    # Create datasets
    full_dataset = FibrosisDataset(all_prompts, all_instructions, tokenizer)
    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    return train_dataset, val_dataset


def load_data(pickle_path: str):
    """Load the instruction-keyed dictionary of prompts."""
    with open(pickle_path, 'rb') as f:
        data = pickle.load(f)

    
    # Map filename keys to actual instruction types
    instruction_mapping = {
        "30Dec2024_disease2diff_plus_compound2diff": ("disease2diff2disease", "compound2diff2compound"),
        "30Dec2024_disease2diff": "disease2diff2disease",
        "30Dec2024_compound2diff": "compound2diff2compound",
        "30Dec2024_age_group2diff": "age_group2diff2age_group",
        "curated_new_diffs": "disease2diff2disease" 
    }
    
    # Create new dictionary with correct instruction keys
    processed_data = {}
    for file_key, prompts in data.items():
        instruction = instruction_mapping.get(file_key)
        if instruction:
            # Handle tuple case (multiple instructions)
            if isinstance(instruction, tuple):
                for inst in instruction:
                    if inst not in processed_data:
                        processed_data[inst] = []
                    processed_data[inst].extend(prompts)
            else:
                # Initialize list if instruction not seen before
                if instruction not in processed_data:
                    processed_data[instruction] = []
                processed_data[instruction].extend(prompts)
        else:
            print(f"Warning: No mapping found for {file_key}")
            
    return processed_data

# Example usage:
if __name__ == "__main__":
    
    from ipfP3GPT import * 

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    logger.info("Loading data...")
    data = load_data("31Dec2024_fibro_training.pckl")

    tokenizer = load_tokenizer("ipf_model")

    # Initialize model
    logger.info("Initializing model...")
    # Initialize config
    config = MiniTransformerConfig(
        vocab_size=tokenizer.vocab_size,
        d_model=128,
        n_heads=4,
        n_layers=3,
        dropout=0.1,
        max_seq_len=512,
        n_instructions=3
    )
    model = MiniP3GPT(config)
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cuda:0')
    model.to(device)


    train_dataset, val_dataset = prepare_datasets(data, tokenizer, val_split=0.1)
    batch_size = 32
    train_sampler = create_weighted_sampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    # Setup optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scaler = torch.cuda.amp.GradScaler() 

    # Training loop
    num_epochs = 25
    best_val_loss = float('inf')
    save_path = 'best_model'

    train_losses = []
    val_losses = []
    instruction_type_metrics = defaultdict(lambda: {'losses': [], 'accuracies': []})

    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, scaler, device)
        val_loss = validate(model, val_loader, device)
        
        # Store the losses
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, tokenizer, config, save_path)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f}')
        print(f'Val Loss: {val_loss:.4f}')
        
        # Validate by instruction type
        instruction_metrics = validate_by_instruction(model, val_loader, device)
        for inst, metrics in instruction_metrics.items():
            instruction_type_metrics[inst]['losses'].append(metrics['loss'])
            instruction_type_metrics[inst]['accuracies'].append(metrics['accuracy'])
            print(f'{inst} - Loss: {metrics["loss"]:.4f}, Accuracy: {metrics["accuracy"]:.4f}')
