#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Model Baseline untuk Extractive Summarization Liputan6
"""

import numpy as np
import pandas as pd
from typing import List, Dict, Tuple, Any
import pickle
import joblib
from abc import ABC, abstractmethod

# NLP Libraries
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
import re

# Sklearn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler

# TextRank
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity

# BERT
try:
    from transformers import AutoTokenizer, AutoModel
    import torch
    BERT_AVAILABLE = True
except ImportError:
    BERT_AVAILABLE = False
    print("Warning: transformers tidak tersedia. BERT model tidak bisa digunakan.")

# Download NLTK data jika belum ada
def ensure_nltk_data():
    """Pastikan NLTK data tersedia"""
    required_data = [
        ('tokenizers/punkt', 'punkt'),
        ('tokenizers/punkt_tab', 'punkt_tab'),
        ('corpora/stopwords', 'stopwords')
    ]
    
    for path, name in required_data:
        try:
            nltk.data.find(path)
        except LookupError:
            try:
                print(f"Downloading NLTK {name}...")
                nltk.download(name, quiet=True)
            except:
                # Fallback untuk punkt_tab
                if name == 'punkt_tab':
                    try:
                        nltk.download('punkt', quiet=True)
                    except:
                        pass

# Setup NLTK data
ensure_nltk_data()

class BaseExtractiveModel(ABC):
    """
    Base class untuk semua model extractive summarization
    """
    
    def __init__(self):
        self.is_trained = False
        
    @abstractmethod
    def fit(self, articles: List[List[str]], labels: List[List[int]]):
        """Train model"""
        pass
    
    @abstractmethod
    def predict(self, article: List[str], num_sentences: int = 3) -> List[int]:
        """Prediksi kalimat yang akan dipilih untuk summary"""
        pass
    
    def save_model(self, filepath: str):
        """Save model ke file"""
        with open(filepath, 'wb') as f:
            pickle.dump(self, f)
    
    @classmethod
    def load_model(cls, filepath: str):
        """Load model dari file"""
        with open(filepath, 'rb') as f:
            return pickle.load(f)

class TextRankModel(BaseExtractiveModel):
    """
    TextRank algorithm untuk extractive summarization
    """
    
    def __init__(self, language='indonesian'):
        super().__init__()
        self.language = language
        try:
            self.stop_words = set(stopwords.words('indonesian'))
        except:
            self.stop_words = set()
        self.stemmer = PorterStemmer()
        
    def preprocess_text(self, text: str) -> str:
        """Preprocess teks untuk TextRank"""
        # Lowercase
        text = text.lower()
        # Hapus karakter khusus
        text = re.sub(r'[^a-zA-Z\s]', '', text)
        # Tokenize dan hapus stopwords
        words = word_tokenize(text)
        words = [self.stemmer.stem(word) for word in words if word not in self.stop_words and len(word) > 2]
        return ' '.join(words)
    
    def calculate_sentence_similarity(self, sent1: str, sent2: str) -> float:
        """Hitung similarity antara dua kalimat"""
        # Preprocess kalimat
        sent1_processed = self.preprocess_text(sent1)
        sent2_processed = self.preprocess_text(sent2)
        
        # Tokenize
        words1 = set(sent1_processed.split())
        words2 = set(sent2_processed.split())
        
        # Jaccard similarity
        if len(words1.union(words2)) == 0:
            return 0.0
        
        return len(words1.intersection(words2)) / len(words1.union(words2))
    
    def fit(self, articles: List[List[str]], labels: List[List[int]]):
        """TextRank tidak perlu training, hanya set flag"""
        self.is_trained = True
        return self
    
    def predict(self, article: List[str], num_sentences: int = 3) -> List[int]:
        """Prediksi menggunakan TextRank"""
        if len(article) <= num_sentences:
            return list(range(len(article)))
        
        # Buat similarity matrix
        similarity_matrix = np.zeros((len(article), len(article)))
        
        for i in range(len(article)):
            for j in range(len(article)):
                if i != j:
                    similarity_matrix[i][j] = self.calculate_sentence_similarity(article[i], article[j])
        
        # Buat graph dan hitung PageRank
        nx_graph = nx.from_numpy_array(similarity_matrix)
        scores = nx.pagerank(nx_graph)
        
        # Ambil top-k kalimat
        ranked_sentences = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        selected_indices = [idx for idx, score in ranked_sentences[:num_sentences]]
        
        # Convert ke binary labels
        labels = [0] * len(article)
        for idx in selected_indices:
            labels[idx] = 1
            
        return labels

class TfidfLogisticModel(BaseExtractiveModel):
    """
    Model Logistic Regression dengan TF-IDF features
    """
    
    def __init__(self):
        super().__init__()
        self.vectorizer = TfidfVectorizer(max_features=5000, stop_words=None, ngram_range=(1, 2))
        self.scaler = StandardScaler()
        self.classifier = LogisticRegression(random_state=42, max_iter=1000)
        
    def extract_features(self, articles: List[List[str]]) -> np.ndarray:
        """Extract features dari artikel"""
        # Flatten semua kalimat
        all_sentences = []
        for article in articles:
            all_sentences.extend(article)
        
        # TF-IDF features
        tfidf_features = self.vectorizer.fit_transform(all_sentences)
        
        # Additional features
        additional_features = []
        sentence_idx = 0
        
        for article in articles:
            for i, sentence in enumerate(article):
                # Position features
                position_ratio = i / len(article) if len(article) > 1 else 0
                
                # Length features
                sentence_length = len(sentence.split())
                
                # Sentence features
                features = [
                    position_ratio,
                    sentence_length,
                    len(sentence),  # Character length
                    sentence.count('.'),  # Number of periods
                    sentence.count(','),  # Number of commas
                ]
                
                additional_features.append(features)
                sentence_idx += 1
        
        # Combine TF-IDF with additional features
        additional_features = np.array(additional_features)
        additional_features = self.scaler.fit_transform(additional_features)
        
        # Combine features
        combined_features = np.hstack([tfidf_features.toarray(), additional_features])
        
        return combined_features
    
    def fit(self, articles: List[List[str]], labels: List[List[int]]):
        """Train model"""
        # Extract features
        X = self.extract_features(articles)
        
        # Flatten labels
        y = []
        for article_labels in labels:
            y.extend(article_labels)
        
        y = np.array(y)
        
        # Train classifier
        self.classifier.fit(X, y)
        self.is_trained = True
        
        return self
    
    def predict(self, article: List[str], num_sentences: int = 3) -> List[int]:
        """Prediksi untuk satu artikel"""
        if not self.is_trained:
            raise ValueError("Model belum dilatih!")
        
        # Extract features untuk artikel ini
        tfidf_features = self.vectorizer.transform(article)
        
        additional_features = []
        for i, sentence in enumerate(article):
            position_ratio = i / len(article) if len(article) > 1 else 0
            sentence_length = len(sentence.split())
            
            features = [
                position_ratio,
                sentence_length,
                len(sentence),
                sentence.count('.'),
                sentence.count(','),
            ]
            additional_features.append(features)
        
        additional_features = np.array(additional_features)
        additional_features = self.scaler.transform(additional_features)
        
        # Combine features
        X = np.hstack([tfidf_features.toarray(), additional_features])
        
        # Prediksi probabilitas
        probabilities = self.classifier.predict_proba(X)[:, 1]
        
        # Pilih top-k kalimat
        top_indices = np.argsort(probabilities)[-num_sentences:]
        
        # Convert ke binary labels
        labels = [0] * len(article)
        for idx in top_indices:
            labels[idx] = 1
            
        return labels

class BertExtractiveModel(BaseExtractiveModel):
    """
    BERT-based extractive summarization model
    """
    
    def __init__(self, model_name='indolem/indobert-base-uncased'):
        super().__init__()
        
        if not BERT_AVAILABLE:
            raise ImportError("transformers library tidak tersedia!")
        
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.bert_model = AutoModel.from_pretrained(model_name)
        self.classifier = LogisticRegression(random_state=42, max_iter=1000)
        self.scaler = StandardScaler()
        
        # Set device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.bert_model.to(self.device)
        
    def get_bert_embeddings(self, sentences: List[str]) -> np.ndarray:
        """Dapatkan BERT embeddings untuk kalimat"""
        embeddings = []
        
        self.bert_model.eval()
        with torch.no_grad():
            for sentence in sentences:
                # Tokenize
                inputs = self.tokenizer(
                    sentence, 
                    return_tensors='pt', 
                    max_length=512, 
                    truncation=True, 
                    padding=True
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                # Get embeddings
                outputs = self.bert_model(**inputs)
                # Use [CLS] token embedding
                cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
                embeddings.append(cls_embedding[0])
        
        return np.array(embeddings)
    
    def extract_features(self, articles: List[List[str]]) -> np.ndarray:
        """Extract BERT features dari artikel"""
        all_sentences = []
        for article in articles:
            all_sentences.extend(article)
        
        # Get BERT embeddings
        bert_features = self.get_bert_embeddings(all_sentences)
        
        # Additional positional features
        additional_features = []
        sentence_idx = 0
        
        for article in articles:
            for i, sentence in enumerate(article):
                position_ratio = i / len(article) if len(article) > 1 else 0
                sentence_length = len(sentence.split())
                
                features = [
                    position_ratio,
                    sentence_length,
                    len(sentence),
                ]
                
                additional_features.append(features)
                sentence_idx += 1
        
        additional_features = np.array(additional_features)
        additional_features = self.scaler.fit_transform(additional_features)
        
        # Combine BERT with additional features
        combined_features = np.hstack([bert_features, additional_features])
        
        return combined_features
    
    def fit(self, articles: List[List[str]], labels: List[List[int]]):
        """Train BERT model"""
        print("Extracting BERT features...")
        X = self.extract_features(articles)
        
        # Flatten labels
        y = []
        for article_labels in labels:
            y.extend(article_labels)
        
        y = np.array(y)
        
        print("Training classifier...")
        self.classifier.fit(X, y)
        self.is_trained = True
        
        return self
    
    def predict(self, article: List[str], num_sentences: int = 3) -> List[int]:
        """Prediksi untuk satu artikel"""
        if not self.is_trained:
            raise ValueError("Model belum dilatih!")
        
        # Get BERT embeddings
        bert_features = self.get_bert_embeddings(article)
        
        # Additional features
        additional_features = []
        for i, sentence in enumerate(article):
            position_ratio = i / len(article) if len(article) > 1 else 0
            sentence_length = len(sentence.split())
            
            features = [
                position_ratio,
                sentence_length,
                len(sentence),
            ]
            additional_features.append(features)
        
        additional_features = np.array(additional_features)
        additional_features = self.scaler.transform(additional_features)
        
        # Combine features
        X = np.hstack([bert_features, additional_features])
        
        # Prediksi probabilitas
        probabilities = self.classifier.predict_proba(X)[:, 1]
        
        # Pilih top-k kalimat
        top_indices = np.argsort(probabilities)[-num_sentences:]
        
        # Convert ke binary labels
        labels = [0] * len(article)
        for idx in top_indices:
            labels[idx] = 1
            
        return labels

def create_model(model_type: str = 'textrank') -> BaseExtractiveModel:
    """
    Factory function untuk membuat model
    
    Args:
        model_type (str): 'textrank', 'tfidf_lr', atau 'bert'
        
    Returns:
        BaseExtractiveModel: Instance model yang dipilih
    """
    if model_type == 'textrank':
        return TextRankModel()
    elif model_type == 'tfidf_lr':
        return TfidfLogisticModel()
    elif model_type == 'bert':
        if not BERT_AVAILABLE:
            raise ImportError("transformers library tidak tersedia untuk BERT model!")
        return BertExtractiveModel()
    else:
        raise ValueError(f"Model type '{model_type}' tidak dikenal!")

def main():
    """
    Test models
    """
    # Sample data
    sample_article = [
        "Presiden Joko Widodo mengumumkan kebijakan baru hari ini.",
        "Kebijakan tersebut berkaitan dengan ekonomi digital Indonesia.",
        "Menurut presiden, hal ini akan meningkatkan pertumbuhan ekonomi.",
        "Para ahli ekonomi memberikan respons positif terhadap kebijakan ini.",
        "Implementasi kebijakan akan dimulai bulan depan."
    ]
    
    # Test TextRank
    print("=== Testing TextRank Model ===")
    textrank_model = create_model('textrank')
    textrank_model.fit([], [])  # TextRank tidak perlu training
    
    prediction = textrank_model.predict(sample_article, num_sentences=2)
    print(f"TextRank prediction: {prediction}")
    
    selected_sentences = [sample_article[i] for i, label in enumerate(prediction) if label == 1]
    print("Selected sentences:")
    for sent in selected_sentences:
        print(f"- {sent}")

if __name__ == "__main__":
    main()
