"""
Grid search over LR × O-weight-factor for CamemBERT NER fine-tuning.

Imports core functions from train.py to avoid code duplication.
Each run uses early stopping (patience=3) so bad configs stop fast.

Usage:
  python search.py
  python search.py --lr_values 1e-5 3e-5 5e-5 --o_weight_factors 0.05 0.10 0.20
  python search.py --epochs 20 --batch_size 16

Results are saved incrementally to search_results.csv (survives crashes).
"""

import argparse
import csv
import gc
import itertools
import logging
import random
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
from datasets import Dataset
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainerCallback,
    TrainingArguments,
)

sys.path.insert(0, str(Path(__file__).parent))
from train import (
    DATASET_PATH,
    MODEL_NAME,
    WeightedNERTrainer,
    build_label_maps,
    compute_class_weights,
    compute_metrics_fn,
    load_dataset_csv,
    tokenize_and_align,
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

# Default grid
DEFAULT_LR_VALUES = [1e-5, 3e-5, 5e-5]
DEFAULT_O_WEIGHT_FACTORS = [0.05, 0.10, 0.20]


class EarlyStoppingNoSave(TrainerCallback):
    """Early stopping sans sauvegarde de checkpoints (pour le grid search)."""

    def __init__(self, patience=3):
        self.patience = patience
        self.best_loss = float("inf")
        self.no_improve = 0

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        loss = metrics.get("eval_loss", float("inf"))
        if loss < self.best_loss - 1e-4:
            self.best_loss = loss
            self.no_improve = 0
        else:
            self.no_improve += 1
            if self.no_improve >= self.patience:
                logger.info(f"Early stopping: no improvement for {self.patience} evals.")
                control.should_training_stop = True
        return control


def run_one(lr, o_weight_factor, train_samples, eval_samples, label2id, id2label,
            seed=42, epochs=15, batch_size=8):
    """Train one (lr, o_weight_factor) config. Returns (best_eval_loss, best_f1, epochs_run).
    Aucun checkpoint n'est sauvegardé sur disque."""

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForTokenClassification.from_pretrained(
        MODEL_NAME,
        num_labels=len(label2id),
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    )

    fn_kwargs = {"tokenizer": tokenizer, "label2id": label2id}
    train_ds = Dataset.from_list(train_samples).map(
        tokenize_and_align, batched=True, fn_kwargs=fn_kwargs,
        remove_columns=["tokens", "ner_tags"],
    )
    eval_ds = Dataset.from_list(eval_samples).map(
        tokenize_and_align, batched=True, fn_kwargs=fn_kwargs,
        remove_columns=["tokens", "ner_tags"],
    )

    class_weights = compute_class_weights(train_samples, label2id, o_weight_factor=o_weight_factor)

    training_args = TrainingArguments(
        output_dir="./search_tmp",   # dossier temp, rien d'utile n'y est écrit
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=lr,
        weight_decay=0.01,
        warmup_steps=50,
        eval_strategy="epoch",
        save_strategy="no",          # aucun checkpoint sauvegardé
        load_best_model_at_end=False,
        logging_steps=10,
        seed=seed,
        report_to="none",
    )

    early_stopper = EarlyStoppingNoSave(patience=3)

    trainer = WeightedNERTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=DataCollatorForTokenClassification(tokenizer),
        compute_metrics=compute_metrics_fn(id2label),
        class_weights=class_weights,
        callbacks=[early_stopper],
    )

    trainer.train()

    # Récupérer le meilleur loss et F1 depuis l'historique des logs
    eval_logs = [log for log in trainer.state.log_history if "eval_loss" in log]
    best_loss = min(log["eval_loss"] for log in eval_logs) if eval_logs else float("inf")
    best_f1 = max((log.get("eval_f1", 0.0) for log in eval_logs), default=0.0)
    epochs_run = int(trainer.state.epoch)

    # Libérer la mémoire avant le prochain run
    del model, trainer
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return best_loss, best_f1, epochs_run


def print_summary(results):
    print("\n" + "=" * 65)
    print("GRID SEARCH RESULTS  (sorted by eval_loss ↑ = better)")
    print("=" * 65)
    print(f"{'LR':>10}  {'O weight':>10}  {'Eval loss':>12}  {'F1':>8}  {'Epochs':>7}")
    print("-" * 65)
    for r in sorted(results, key=lambda x: x["best_eval_loss"]):
        print(
            f"{r['lr']:>10.0e}  {r['o_weight_factor']:>10.2f}"
            f"  {r['best_eval_loss']:>12.4f}  {r['eval_f1']:>8.4f}  {r['epochs_run']:>7}"
        )
    best = min(results, key=lambda x: x["best_eval_loss"])
    print(
        f"\nBest → LR={best['lr']:.0e}  O weight={best['o_weight_factor']}"
        f"  loss={best['best_eval_loss']:.4f}  F1={best['eval_f1']:.4f}"
    )


def main():
    parser = argparse.ArgumentParser(description="Grid search: LR × O-weight-factor")
    parser.add_argument("--results_csv", default="./search_results.csv")
    parser.add_argument("--epochs", type=int, default=15)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--eval_split", type=float, default=0.15)
    parser.add_argument("--lr_values", nargs="+", type=float, default=DEFAULT_LR_VALUES)
    parser.add_argument("--o_weight_factors", nargs="+", type=float, default=DEFAULT_O_WEIGHT_FACTORS)
    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # ── Load & split data (shared across all runs) ────────────────────────────
    samples = load_dataset_csv(DATASET_PATH)
    label2id, id2label = build_label_maps(samples)
    logger.info(f"Labels: {list(label2id.keys())}")

    random.shuffle(samples)
    split = max(1, int(len(samples) * args.eval_split))
    train_samples = samples[split:]
    eval_samples = samples[:split]
    logger.info(f"Train: {len(train_samples)} | Eval: {len(eval_samples)}")

    combos = list(itertools.product(args.lr_values, args.o_weight_factors))
    logger.info(f"\nGrid: {len(combos)} combinations")
    for lr, o in combos:
        logger.info(f"  LR={lr:.0e}  O weight={o}")

    # ── CSV header (written once, results appended incrementally) ─────────────
    results_path = Path(args.results_csv)
    with open(results_path, "w", newline="") as f:
        csv.writer(f).writerow(["lr", "o_weight_factor", "best_eval_loss", "eval_f1", "epochs_run", "timestamp"])

    results = []

    for i, (lr, o_weight_factor) in enumerate(combos):
        logger.info(f"\n{'='*65}")
        logger.info(f"Run {i + 1}/{len(combos)}  |  LR={lr:.0e}  |  O weight={o_weight_factor}")
        logger.info(f"{'='*65}")

        try:
            best_loss, f1, epochs_run = run_one(
                lr=lr,
                o_weight_factor=o_weight_factor,
                train_samples=train_samples,
                eval_samples=eval_samples,
                label2id=label2id,
                id2label=id2label,
                seed=args.seed,
                epochs=args.epochs,
                batch_size=args.batch_size,
            )
            row = {
                "lr": lr,
                "o_weight_factor": o_weight_factor,
                "best_eval_loss": best_loss,
                "eval_f1": f1,
                "epochs_run": epochs_run,
            }
            results.append(row)

            # Append to CSV immediately (crash-safe)
            with open(results_path, "a", newline="") as f:
                csv.writer(f).writerow([
                    lr, o_weight_factor,
                    f"{best_loss:.4f}", f"{f1:.4f}",
                    epochs_run, datetime.now().isoformat(),
                ])

            logger.info(f"Done → loss={best_loss:.4f}  F1={f1:.4f}  stopped at epoch {epochs_run}")

        except Exception as e:
            logger.error(f"Run failed: {e}", exc_info=True)

    if results:
        print_summary(results)
        print(f"\nFull results saved to {results_path}")
    else:
        logger.error("All runs failed.")


if __name__ == "__main__":
    main()
