diff --git a/trainer.py b/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..16b8d4bac2f2f545f77b190f9aa412b9acc65b79
--- /dev/null
+++ b/trainer.py
@@ -0,0 +1,705 @@
+from __future__ import print_function, absolute_import
+import time
+import numpy as np
+import torch
+from sklearn.metrics import accuracy_score
+from evaluation.eval import accuracy
+from utils.meters import AverageMeter
+from torch import nn
+
+class BaseTrainer(object):
+    def __init__(self):
+        super(BaseTrainer, self).__init__()
+
+    def train(self, combined_loader, optimizer, epochs, stage, print_freq=1, mode='ir'):
+        raise NotImplementedError
+
+    def _parse_data(self, inputs):
+        raise NotImplementedError
+
+    def _forward(self, inputs, targets, opt):
+        raise NotImplementedError
+
+
+class UnsupervisedTrainer(BaseTrainer):
+    def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
+        super().__init__()
+        self.shared_model = shared_model
+        self.ir_obj = ir_obj
+        self.is_adapt = is_adapt
+        self.mode = mode
+
+        self.ir_targets = torch.ones((32, 1)).cuda()
+        self.rgb_targets = torch.zeros((32, 1)).cuda()
+        self.domain_loss = nn.BCEWithLogitsLoss()
+        self.triplet_loss = nn.TripletMarginLoss(swap=True)
+
+    def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
+        self.shared_model.train()
+
+        # if early_stopping:
+        #     print(f"Early stopping at epoch: {epoch}")
+        #     break
+
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        ir_losses = AverageMeter()
+        rgb_losses = AverageMeter()
+        domain_losses = AverageMeter()
+        part_losses = AverageMeter()
+        ir_precisions = AverageMeter()
+        domain_precisions = AverageMeter()
+        end = time.time()
+
+        for i, triplet_inputs in enumerate(combined_loader[0]):
+            if i == 200:
+                break
+            data_time.update(time.time() - end)
+
+            # ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
+            anchor_rgb, pid, _, rgb_positive, rgb_negative, ir_negative, ir_positive = triplet_inputs
+
+            ir_loss, ir_prec, domain_loss, domain_prec = self._forward(#ir_inputs, ir_targets,
+                                                                       [anchor_rgb, rgb_positive, rgb_negative,
+                                                                        ir_negative, ir_positive, pid])
+
+            ir_losses.update(ir_loss.item(), anchor_rgb.size(0))
+            domain_losses.update(domain_loss, anchor_rgb.size(0))
+            ir_precisions.update(ir_prec, anchor_rgb.size(0))
+            domain_precisions.update(domain_prec, anchor_rgb.size(0))
+
+            combined_loss = ir_loss + domain_loss
+            optimizer.zero_grad()
+            combined_loss.backward()
+            optimizer.step()
+
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if True:#(i + 1) % print_freq == 0:
+                print('Epoch: [{}][{}/{}]\t'
+                      'Time {:.3f} ({:.3f})\t'
+                      'Data {:.3f} ({:.3f})\t'
+                      'IR Loss {:.3f} ({:.3f})\t'
+                      'IR Prec {:.2%} ({:.2%})\t'
+                      'Domain Loss {:.3f} ({:.3f})\t'
+                      'Domain Prec {:.2%} ({:.2%})\t'
+                      'LR :{:.3f}\t\t'
+                      .format(epochs, i + 1, len(combined_loader),
+                              batch_time.val, batch_time.avg,
+                              data_time.val, data_time.avg,
+                              ir_losses.val, ir_losses.avg,
+                              ir_precisions.val, ir_precisions.avg,
+                              domain_losses.val, domain_losses.avg,
+                              domain_precisions.val, domain_precisions.avg,
+                              lr), end='\r')
+
+    def _parse_data(self, inputs):
+        imgs, pids, camids, clusterid, condition = inputs
+        # clusterid = pids
+        imgs.requires_grad = False
+        clusterid = [int(i) for i in clusterid]
+        clusterid = torch.IntTensor(clusterid).long()
+        imgs, clusterid = imgs.cuda(), clusterid.cuda()
+        return imgs, clusterid, camids
+
+    def _forward(self, triplets):
+        # domain_class, ir_fmap = self.shared_model(inputs, is_classifier=True)
+        # ir_loss, ir_data = self.ir_obj(ir_fmap, targets, None)
+        # ir_prec, = accuracy(ir_data.data, targets.data)
+        # self.rgb_targets = torch.zeros((inputs.size(0), 1)).cuda()
+
+        # domain_loss = self.domain_loss(domain_class, self.rgb_targets)
+
+        # domain_prec = accuracy_score(self.rgb_targets.data.cpu().numpy(), np.round(torch.sigmoid(domain_class.data).cpu().numpy()))
+        # domain_prec, = accuracy(domain_class.data, self.rgb_targets.data)
+
+        anchor_rgb, rgb_positive, rgb_negative, ir_negative, ir_positive, pid = triplets
+        _, anchor_rgb = self.shared_model(anchor_rgb.cuda(), modal=1)
+        # _, rgb_positive = self.shared_model(rgb_positive.cuda(), modal=1)
+        _, rgb_negative = self.shared_model(rgb_negative.cuda(), modal=1)
+        _, ir_negative = self.shared_model(ir_negative.cuda(), modal=2)
+        _, ir_positive = self.shared_model(ir_positive.cuda(), modal=2)
+        triplet_loss = self.triplet_loss(anchor_rgb, ir_positive, rgb_negative)  # +
+        # triplet_loss = self.triplet_loss(ir_positive, anchor_rgb, ir_negative)
+
+        l1_lambda = 0.000001#0.000001
+        l1_penalty_loss = 0
+        for output in ir_positive:
+            l1_penalty_loss += torch.norm(output, 1)
+        l1_penalty_loss *= l1_lambda
+
+        # ir_loss, ir_data = self.ir_obj(ir_positive.cuda(), pid.cuda(), None)
+        # rgb_loss, rgb_data = self.ir_obj(anchor_rgb.cuda(), pid.cuda(), None)
+        # ir_prec, = accuracy(ir_data.data, targets.data)
+
+        return triplet_loss, 0, l1_penalty_loss, 0  # domain_prec#[0]
+
+
+
+class StaticMemBankTrainer(BaseTrainer):
+    def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
+        super().__init__()
+        self.shared_model = shared_model
+        self.ir_obj = ir_obj
+        self.is_adapt = is_adapt
+        self.mode = mode
+
+    def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
+        self.shared_model.train()
+
+
+
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        ir_losses = AverageMeter()
+        rgb_losses = AverageMeter()
+        domain_losses = AverageMeter()
+        part_losses = AverageMeter()
+        ir_precisions = AverageMeter()
+        domain_precisions = AverageMeter()
+        end = time.time()
+
+        for i, (ir_inputs) in enumerate(combined_loader[0]):
+            # if i == 200:
+            #     break
+            data_time.update(time.time() - end)
+
+            ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
+            # rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
+
+            ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, ir_targets)
+                                                                       #rgb_inputs, rgb_targets)
+
+            ir_losses.update(ir_loss.item(), ir_targets.size(0))
+            domain_losses.update(domain_loss.item(), ir_targets.size(0))
+            ir_precisions.update(ir_prec, ir_targets.size(0))
+            domain_precisions.update(domain_prec, ir_targets.size(0))
+
+            combined_loss = ir_loss #+ domain_loss
+            optimizer.zero_grad()
+            combined_loss.backward()
+            optimizer.step()
+
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if (i + 1) % print_freq == 0:
+
+                print('Epoch: [{}][{}/{}]\t'
+                      'Time {:.3f} ({:.3f})\t'
+                      'Data {:.3f} ({:.3f})\t'
+                      'IR Loss {:.3f} ({:.3f})\t'
+                      'IR Prec {:.2%} ({:.2%})\t'
+                      'RGB Loss {:.3f} ({:.3f})\t'
+                      'RGB Prec {:.2%} ({:.2%})\t'
+                      'LR :{:.3f}\t\t\t'
+                      .format(epochs, i + 1, len(combined_loader[0]),
+                              batch_time.val, batch_time.avg,
+                              data_time.val, data_time.avg,
+                              ir_losses.val, ir_losses.avg,
+                              ir_precisions.val, ir_precisions.avg,
+                              domain_losses.val, domain_losses.avg,
+                              domain_precisions.val, domain_precisions.avg,
+                              lr), end='\r')
+
+
+    def _parse_data(self, inputs):
+        imgs, pids, camids, clusterid, condition = inputs
+        # clusterid = pids
+        imgs.requires_grad = False
+        clusterid = [int(i) for i in clusterid]
+        clusterid = torch.IntTensor(clusterid).long()
+        imgs, clusterid = imgs.cuda(), clusterid.cuda()
+        return imgs, clusterid, camids
+
+    def _forward(self, ir_inputs, ir_targets):
+        domain_class, ir_fmap = self.shared_model(ir_inputs, modal=2)
+        ir_loss, ir_data = self.ir_obj(ir_fmap, ir_targets, None)
+        ir_prec, = accuracy(ir_data.data, ir_targets.data)
+        # print(f"rgb_inputs: {rgb_inputs.shape}")
+        # domain_class, rgb_fmap = self.shared_model(rgb_inputs, modal=1)
+        # rgb_loss, rgb_data = self.ir_obj(rgb_fmap, rgb_targets, None)
+        # # print(f"rgb_FMAP: {rgb_fmap.shape}")
+        # rgb_prec, = accuracy(rgb_data.data, rgb_targets.data)
+
+        return ir_loss, ir_prec[0], ir_loss, ir_prec[0]#domain_prec#[0]
+
+class BaselineTrainer(BaseTrainer):
+    def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
+        super().__init__()
+        self.shared_model = shared_model
+        self.ir_obj = ir_obj
+        self.is_adapt = is_adapt
+        self.mode = mode
+
+    def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
+        self.shared_model.train()
+
+
+        # if early_stopping:
+        #     print(f"Early stopping at epoch: {epoch}")
+        #     break
+
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        ir_losses = AverageMeter()
+        rgb_losses = AverageMeter()
+        domain_losses = AverageMeter()
+        part_losses = AverageMeter()
+        ir_precisions = AverageMeter()
+        domain_precisions = AverageMeter()
+        end = time.time()
+
+        for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
+            # if i == 200:
+            #     break
+            data_time.update(time.time() - end)
+
+            ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
+            rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
+
+
+            ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, ir_targets,
+                                                                       rgb_inputs, rgb_targets)
+
+            ir_losses.update(ir_loss.item(), ir_targets.size(0))
+            domain_losses.update(domain_loss.item(), ir_targets.size(0))
+            ir_precisions.update(ir_prec, ir_targets.size(0))
+            domain_precisions.update(domain_prec, ir_targets.size(0))
+
+            combined_loss = ir_loss + domain_loss
+            optimizer.zero_grad()
+            combined_loss.backward()
+            optimizer.step()
+
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if (i + 1) % print_freq == 0:
+
+                print('Epoch: [{}][{}/{}]\t'
+                      'Time {:.3f} ({:.3f})\t'
+                      'Data {:.3f} ({:.3f})\t'
+                      'IR Loss {:.3f} ({:.3f})\t'
+                      'IR Prec {:.2%} ({:.2%})\t'
+                      'RGB Loss {:.3f} ({:.3f})\t'
+                      'RGB Prec {:.2%} ({:.2%})\t'
+                      'LR :{:.3f}\t\t\t'
+                      .format(epochs, i + 1, len(combined_loader[0]),
+                              batch_time.val, batch_time.avg,
+                              data_time.val, data_time.avg,
+                              ir_losses.val, ir_losses.avg,
+                              ir_precisions.val, ir_precisions.avg,
+                              domain_losses.val, domain_losses.avg,
+                              domain_precisions.val, domain_precisions.avg,
+                              lr), end='\r')
+
+
+    def _parse_data(self, inputs):
+        imgs, pids, camids, clusterid, condition = inputs
+        # clusterid = pids
+        imgs.requires_grad = False
+        clusterid = [int(i) for i in clusterid]
+        clusterid = torch.IntTensor(clusterid).long()
+        imgs, clusterid = imgs.cuda(), clusterid.cuda()
+        return imgs, clusterid, camids
+
+    def _forward(self, ir_inputs, ir_targets, rgb_inputs, rgb_targets):
+        domain_class, ir_fmap = self.shared_model(ir_inputs, modal=2)
+        # print(f"ir_inputs: {ir_inputs.shape}")
+        #
+        # print(f"IR_FMAP: {ir_fmap.shape}")
+        ir_loss, ir_data = self.ir_obj(ir_fmap, ir_targets, None)
+        ir_prec, = accuracy(ir_data.data, ir_targets.data)
+        # print(f"rgb_inputs: {rgb_inputs.shape}")
+        domain_class, rgb_fmap = self.shared_model(rgb_inputs, modal=1)
+        rgb_loss, rgb_data = self.ir_obj(rgb_fmap, rgb_targets, None)
+        # print(f"rgb_FMAP: {rgb_fmap.shape}")
+        rgb_prec, = accuracy(rgb_data.data, rgb_targets.data)
+
+        return ir_loss, ir_prec[0], rgb_loss, rgb_prec[0]#domain_prec#[0]
+
+
+class DomainTrainer(BaseTrainer):
+    def __init__(self, shared_model):
+        super().__init__()
+        self.shared_model = shared_model
+
+        self.domain_loss = nn.BCEWithLogitsLoss()
+
+    def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
+        self.shared_model.train()
+
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        ir_losses = AverageMeter()
+        rgb_losses = AverageMeter()
+        domain_losses = AverageMeter()
+        part_losses = AverageMeter()
+        ir_precisions = AverageMeter()
+        domain_precisions = AverageMeter()
+        end = time.time()
+
+        for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
+            data_time.update(time.time() - end)
+
+            ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
+            rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
+
+            ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, rgb_inputs)
+
+            ir_losses.update(ir_loss.item(), ir_targets.size(0))
+            domain_losses.update(domain_loss.item(), ir_targets.size(0))
+            ir_precisions.update(ir_prec, ir_targets.size(0))
+            domain_precisions.update(domain_prec, ir_targets.size(0))
+
+            combined_loss = ir_loss + domain_loss
+            optimizer.zero_grad()
+            combined_loss.backward()
+            optimizer.step()
+
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if (i + 1) % print_freq == 0:
+
+                print('Epoch: [{}][{}/{}]\t'
+                      'Time {:.3f} ({:.3f})\t'
+                      'Data {:.3f} ({:.3f})\t'
+                      'IR Loss {:.3f} ({:.3f})\t'
+                      'IR Prec {:.2%} ({:.2%})\t'
+                      'RGB Loss {:.3f} ({:.3f})\t'
+                      'RGB Prec {:.2%} ({:.2%})\t'
+                      'LR :{:.3f}\t'
+                      .format(epochs, i + 1, len(combined_loader),
+                              batch_time.val, batch_time.avg,
+                              data_time.val, data_time.avg,
+                              ir_losses.val, ir_losses.avg,
+                              ir_precisions.val, ir_precisions.avg,
+                              domain_losses.val, domain_losses.avg,
+                              domain_precisions.val, domain_precisions.avg,
+                              lr))
+
+
+    def _parse_data(self, inputs):
+        imgs, pids, camids, clusterid, condition = inputs
+        # clusterid = pids
+        imgs.requires_grad = False
+        clusterid = [int(i) for i in clusterid]
+        clusterid = torch.IntTensor(clusterid).long()
+        imgs, clusterid = imgs.cuda(), clusterid.cuda()
+        return imgs, clusterid, camids
+
+    def _forward(self, ir_inputs, rgb_inputs):
+        ir_domain_class, ir_fmap = self.shared_model(ir_inputs, is_classifier=True)
+        rgb_domain_class, rgb_fmap = self.shared_model(rgb_inputs, is_classifier=True)
+
+        self.ir_domain_target = torch.ones((ir_inputs.size(0), 1)).cuda()
+        self.rgb_domain_target = torch.zeros((rgb_inputs.size(0), 1)).cuda()
+
+        ir_domain_loss = self.domain_loss(ir_domain_class, self.ir_domain_target)
+        rgb_domain_loss = self.domain_loss(rgb_domain_class, self.rgb_domain_target)
+        # ir_domain_prec, = accuracy(ir_domain_class.data, self.ir_domain_target.data)
+        # rgb_domain_prec, = accuracy(rgb_domain_class.data, self.rgb_domain_target.data)
+
+        ir_domain_prec = accuracy_score(self.ir_domain_target.data.cpu().numpy(), np.round(torch.sigmoid(ir_domain_class.data).cpu().numpy()) )
+        rgb_domain_prec = accuracy_score(self.rgb_domain_target.data.cpu().numpy(), np.round(torch.sigmoid(rgb_domain_class.data).cpu().numpy()))
+
+        return ir_domain_loss, ir_domain_prec, rgb_domain_loss, rgb_domain_prec
+
+
+class SupervisedTrainer(BaseTrainer):
+    def __init__(self, shared_model, ir_obj):
+        super().__init__()
+        self.shared_model = shared_model
+        self.ir_obj = ir_obj
+        self.rgb_obj = ir_obj
+        self.domain_loss = torch.nn.KLDivLoss()
+
+    def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir'):
+        # if epochs >=2:
+        for name, module in self.shared_model.module.CNN.named_modules():
+            for param in module.parameters():
+                param.requires_grad = False
+
+        # for epoch in range(epochs):
+        self.shared_model.train()
+
+
+        # if early_stopping:
+        #     print(f"Early stopping at epoch: {epoch}")
+        #     break
+
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        ir_losses = AverageMeter()
+        rgb_losses = AverageMeter()
+        domain_losses = AverageMeter()
+        part_losses = AverageMeter()
+        ir_precisions = AverageMeter()
+        rgb_precisions = AverageMeter()
+        end = time.time()
+
+        for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
+            data_time.update(time.time() - end)
+
+            ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
+            rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
+
+            if len(ir_inputs) != len(rgb_inputs):
+                continue
+
+            ir_loss, rgb_loss, ir_prec, rgb_prec, domain_loss, part_loss = self._forward([ir_inputs, rgb_inputs],
+                                                                              [ir_targets, rgb_targets],
+                                                                              )
+
+            ir_losses.update(ir_loss.item(), ir_targets.size(0))
+            rgb_losses.update(rgb_loss.item(), rgb_targets.size(0))
+            domain_losses.update(domain_loss.item(), rgb_targets.size(0))
+            part_losses.update(part_loss.item(), rgb_targets.size(0))
+
+            ir_precisions.update(ir_prec, ir_targets.size(0))
+            rgb_precisions.update(rgb_prec, rgb_targets.size(0))
+
+
+            combined_loss = rgb_loss
+
+            optimizer.zero_grad()
+            combined_loss.backward()
+            optimizer.step()
+
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if (i + 1) % print_freq == 0:
+                print('Epoch: [{}][{}/{}]\t'
+                      'Time {:.3f} ({:.3f})\t'
+                      'Data {:.3f} ({:.3f})\t'
+                      'IR Loss {:.3f} ({:.3f})\t'
+                      'RGB Loss {:.3f} ({:.3f})\t'
+                      'Triplet Loss {:.3f} ({:.3f})\t'
+                      'Part Loss {:.3f} ({:.3f})\t'
+                      'IR Prec {:.2%} ({:.2%})\t'
+                      'RGB Prec {:.2%} ({:.2%})\t'
+                      'LR :{:.3f}\t'
+                      .format(epochs, i + 1, len(combined_loader[0]),
+                              batch_time.val, batch_time.avg,
+                              data_time.val, data_time.avg,
+                              ir_losses.val, ir_losses.avg,
+                              rgb_losses.val, rgb_losses.avg,
+                              domain_losses.val, domain_losses.avg,
+                              part_losses.val, part_losses.avg,
+                              ir_precisions.val, ir_precisions.avg,
+                              rgb_precisions.val, rgb_precisions.avg, 0.1))
+
+
+
+    def _parse_data(self, inputs):
+        imgs, pids, camids, realid, clusterid = inputs
+        clusterid = pids
+        imgs.requires_grad = False
+        clusterid = [int(i) for i in clusterid]
+        clusterid = torch.IntTensor(clusterid).long()
+        imgs, clusterid = imgs.cuda(), clusterid.cuda()
+        return imgs, clusterid, camids
+
+    def _forward(self, inputs, targets):
+        _, ir_pooled = self.shared_model(inputs[0])
+        ir_loss, ir_outputs = self.ir_obj(ir_pooled, targets[0], None)
+        ir_prec, = accuracy(ir_outputs.data, targets[0].data)
+
+        _, rgb_pooled = self.shared_model(inputs[1], is_adapt=True)
+        rgb_loss, outputs = self.ir_obj(rgb_pooled, targets[1], None)
+        rgb_prec, = accuracy(outputs.data, targets[1].data)
+
+
+        return ir_loss, rgb_loss, ir_prec[0], rgb_prec[0], ir_loss, ir_loss
+
+
+class IRSupervisedTrainer(BaseTrainer):
+    def __init__(self, shared_model, ir_obj):
+        super().__init__()
+        self.shared_model = shared_model
+        self.ir_obj = ir_obj
+
+    def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
+        # for name, module in self.shared_model.module.CNN.named_modules():
+        #     for param in module.parameters():
+        #         param.requires_grad = False
+        # for epoch in range(epochs):
+        self.shared_model.train()
+
+
+        # if early_stopping:
+        #     print(f"Early stopping at epoch: {epoch}")
+        #     break
+
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        ir_losses = AverageMeter()
+        rgb_losses = AverageMeter()
+        domain_losses = AverageMeter()
+        part_losses = AverageMeter()
+        ir_precisions = AverageMeter()
+        rgb_precisions = AverageMeter()
+        end = time.time()
+
+        for i, ir_inputs in enumerate(combined_loader):
+            data_time.update(time.time() - end)
+
+            ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
+
+            ir_loss, ir_prec = self._forward(ir_inputs, ir_targets)
+
+            ir_losses.update(ir_loss.item(), ir_targets.size(0))
+            ir_precisions.update(ir_prec, ir_targets.size(0))
+
+            combined_loss = ir_loss
+            optimizer.zero_grad()
+            combined_loss.backward()
+            optimizer.step()
+
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if (i + 1) % print_freq == 0:
+
+                print('Epoch: [{}][{}/{}]\t'
+                      'Time {:.3f} ({:.3f})\t'
+                      'Data {:.3f} ({:.3f})\t'
+                      'IR Loss {:.3f} ({:.3f})\t'
+                      'IR Prec {:.2%} ({:.2%})\t'
+                      'LR :{:.3f}\t'
+                      .format(epochs, i + 1, len(combined_loader),
+                              batch_time.val, batch_time.avg,
+                              data_time.val, data_time.avg,
+                              ir_losses.val, ir_losses.avg,
+                              ir_precisions.val, ir_precisions.avg,
+                              lr))
+        return self.ir_obj.M
+
+
+    def _parse_data(self, inputs):
+        imgs, pids, camids, realid, clusterid = inputs
+        clusterid = pids
+        imgs.requires_grad = False
+        clusterid = [int(i) for i in clusterid]
+        clusterid = torch.IntTensor(clusterid).long()
+        imgs, clusterid = imgs.cuda(), clusterid.cuda()
+        return imgs, clusterid, camids
+
+    def _forward(self, inputs, targets):
+        _, ir_pooled = self.shared_model(inputs)
+        ir_loss, ir_outputs = self.ir_obj(ir_pooled, targets, None)
+        ir_prec, = accuracy(ir_outputs.data, targets.data)
+
+        return ir_loss, ir_prec[0]
+
+
+class FCSupervisedTrainer(BaseTrainer):
+    def __init__(self, shared_model, ir_obj, is_adapt=False):
+        super().__init__()
+        self.shared_model = shared_model
+        self.ir_obj = ir_obj
+        self.domain_loss = nn.CrossEntropyLoss().cuda()
+        self.is_adapt = is_adapt
+
+    def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
+        # for name, module in self.shared_model.module.CNN.named_modules():
+        #     for param in module.parameters():
+        #         param.requires_grad = False
+        # for epoch in range(epochs):
+        self.shared_model.train()
+
+
+        # if early_stopping:
+        #     print(f"Early stopping at epoch: {epoch}")
+        #     break
+
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        ir_losses = AverageMeter()
+        rgb_losses = AverageMeter()
+        domain_losses = AverageMeter()
+        part_losses = AverageMeter()
+        ir_precisions = AverageMeter()
+        rgb_precisions = AverageMeter()
+        end = time.time()
+
+        for i, ir_inputs in enumerate(combined_loader):
+            data_time.update(time.time() - end)
+
+            ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
+
+            ir_loss, ir_prec = self._forward(ir_inputs, ir_targets)
+
+            ir_losses.update(ir_loss.item(), ir_targets.size(0))
+            # domain_losses.update(domain_inv_loss.item(), ir_targets.size(0))
+            ir_precisions.update(ir_prec, ir_targets.size(0))
+
+            combined_loss = ir_loss
+            optimizer.zero_grad()
+            combined_loss.backward()
+            optimizer.step()
+
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if (i + 1) % print_freq == 0:
+
+                print('Epoch: [{}][{}/{}]\t'
+                      'Time {:.3f} ({:.3f})\t'
+                      'Data {:.3f} ({:.3f})\t'
+                      'IR Loss {:.3f} ({:.3f})\t'
+                      'IR Prec {:.2%} ({:.2%})\t'
+                      'LR :{:.3f}\t'
+                      .format(epochs, i + 1, len(combined_loader),
+                              batch_time.val, batch_time.avg,
+                              data_time.val, data_time.avg,
+                              ir_losses.val, ir_losses.avg,
+                              ir_precisions.val, ir_precisions.avg,
+                              # domain_losses.val, domain_losses.avg,
+                              lr))
+        return self.ir_obj.M
+
+
+    def _parse_data(self, inputs):
+        imgs, pids, camids, realid, clusterid = inputs
+        clusterid = pids
+        imgs.requires_grad = False
+        clusterid = [int(i) for i in clusterid]
+        clusterid = torch.IntTensor(clusterid).long()
+        imgs, clusterid = imgs.cuda(), clusterid.cuda()
+        return imgs, clusterid, camids
+
+    def _forward(self, inputs, targets):
+        ir_fc, ir_fmap = self.shared_model(inputs, self.is_adapt)
+        ir_loss, ir_data = self.ir_obj(ir_fmap, targets, None)
+        ir_prec, = accuracy(ir_data.data, targets.data)
+
+        return ir_loss, ir_prec[0]#, domain_inv_loss
+
+
+class Trainer(BaseTrainer):
+    def _parse_data(self, inputs):
+        imgs, subid, camera, condition, cluster_id = inputs['image'], inputs['subid'], inputs['cam'], inputs['condition'], inputs['cluster_id']
+        return imgs.cuda(), cluster_id.cuda().type(torch.int64), camera #subid.cuda()
+
+    def _forward(self, inputs, targets, cam):
+        _, outputs = self.model(inputs)
+
+        loss, outputs = self.criterion(outputs, targets, cam) #torch.nn.functional.so(outputs, targets)
+        # outputs = np.argmax(outputs.data.cpu(), axis=1)
+
+        acc, = accuracy(outputs.data, targets.data)
+        return loss, acc[0]
+
+        #
+        # if isinstance(self.criterion, ACL_IDL):
+        #     prec, = accuracy(outputs.data, targets.data[0:int(len(targets.data) // 2)])
+        #     prec = prec[0]
+        # else:
+        #     prec = 0
+        # return loss, prec
+