
# train_from_binpack.py
import argparse, torch, torch.nn as nn
from torch.utils.data import DataLoader, random_split
from binpack_dataset import BinpackDataset

class TinyNNUE(nn.Module):
    # Simple MLP over HalfKP presence features -> scalar eval (pawns)
    def __init__(self, feat_n):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_n, 256), nn.ReLU(),
            nn.Linear(256, 32),     nn.ReLU(),
            nn.Linear(32, 32),      nn.ReLU(),
            nn.Linear(32, 1)
        )
    def forward(self, x): return self.net(x)

def main(args):
    ds = BinpackDataset(args.data)
    feat_n = ds.feat_n
    # small validation split
    val_size = max(1, int(0.05*len(ds)))
    train_size = len(ds) - val_size
    train_ds, val_ds = random_split(ds, [train_size, val_size])

    dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    vl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=0)

    device = torch.device("cuda" if (args.gpus and torch.cuda.is_available()) else "cpu")
    net = TinyNNUE(feat_n).to(device)
    opt = torch.optim.Adam(net.parameters(), lr=args.lr)
    loss_fn = nn.MSELoss()

    for epoch in range(args.epochs):
        net.train(); total=n=0
        for xb, yb in dl:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(); pred = net(xb); loss = loss_fn(pred, yb)
            loss.backward(); opt.step()
            total += loss.item()*xb.size(0); n += xb.size(0)
        train_loss = total/n

        net.eval(); vtot=vn=0
        with torch.no_grad():
            for xb, yb in vl:
                xb, yb = xb.to(device), yb.to(device)
                pred = net(xb); vtot += loss_fn(pred, yb).item()*xb.size(0); vn += xb.size(0)
        val_loss = vtot/vn
        print(f"epoch {epoch+1}/{args.epochs}  train={train_loss:.6f}  val={val_loss:.6f}")

    torch.save({"state_dict": net.state_dict(), "feat_n": feat_n}, args.out)
    print(f"Saved checkpoint -> {args.out}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--data", required=True)
    ap.add_argument("--epochs", type=int, default=5)
    ap.add_argument("--batch-size", type=int, default=256)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--gpus", type=int, default=0)  # set 1 if you want CUDA
    ap.add_argument("--out", default="tiny_nnue.pth")
    args = ap.parse_args()
    main(args)
