#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Data Loader untuk Dataset Liputan6 Extractive Summarization
"""

import json
import os
import glob
from typing import List, Dict, Tuple
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd

class Liputan6DataLoader:
    """
    Class untuk memuat dan memproses data Liputan6 untuk extractive summarization
    """
    
    def __init__(self, data_path: str = "liputan6_data/liputan6_data/canonical"):
        """
        Initialize data loader
        
        Args:
            data_path (str): Path ke folder canonical dataset
        """
        self.data_path = data_path
        self.train_path = os.path.join(data_path, "train")
        self.dev_path = os.path.join(data_path, "dev")
        self.test_path = os.path.join(data_path, "test")
        
    def load_json_files(self, folder_path: str, limit: int = None) -> List[Dict]:
        """
        Load semua file JSON dari folder
        
        Args:
            folder_path (str): Path ke folder yang berisi file JSON
            limit (int): Batasi jumlah file yang dimuat (untuk testing)
            
        Returns:
            List[Dict]: List berisi data dari semua file JSON
        """
        json_files = glob.glob(os.path.join(folder_path, "*.json"))
        
        if limit:
            json_files = json_files[:limit]
            
        data = []
        for file_path in json_files:
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data.append(json.load(f))
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                continue
                
        return data
    
    def preprocess_article(self, clean_article: List[List[str]]) -> List[str]:
        """
        Preprocess artikel dari format token menjadi kalimat
        
        Args:
            clean_article (List[List[str]]): Artikel dalam format token per kalimat
            
        Returns:
            List[str]: List kalimat yang sudah digabung
        """
        sentences = []
        for sentence_tokens in clean_article:
            # Gabungkan token menjadi kalimat
            sentence = " ".join(sentence_tokens)
            sentences.append(sentence)
        return sentences
    
    def preprocess_summary(self, clean_summary: List[List[str]]) -> str:
        """
        Preprocess summary dari format token menjadi teks
        
        Args:
            clean_summary (List[List[str]]): Summary dalam format token per kalimat
            
        Returns:
            str: Summary yang sudah digabung
        """
        summary_sentences = []
        for sentence_tokens in clean_summary:
            sentence = " ".join(sentence_tokens)
            summary_sentences.append(sentence)
        return " ".join(summary_sentences)
    
    def create_extractive_labels(self, article_sentences: List[str], 
                                extractive_indices: List[int]) -> List[int]:
        """
        Buat label binary untuk extractive summarization
        
        Args:
            article_sentences (List[str]): List kalimat artikel
            extractive_indices (List[int]): Indeks kalimat yang dipilih untuk summary
            
        Returns:
            List[int]: Label binary (1 untuk kalimat yang dipilih, 0 untuk yang tidak)
        """
        labels = [0] * len(article_sentences)
        for idx in extractive_indices:
            if idx < len(labels):
                labels[idx] = 1
        return labels
    
    def load_dataset(self, split: str = "train", limit: int = None) -> Tuple[List[List[str]], List[List[int]], List[str]]:
        """
        Load dataset untuk split tertentu
        
        Args:
            split (str): "train", "dev", atau "test"
            limit (int): Batasi jumlah data yang dimuat
            
        Returns:
            Tuple berisi:
            - articles: List artikel (setiap artikel adalah list kalimat)
            - labels: List label extractive (setiap label adalah list binary)
            - summaries: List summary reference untuk evaluasi
        """
        if split == "train":
            folder_path = self.train_path
        elif split == "dev":
            folder_path = self.dev_path
        elif split == "test":
            folder_path = self.test_path
        else:
            raise ValueError("Split harus 'train', 'dev', atau 'test'")
        
        # Load data JSON
        raw_data = self.load_json_files(folder_path, limit)
        
        articles = []
        labels = []
        summaries = []
        
        for item in raw_data:
            # Preprocess artikel
            article_sentences = self.preprocess_article(item['clean_article'])
            
            # Preprocess summary
            summary = self.preprocess_summary(item['clean_summary'])
            
            # Buat extractive labels
            extractive_labels = self.create_extractive_labels(
                article_sentences, 
                item['extractive_summary']
            )
            
            articles.append(article_sentences)
            labels.append(extractive_labels)
            summaries.append(summary)
        
        return articles, labels, summaries
    
    def get_dataset_info(self):
        """
        Dapatkan informasi statistik dataset
        """
        info = {}
        
        for split in ["train", "dev", "test"]:
            try:
                articles, labels, summaries = self.load_dataset(split, limit=100)  # Sample kecil untuk info
                
                # Hitung statistik
                avg_article_len = np.mean([len(article) for article in articles])
                avg_summary_len = np.mean([sum(label) for label in labels])
                
                info[split] = {
                    "num_samples": len(articles),
                    "avg_article_sentences": avg_article_len,
                    "avg_extractive_sentences": avg_summary_len
                }
            except Exception as e:
                info[split] = {"error": str(e)}
        
        return info

def main():
    """
    Test data loader
    """
    # Initialize data loader
    loader = Liputan6DataLoader()
    
    # Dapatkan info dataset
    print("=== Info Dataset ===")
    info = loader.get_dataset_info()
    for split, data in info.items():
        print(f"{split.upper()}: {data}")
    
    # Test load sample data
    print("\n=== Sample Data ===")
    articles, labels, summaries = loader.load_dataset("train", limit=2)
    
    for i, (article, label, summary) in enumerate(zip(articles, labels, summaries)):
        print(f"\nSample {i+1}:")
        print(f"Artikel ({len(article)} kalimat):")
        for j, sentence in enumerate(article):
            marker = ">>> " if label[j] == 1 else "    "
            print(f"{marker}{j}: {sentence[:100]}...")
        
        print(f"\nSummary: {summary[:200]}...")
        print(f"Extractive indices: {[i for i, l in enumerate(label) if l == 1]}")

if __name__ == "__main__":
    main()
