File size: 40,401 Bytes
39b7b21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
import numpy as np
import cv2 as cv
import os
from numpy.linalg import norm, inv
from scipy.stats import multivariate_normal as mv_norm
import joblib  # or import pickle
import os
import torch
from torch.distributions import MultivariateNormal
import torch.nn.functional as F
init_weight = [0.7, 0.11, 0.1, 0.09]
init_u = np.zeros(3)
# initial Covariance matrix
init_sigma = 225*np.eye(3)
init_alpha = 0.05

class GMM():
    def __init__(self, data_dir, train_num, alpha=init_alpha):
        self.data_dir = data_dir
        self.train_num = train_num
        self.alpha = alpha
        self.img_shape = None

        self.weight = None
        self.mu = None
        self.sigma = None
        self.K = None
        self.B = None

    def check(self, pixel, mu, sigma):
        '''
        Check whether a pixel matches a Gaussian distribution.
        Matching means the Mahalanobis distance is less than 2.5.
        '''
        # Convert to torch tensors on same device
        if isinstance(mu, np.ndarray):
            mu = torch.from_numpy(mu).float()
        if isinstance(sigma, np.ndarray):
            sigma = torch.from_numpy(sigma).float()
        if isinstance(pixel, np.ndarray):
            pixel = torch.from_numpy(pixel).float()
        
        # Ensure all are on the same device
        device = mu.device
        pixel = pixel.to(device)
        sigma = sigma.to(device)

        # Compute Mahalanobis distance
        delta = pixel - mu
        sigma_inv = torch.linalg.inv(sigma)
        d_squared = delta @ sigma_inv @ delta
        d = torch.sqrt(d_squared + 1e-5)

        return d.item() < 0.1
        
    def rgba_to_rgb_for_processing(image_path):
        img = cv.imread(image_path, cv.IMREAD_UNCHANGED)
        
        if img.shape[2] == 4:  # RGBA
            # Create white background
            rgb_img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255
            
            # Alpha blending: blend with white background
            alpha = img[:, :, 3:4] / 255.0
            rgb_img = rgb_img * (1 - alpha) + img[:, :, :3] * alpha
            
            return rgb_img.astype(np.uint8)
        else:
            return img

    
    def train(self, K=4):
        '''
        train model with GPU acceleration
        '''
        self.K = K
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        
        file_list = []
        for i in range(self.train_num):
            file_name = os.path.join(self.data_dir, 'b%05d' % i + '.png')
            file_list.append(file_name)

        # Initialize with first image
        img_init = cv.imread(file_list[0])
        img_shape = img_shape = img_init.shape
        self.img_shape = img_shape
        height, width, channels = img_shape
        
        # Initialize model parameters on GPU
        self.weight = torch.full((height, width, K), 1.0/K, 
                            dtype=torch.float32, device=device)
        self.mu = torch.zeros(height, width, K, 3, 
                        dtype=torch.float32, device=device)
        self.sigma = torch.zeros(height, width, K, 3, 3, 
                            dtype=torch.float32, device=device)
        self.B = torch.ones((height, width), 
                        dtype=torch.int32, device=device)
        
        # Initialize mu with first image values
        img_tensor = torch.from_numpy(img_init).float().to(device)
        for k in range(K):
            self.mu[:, :, k, :] = img_tensor
        
        # Initialize sigma with identity matrix * 225
        self.sigma[:] = torch.eye(3, device=device) * 225
        
        # Training loop
        for file in file_list:
            print('training:{}'.format(file))
            img = cv.imread(file)
            img_tensor = torch.from_numpy(img).float().to(device)  # (H,W,3)
            
            # Check matches for all pixels
            matches = torch.full((height, width), -1, dtype=torch.long, device=device)
            
            for k in range(K):
                # Calculate Mahalanobis distance for each distribution
                delta = img_tensor.unsqueeze(2) - self.mu  # (H,W,K,3)
                sigma_inv = torch.linalg.inv(self.sigma)  # (H,W,K,3,3)
                
                # Compute (x-μ)T Σ^-1 (x-μ)
                temp = torch.einsum('hwki,hwkij->hwkj', delta, sigma_inv)
                mahalanobis = torch.sqrt(torch.einsum('hwki,hwki->hwk', temp, delta))
                
                # Update matches where distance < 2.5 and not already matched
                match_mask = (mahalanobis[:,:,k] < 2.5) & (matches == -1)
                matches[match_mask] = k
            
            # Process matched pixels
            for k in range(K):
                # Get mask for current distribution matches
                mask = matches == k
                if mask.any():
                    # Get matched pixels
                    matched_pixels = img_tensor[mask]  # (N,3)
                    matched_mu = self.mu[:,:,k,:][mask]  # (N,3)
                    matched_sigma = self.sigma[:,:,k,:,:][mask]  # (N,3,3)
                    
                    try:
                        # Create multivariate normal distribution
                        mvn = MultivariateNormal(matched_mu, 
                                            covariance_matrix=matched_sigma)
                        
                        # Calculate rho
                        rho = self.alpha * torch.exp(mvn.log_prob(matched_pixels))
                        
                        # Update weights
                        self.weight[:,:,k][mask] = (1 - self.alpha) * self.weight[:,:,k][mask] + self.alpha
                        
                        # Update mu
                        delta = matched_pixels - matched_mu
                        self.mu[:,:,k,:][mask] += rho.unsqueeze(1) * delta
                        
                        # Update sigma
                        delta_outer = torch.einsum('bi,bj->bij', delta, delta)
                        sigma_update = rho.unsqueeze(1).unsqueeze(2) * (delta_outer - matched_sigma)
                        self.sigma[:,:,k,:,:][mask] += sigma_update
                        
                    except RuntimeError as e:
                        print(f"Error updating distribution {k}: {e}")
                        continue
            
            # Process non-matched pixels
            non_matched = matches == -1
            if non_matched.any():
                # Find least probable distribution for each non-matched pixel
                weight_non_matched = self.weight[non_matched]  # shape: (N, K)
                min_weight_idx = torch.argmin(weight_non_matched, dim=1)  # shape: (N,)
                
                # Create flat indices of non-matched pixels
                non_matched_indices = non_matched.nonzero(as_tuple=False)  # shape: (N, 2)

                for k in range(K):
                    # Find positions where min_weight_idx == k
                    k_mask = (min_weight_idx == k)
                    if k_mask.any():
                        selected_indices = non_matched_indices[k_mask]  # shape: (M, 2)
                        y_idx = selected_indices[:, 0]
                        x_idx = selected_indices[:, 1]
                        
                        # Update mu and sigma
                        self.mu[y_idx, x_idx, k, :] = img_tensor[y_idx, x_idx]
                        self.sigma[y_idx, x_idx, k, :, :] = torch.eye(3, device=device) * 225
            
            # Convert to numpy for reordering and debug prints
            weight_np = self.weight.cpu().numpy()
            mu_np = self.mu.cpu().numpy()
            sigma_np = self.sigma.cpu().numpy()
            B_np = self.B.cpu().numpy()
            
            print('img:{}'.format(img[100][100]))
            print('weight:{}'.format(weight_np[100][100]))
        
        # Update numpy arrays for reorder
        self.weight = weight_np
        self.mu = mu_np
        self.sigma = sigma_np
        self.B = B_np
        
        self.reorder()
        for i in range(self.K):
            print('u:{}'.format(self.mu[100][100][i]))
        
        # Move back to GPU for next iteration
        self.weight = torch.from_numpy(self.weight).to(device)
        self.mu = torch.from_numpy(self.mu).to(device)
        self.sigma = torch.from_numpy(self.sigma).to(device)
        self.B = torch.from_numpy(self.B).to(device)

    def save_model(self, file_path):
        """
        Save the trained model to a file
        """
        # Only make directories if there is a directory in the path
        dir_name = os.path.dirname(file_path)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)

        joblib.dump({
            'weight': self.weight,
            'mu': self.mu,
            'sigma': self.sigma,
            'K': self.K,
            'B': self.B,
            'img_shape': self.img_shape,
            'alpha': self.alpha,
            'data_dir': self.data_dir,
            'train_num': self.train_num
        }, file_path)

        print(f"Model saved to {file_path}")

    @classmethod
    def load_model(cls, file_path):
        """
        Load a trained model from file
        """
        data = joblib.load(file_path)
        
        # Create new instance
        gmm = cls(data['data_dir'], data['train_num'], data['alpha'])
        
        # Restore all attributes
        gmm.weight = data['weight']
        gmm.mu = data['mu']
        gmm.sigma = data['sigma']
        gmm.K = data['K']
        gmm.B = data['B']
        gmm.img_shape = data['img_shape']
        gmm.image_shape = data['img_shape']

        print(f"Model loaded from {file_path}")
        return gmm
    # @classmethod
    # def load_model(cls, file_path):
    #     """
    #     Load a trained model safely onto CPU, even if saved from GPU.
    #     """
    #     import pickle
    
    #     def cpu_load(path):
    #         with open(path, "rb") as f:
    #             unpickler = pickle._Unpickler(f)
    #             unpickler.persistent_load = lambda saved_id: torch.load(saved_id, map_location="cpu")
    #             return unpickler.load()
    
    #     # Force joblib to use pickle with CPU-mapped tensors
    #     data = cpu_load(file_path)
    
    #     # Create instance
    #     gmm = cls(data['data_dir'], data['train_num'], data['alpha'])
    
        # Assign all attributes (already CPU tensors now)
        gmm.weight = data['weight']
        gmm.mu = data['mu']
        gmm.sigma = data['sigma']
        gmm.K = data['K']
        gmm.B = data['B']
        gmm.img_shape = data['img_shape']
        gmm.image_shape = data['img_shape']
    
        print(f"✅ GMM model loaded on CPU from {file_path}")
        return gmm

    


    def reorder(self, T=0.90):
        '''
        Reorder the estimated components based on the ratio pi / the norm of standard deviation.
        The first B components are chosen as background components.
        The default threshold is 0.90.
        '''
        epsilon = 1e-6  # to prevent divide-by-zero

        for i in range(self.img_shape[0]):
            for j in range(self.img_shape[1]):
                k_weight = self.weight[i][j]
                k_norm = []

                for k in range(self.K):
                    cov = self.sigma[i][j][k]
                    try:
                        if np.all(np.linalg.eigvals(cov) >= 0):
                            # stddev = np.sqrt(cov)
                            epsilon = 1e-6
                            stddev = np.sqrt(np.maximum(cov, epsilon))
                            k_norm.append(norm(stddev))
                        else:
                            k_norm.append(epsilon)
                    except:
                        k_norm.append(epsilon)

                k_norm = np.array(k_norm)
                ratio = k_weight / (k_norm + epsilon)
                descending_order = np.argsort(-ratio)

                self.weight[i][j] = self.weight[i][j][descending_order]
                self.mu[i][j] = self.mu[i][j][descending_order]
                self.sigma[i][j] = self.sigma[i][j][descending_order]

                cum_weight = 0
                for index, order in enumerate(descending_order):
                    cum_weight += self.weight[i][j][index]
                    if cum_weight > T:
                        self.B[i][j] = index + 1
                        break
    from typing import Tuple, Optional
    
    def region_propfill_enhancement(self, binary_mask: np.ndarray, 
                               table_mask: Optional[np.ndarray] = None,  # ADDED parameter
                               dilation_kernel_size: int = 5,
                               dilation_iterations: int = 2,
                               erosion_iterations: int = 1,
                               fill_threshold: int = 200,
                               min_contour_area: int = 50) -> Tuple[np.ndarray, np.ndarray]:
        """
        Enhance GMM binary prediction mask using dilation and region filling.
        
        Args:
            binary_mask: Binary mask from GMM detection (True for detected foreground)
            table_mask: Optional binary mask defining table area (restricts processing)
            dilation_kernel_size: Size of dilation kernel (odd number)
            dilation_iterations: Number of dilation iterations to connect fragments
            erosion_iterations: Number of erosion iterations to restore original size
            fill_threshold: Threshold for flood fill operation
            min_contour_area: Minimum contour area to consider for processing
            
        Returns:
            enhanced_mask: Improved binary mask with filled regions
            debug_info: Dictionary containing intermediate results for debugging
        """
        
        # Convert boolean mask to uint8 if needed
        if binary_mask.dtype == bool:
            mask_uint8 = (binary_mask * 255).astype(np.uint8)
        else:
            mask_uint8 = binary_mask.astype(np.uint8)
        
        # Apply table mask if provided - CRITICAL FIX
        if table_mask is not None:
            # Ensure table_mask matches dimensions
            if table_mask.shape != mask_uint8.shape:
                table_mask = cv.resize(table_mask.astype(np.uint8), 
                                     (mask_uint8.shape[1], mask_uint8.shape[0]), 
                                     interpolation=cv.INTER_NEAREST) > 0
            # Zero out everything outside table area
            mask_uint8[~table_mask] = 0
        
        # Store original for comparison
        original_mask = mask_uint8.copy()
        
        # Step 1: Apply dilation to connect fragmented detections
        kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, 
                                          (dilation_kernel_size, dilation_kernel_size))
        
        # Dilate to connect nearby fragments
        dilated_mask = cv.dilate(mask_uint8, kernel, iterations=dilation_iterations)
        
        # Step 2: Apply flood fill to fill internal holes
        filled_mask = dilated_mask.copy()
        h, w = filled_mask.shape
        
        # Create flood fill mask (needs to be 2 pixels larger)
        flood_mask = np.zeros((h + 2, w + 2), np.uint8)
        
        # Find contours to identify individual objects
        contours, _ = cv.findContours(dilated_mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
        
        # Process each contour separately
        enhanced_mask = np.zeros_like(filled_mask)
        
        for contour in contours:
            # Filter out small contours
            if cv.contourArea(contour) < min_contour_area:
                continue
                
            # Create mask for this contour
            contour_mask = np.zeros_like(filled_mask)
            cv.drawContours(contour_mask, [contour], -1, 255, -1)
            
            # Get bounding rectangle
            x, y, w_rect, h_rect = cv.boundingRect(contour)
            
            # Create region of interest
            roi = contour_mask[y:y+h_rect, x:x+w_rect].copy()
            
            if roi.size == 0:
                continue
                
            # Apply flood fill from borders to fill external areas
            roi_filled = roi.copy()
            roi_h, roi_w = roi_filled.shape
            
            # Create flood mask for ROI
            roi_flood_mask = np.zeros((roi_h + 2, roi_w + 2), np.uint8)
            
            # Flood fill from all border points to mark external areas
            border_points = []
            # Top and bottom borders
            for i in range(roi_w):
                if roi_filled[0, i] == 0:
                    border_points.append((i, 0))
                if roi_filled[roi_h-1, i] == 0:
                    border_points.append((i, roi_h-1))
            
            # Left and right borders  
            for i in range(roi_h):
                if roi_filled[i, 0] == 0:
                    border_points.append((0, i))
                if roi_filled[i, roi_w-1] == 0:
                    border_points.append((roi_w-1, i))
            
            # Apply flood fill from border points
            external_mask = np.zeros_like(roi_filled)
            for point in border_points:
                if roi_filled[point[1], point[0]] == 0:
                    cv.floodFill(external_mask, roi_flood_mask, point, 255)
            
            # Invert to get internal areas
            internal_mask = cv.bitwise_not(external_mask)
            
            # Combine with original contour
            filled_contour = cv.bitwise_or(roi, internal_mask)
            
            # Place back in full image
            enhanced_mask[y:y+h_rect, x:x+w_rect] = cv.bitwise_or(
                enhanced_mask[y:y+h_rect, x:x+w_rect], filled_contour)
        
        # Step 3: Optional erosion to restore approximate original size
        if erosion_iterations > 0:
            erosion_kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, 
                                                      (dilation_kernel_size, dilation_kernel_size))
            enhanced_mask = cv.erode(enhanced_mask, erosion_kernel, iterations=erosion_iterations)
        
        # Step 4: Ensure we don't lose original detections AND respect table boundary
        enhanced_mask = cv.bitwise_or(enhanced_mask, original_mask)
        
        # RE-APPLY TABLE MASK - Ensure no processing outside table
        if table_mask is not None:
            enhanced_mask[~table_mask] = 0
        
        # Convert back to boolean if input was boolean
        if binary_mask.dtype == bool:
            enhanced_mask = enhanced_mask > 0
        
        # Create debug info
        debug_info = {
            'original_mask': original_mask,
            'dilated_mask': dilated_mask,
            'enhanced_mask': enhanced_mask,
            'num_contours_processed': len([c for c in contours if cv.contourArea(c) >= min_contour_area])
        }
        
        return enhanced_mask, debug_info
    
    def draw_heatmap_colorbar(self, frame: np.ndarray, heatmap: np.ndarray) -> np.ndarray:
        """
        Draw a vertical heatmap color bar on the right side of the frame.
        
        Args:
            frame: Original frame
            heatmap: Heatmap array with values 0-1
            
        Returns:
            Frame with color bar overlay
        """
        height, width = frame.shape[:2]
        
        # Color bar dimensions
        bar_width = 30
        bar_height = int(height * 0.6)
        bar_x = width - bar_width - 20
        bar_y = int(height * 0.2)
        
        # Create gradient color bar
        gradient = np.linspace(1, 0, bar_height).reshape(-1, 1)
        gradient = np.tile(gradient, (1, bar_width))
        
        # Convert to color using JET colormap
        gradient_colored = cv.applyColorMap((gradient * 255).astype(np.uint8), cv.COLORMAP_JET)
        
        # Add border and background
        cv.rectangle(frame, (bar_x - 2, bar_y - 2), 
                    (bar_x + bar_width + 2, bar_y + bar_height + 2), (255, 255, 255), 2)
        cv.rectangle(frame, (bar_x - 1, bar_y - 1), 
                    (bar_x + bar_width + 1, bar_y + bar_height + 1), (0, 0, 0), 1)
        
        # Place color bar
        frame[bar_y:bar_y+bar_height, bar_x:bar_x+bar_width] = gradient_colored
        
        # Add labels
        labels = ["1.0", "0.75", "0.5", "0.25", "0.0"]
        label_positions = [0, 0.25, 0.5, 0.75, 1.0]
        
        for label, pos in zip(labels, label_positions):
            y_pos = bar_y + int(pos * bar_height)
            cv.putText(frame, label, (bar_x + bar_width + 5, y_pos + 5),
                      cv.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        
        # Add title
        cv.putText(frame, "HEAT", (bar_x - 5, bar_y - 10),
                  cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        
        # Add current max value
        max_heat = heatmap.max()
        cv.putText(frame, f"Max: {max_heat:.2f}", (bar_x - 20, bar_y + bar_height + 20),
                  cv.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        
        return frame

    def region_propfill_enhancement(self, binary_mask: np.ndarray, 
                               table_mask: Optional[np.ndarray] = None,  # ADDED parameter
                               dilation_kernel_size: int = 5,
                               dilation_iterations: int = 2,
                               erosion_iterations: int = 1,
                               fill_threshold: int = 200,
                               min_contour_area: int = 50) -> Tuple[np.ndarray, np.ndarray]:
        """
        Enhance GMM binary prediction mask using dilation and region filling.
        
        Args:
            binary_mask: Binary mask from GMM detection (True for detected foreground)
            table_mask: Optional binary mask defining table area (restricts processing)
            dilation_kernel_size: Size of dilation kernel (odd number)
            dilation_iterations: Number of dilation iterations to connect fragments
            erosion_iterations: Number of erosion iterations to restore original size
            fill_threshold: Threshold for flood fill operation
            min_contour_area: Minimum contour area to consider for processing
            
        Returns:
            enhanced_mask: Improved binary mask with filled regions
            debug_info: Dictionary containing intermediate results for debugging
        """
        
        # Convert boolean mask to uint8 if needed
        if binary_mask.dtype == bool:
            mask_uint8 = (binary_mask * 255).astype(np.uint8)
        else:
            mask_uint8 = binary_mask.astype(np.uint8)
        
        # Apply table mask if provided - CRITICAL FIX
        if table_mask is not None:
            # Ensure table_mask matches dimensions
            if table_mask.shape != mask_uint8.shape:
                table_mask = cv.resize(table_mask.astype(np.uint8), 
                                     (mask_uint8.shape[1], mask_uint8.shape[0]), 
                                     interpolation=cv.INTER_NEAREST) > 0
            # Zero out everything outside table area
            mask_uint8[~table_mask] = 0
        
        # Store original for comparison
        original_mask = mask_uint8.copy()
        
        # Step 1: Apply dilation to connect fragmented detections
        kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, 
                                          (dilation_kernel_size, dilation_kernel_size))
        
        # Dilate to connect nearby fragments
        dilated_mask = cv.dilate(mask_uint8, kernel, iterations=dilation_iterations)
        
        # Step 2: Apply flood fill to fill internal holes
        filled_mask = dilated_mask.copy()
        h, w = filled_mask.shape
        
        # Create flood fill mask (needs to be 2 pixels larger)
        flood_mask = np.zeros((h + 2, w + 2), np.uint8)
        
        # Find contours to identify individual objects
        contours, _ = cv.findContours(dilated_mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
        
        # Process each contour separately
        enhanced_mask = np.zeros_like(filled_mask)
        
        for contour in contours:
            # Filter out small contours
            if cv.contourArea(contour) < min_contour_area:
                continue
                
            # Create mask for this contour
            contour_mask = np.zeros_like(filled_mask)
            cv.drawContours(contour_mask, [contour], -1, 255, -1)
            
            # Get bounding rectangle
            x, y, w_rect, h_rect = cv.boundingRect(contour)
            
            # Create region of interest
            roi = contour_mask[y:y+h_rect, x:x+w_rect].copy()
            
            if roi.size == 0:
                continue
                
            # Apply flood fill from borders to fill external areas
            roi_filled = roi.copy()
            roi_h, roi_w = roi_filled.shape
            
            # Create flood mask for ROI
            roi_flood_mask = np.zeros((roi_h + 2, roi_w + 2), np.uint8)
            
            # Flood fill from all border points to mark external areas
            border_points = []
            # Top and bottom borders
            for i in range(roi_w):
                if roi_filled[0, i] == 0:
                    border_points.append((i, 0))
                if roi_filled[roi_h-1, i] == 0:
                    border_points.append((i, roi_h-1))
            
            # Left and right borders  
            for i in range(roi_h):
                if roi_filled[i, 0] == 0:
                    border_points.append((0, i))
                if roi_filled[i, roi_w-1] == 0:
                    border_points.append((roi_w-1, i))
            
            # Apply flood fill from border points
            external_mask = np.zeros_like(roi_filled)
            for point in border_points:
                if roi_filled[point[1], point[0]] == 0:
                    cv.floodFill(external_mask, roi_flood_mask, point, 255)
            
            # Invert to get internal areas
            internal_mask = cv.bitwise_not(external_mask)
            
            # Combine with original contour
            filled_contour = cv.bitwise_or(roi, internal_mask)
            
            # Place back in full image
            enhanced_mask[y:y+h_rect, x:x+w_rect] = cv.bitwise_or(
                enhanced_mask[y:y+h_rect, x:x+w_rect], filled_contour)
        
        # Step 3: Optional erosion to restore approximate original size
        if erosion_iterations > 0:
            erosion_kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, 
                                                      (dilation_kernel_size, dilation_kernel_size))
            enhanced_mask = cv.erode(enhanced_mask, erosion_kernel, iterations=erosion_iterations)
        
        # Step 4: Ensure we don't lose original detections AND respect table boundary
        enhanced_mask = cv.bitwise_or(enhanced_mask, original_mask)
        
        # RE-APPLY TABLE MASK - Ensure no processing outside table
        if table_mask is not None:
            enhanced_mask[~table_mask] = 0
        
        # Convert back to boolean if input was boolean
        if binary_mask.dtype == bool:
            enhanced_mask = enhanced_mask > 0
        
        # Create debug info
        debug_info = {
            'original_mask': original_mask,
            'dilated_mask': dilated_mask,
            'enhanced_mask': enhanced_mask,
            'num_contours_processed': len([c for c in contours if cv.contourArea(c) >= min_contour_area])
        }
        
        return enhanced_mask, debug_info
        
    def visualize_mask_enhancement(self, original_mask: np.ndarray, 
                                  enhanced_mask: np.ndarray, 
                                  debug_info: dict,
                                  window_prefix: str = "Enhancement"):
        """
        Visualize the mask enhancement process.
        
        Args:
            original_mask: Original binary mask
            enhanced_mask: Enhanced binary mask  
            debug_info: Debug information from enhancement process
            window_prefix: Prefix for window names
        """
        
        # Convert boolean masks to uint8 for display
        if original_mask.dtype == bool:
            orig_display = (original_mask * 255).astype(np.uint8)
        else:
            orig_display = original_mask.astype(np.uint8)
            
        if enhanced_mask.dtype == bool:
            enhanced_display = (enhanced_mask * 255).astype(np.uint8)
        else:
            enhanced_display = enhanced_mask.astype(np.uint8)
        
        # Show progression
        cv.imshow(f"{window_prefix} - Original Mask", orig_display)
        cv.imshow(f"{window_prefix} - Dilated Mask", debug_info['dilated_mask'])
        cv.imshow(f"{window_prefix} - Enhanced Mask", enhanced_display)
        
        # Show difference
        difference = cv.absdiff(enhanced_display, orig_display)
        cv.imshow(f"{window_prefix} - Added Regions", difference)
        
        # print(f"Processed {debug_info['num_contours_processed']} contours")
        
    def infer(self, img, heatmap=None, alpha_start=0.002, alpha_end=0.0001, 
              table_mask=None, cleaning_mask=None):
        """
        Inference with proper resizing to avoid spatial distortion:
        - Preserves original aspect ratios
        - Minimizes resize operations
        - Ensures spatial consistency between input and output
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Store original dimensions
        orig_H, orig_W = img.shape[:2]
        
        # Get model's expected dimensions
        model_H, model_W = self.B.shape[:2]
        
        # Check if resizing is needed
        needs_resize = (orig_H, orig_W) != (model_H, model_W)
        
        if needs_resize:
            print(f"🔧 Resizing input from ({orig_H}, {orig_W}) to model size ({model_H}, {model_W})")
            
            # Use INTER_LINEAR for better quality, avoid INTER_NEAREST
            img_resized = cv.resize(img, (model_W, model_H), interpolation=cv.INTER_LINEAR)
            img_tensor = torch.from_numpy(img_resized).float().to(device)
            
            # Process table mask with same interpolation
            if table_mask is not None:
                print(f"🔧 Resizing table mask from {table_mask.shape} to ({model_H}, {model_W})")
                # Use INTER_NEAREST for binary masks to preserve sharp edges
                table_mask_resized = cv.resize(table_mask.astype(np.uint8), (model_W, model_H), 
                                             interpolation=cv.INTER_NEAREST)
                table_mask_tensor = torch.from_numpy(table_mask_resized > 0).bool().to(device)
            else:
                table_mask_tensor = torch.ones((model_H, model_W), dtype=torch.bool, device=device)
                
            # Resize existing heatmap if provided
            if heatmap is not None:
                if heatmap.shape != (model_H, model_W):
                    heatmap_resized = cv.resize(heatmap, (model_W, model_H), interpolation=cv.INTER_LINEAR)
                    heatmap = torch.from_numpy(heatmap_resized).float().to(device)
                else:
                    heatmap = torch.from_numpy(heatmap).float().to(device)
            else:
                heatmap = torch.zeros((model_H, model_W), dtype=torch.float32, device=device)
                
            working_H, working_W = model_H, model_W
            
        else:
            # No resizing needed
            img_tensor = torch.from_numpy(img).float().to(device)
            
            if table_mask is not None:
                table_mask_tensor = torch.from_numpy(table_mask > 0).bool().to(device)
            else:
                table_mask_tensor = torch.ones((orig_H, orig_W), dtype=torch.bool, device=device)
                
            if heatmap is not None:
                heatmap = torch.from_numpy(heatmap).float().to(device)
            else:
                heatmap = torch.zeros((orig_H, orig_W), dtype=torch.float32, device=device)
                
            working_H, working_W = orig_H, orig_W
    
        # Initialize foreground detection mask
        detection_mask = table_mask_tensor.clone()
    
        # GMM processing (unchanged)
        for k in range(self.K):
            B_mask = (self.B >= (k + 1)).to(device)
            B_mask = B_mask & table_mask_tensor
            
            mu_k = self.mu[:, :, k, :].to(device)
            sigma_k = self.sigma[:, :, k, :, :].to(device)
    
            delta = img_tensor - mu_k
            delta = delta.unsqueeze(-1)
            sigma_inv = torch.linalg.inv(sigma_k)
            temp = torch.matmul(sigma_inv, delta)
            dist_sq = torch.matmul(delta.transpose(-2, -1), temp).squeeze(-1).squeeze(-1)
            dist = torch.sqrt(dist_sq + 1e-5)
    
            match_mask = (dist < 7.0) & B_mask
            detection_mask[match_mask] = False
            img_tensor[match_mask] = mu_k[match_mask]
    
        # Foreground detection
        foreground_mask = detection_mask & (img_tensor.abs().sum(dim=-1) > 0) & table_mask_tensor
        #------------------------------------------------------------Below line was replaced with region propfill code
        # filled_mask = foreground_mask


        # === REGION PROPFILL ENHANCEMENT ===
        # Convert foreground mask to numpy for processing
        foreground_np = foreground_mask.detach().cpu().numpy()
        table_mask_np = table_mask_tensor.detach().cpu().numpy() if table_mask_tensor is not None else None
        # Apply region propfill enhancement with hardcoded parameters
        enhanced_mask, debug_info = self.region_propfill_enhancement(
            foreground_np,table_mask=table_mask_np, 
            dilation_kernel_size=3,      # Hardcoded: size of dilation kernel
            dilation_iterations=1,       # Hardcoded: connect nearby fragments
            erosion_iterations=2,        # Hardcoded: restore original size
            fill_threshold=230,          # Hardcoded: threshold for flood fill
            min_contour_area=200         # Hardcoded: filter small noise
        )
        
        # Convert enhanced mask back to tensor
        filled_mask = torch.from_numpy(enhanced_mask).bool().to(device)
        
        # Optional: Print enhancement statistics
        if np.any(enhanced_mask != foreground_np):
            added_pixels = np.sum(enhanced_mask) - np.sum(foreground_np)
            # print(f"🔧 Region propfill added {added_pixels} pixels to fill hollow regions")
        #---------------------------------------------------------------------------------------------------------------------------------
        # Heatmap accumulation
        # pixelwise_alpha = alpha_start - (heatmap * (alpha_start - alpha_end))
        # pixelwise_alpha = torch.clamp(pixelwise_alpha, min=alpha_end)
    
        # heatmap = torch.where(
        #     filled_mask & table_mask_tensor,
        #     torch.clamp(heatmap + pixelwise_alpha, 0, 1),
        #     heatmap
        # )

        if heatmap is None:
            heatmap = torch.zeros((working_H, working_W), dtype=torch.float32, device=device)
            
        pixelwise_alpha = alpha_start - (heatmap * (alpha_start - alpha_end))
        pixelwise_alpha = torch.clamp(pixelwise_alpha, min=alpha_end)
        
        # === ACCUMULATION: Grow heatmap slowly where foreground detected ===
        heatmap = torch.where(
            filled_mask & table_mask_tensor,
            torch.clamp(heatmap + pixelwise_alpha * 0.3, 0, 1),  # 0.3 factor = SLOW growth
            heatmap
        )
        if cleaning_mask is not None:
            # Convert cleaning mask to tensor
            cleaning_tensor = torch.from_numpy(cleaning_mask > 0).bool().to(device)
            
            # Ensure dimensions match
            if cleaning_tensor.shape != heatmap.shape:
                # This shouldn't happen, but safety check
                pass
            
            # Calculate decay rate (slower for older/hotter areas)
            decay_alpha = alpha_start - (heatmap * (alpha_start - alpha_end))
            decay_alpha = torch.clamp(decay_alpha, min=alpha_end)
            
            # Apply gradual decay where cleaning
            heatmap = torch.where(
                cleaning_tensor & table_mask_tensor,
                torch.clamp(heatmap - decay_alpha * 0.8, 0, 1),  # 0.8 = decay slightly faster than growth
                heatmap
            )
        # === CRITICAL: Proper output resizing ===
        heatmap_np = heatmap.detach().cpu().numpy()
        
        if needs_resize:
            # Resize results back to original dimensions
            # Use high-quality interpolation for final output
            result_img = cv.resize(img_tensor.detach().cpu().numpy(), (orig_W, orig_H), 
                                  interpolation=cv.INTER_LINEAR)
            
            # For heatmap, use INTER_LINEAR to preserve smooth gradients
            heatmap_np = cv.resize(heatmap_np, (orig_W, orig_H), interpolation=cv.INTER_LINEAR)
            
            # Resize table mask back for final masking
            if table_mask is not None:
                table_mask_final = cv.resize(table_mask_tensor.detach().cpu().numpy().astype(np.uint8), 
                                           (orig_W, orig_H), interpolation=cv.INTER_NEAREST) > 0
                heatmap_np = heatmap_np * table_mask_final
            
            # Use original image for blending
            result = img.copy()
        else:
            result_img = img_tensor.detach().cpu().numpy()
            result = img.copy()
            
            if table_mask is not None:
                table_mask_np = table_mask_tensor.detach().cpu().numpy()
                heatmap_np = heatmap_np * table_mask_np
    
        # Visualization with proper blending
        # heatmap_viz = cv.applyColorMap((heatmap_np * 255).astype(np.uint8), cv.COLORMAP_JET)
        # significant_heat = (heatmap_np > 0.1)
    
        # if np.any(significant_heat):
        #     img_region = result[significant_heat]
        #     heat_region = heatmap_viz[significant_heat]
    
        #     if img_region.size > 0 and heat_region.size > 0:
        #         blended = cv.addWeighted(img_region, 0.7, heat_region, 0.3, 0)
        #         result[significant_heat] = blended
    
        # return result, heatmap_np
        # === FIX: Ensure heatmap stays ONLY within table bounds ===
        if table_mask is not None:
            # Match dimensions
            if table_mask.shape != heatmap_np.shape:
                table_mask_resized = cv.resize(
                    table_mask.astype(np.uint8), 
                    (heatmap_np.shape[1], heatmap_np.shape[0]), 
                    interpolation=cv.INTER_NEAREST
                )
                table_mask_final = table_mask_resized > 0
            else:
                table_mask_final = table_mask > 0
            
            # CRITICAL: Zero out heatmap completely outside table
            heatmap_np = heatmap_np * table_mask_final.astype(np.float32)
        else:
            table_mask_final = np.ones(heatmap_np.shape, dtype=bool)
        
        # Create visualization ONLY on table area (no blue background)
        heatmap_colored = cv.applyColorMap(
            (heatmap_np * 255).astype(np.uint8), 
            cv.COLORMAP_JET
        )
        
        # Apply transparency: only blend where heatmap > threshold AND inside table
        significant_heat = (heatmap_np > 0.1) & table_mask_final
        
        if np.any(significant_heat):
            # Blend ONLY significant areas
            result_blended = result.copy()
            result_blended[significant_heat] = cv.addWeighted(
                result[significant_heat], 0.7, 
                heatmap_colored[significant_heat], 0.3, 0
            )
            result = result_blended
        
        return result, heatmap_np