diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..16b8d4bac2f2f545f77b190f9aa412b9acc65b79 --- /dev/null +++ b/trainer.py @@ -0,0 +1,705 @@ +from __future__ import print_function, absolute_import +import time +import numpy as np +import torch +from sklearn.metrics import accuracy_score +from evaluation.eval import accuracy +from utils.meters import AverageMeter +from torch import nn + +class BaseTrainer(object): + def __init__(self): + super(BaseTrainer, self).__init__() + + def train(self, combined_loader, optimizer, epochs, stage, print_freq=1, mode='ir'): + raise NotImplementedError + + def _parse_data(self, inputs): + raise NotImplementedError + + def _forward(self, inputs, targets, opt): + raise NotImplementedError + + +class UnsupervisedTrainer(BaseTrainer): + def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'): + super().__init__() + self.shared_model = shared_model + self.ir_obj = ir_obj + self.is_adapt = is_adapt + self.mode = mode + + self.ir_targets = torch.ones((32, 1)).cuda() + self.rgb_targets = torch.zeros((32, 1)).cuda() + self.domain_loss = nn.BCEWithLogitsLoss() + self.triplet_loss = nn.TripletMarginLoss(swap=True) + + def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1): + self.shared_model.train() + + # if early_stopping: + # print(f"Early stopping at epoch: {epoch}") + # break + + batch_time = AverageMeter() + data_time = AverageMeter() + ir_losses = AverageMeter() + rgb_losses = AverageMeter() + domain_losses = AverageMeter() + part_losses = AverageMeter() + ir_precisions = AverageMeter() + domain_precisions = AverageMeter() + end = time.time() + + for i, triplet_inputs in enumerate(combined_loader[0]): + if i == 200: + break + data_time.update(time.time() - end) + + # ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs) + anchor_rgb, pid, _, rgb_positive, rgb_negative, ir_negative, ir_positive = triplet_inputs + + ir_loss, ir_prec, domain_loss, domain_prec = self._forward(#ir_inputs, ir_targets, + [anchor_rgb, rgb_positive, rgb_negative, + ir_negative, ir_positive, pid]) + + ir_losses.update(ir_loss.item(), anchor_rgb.size(0)) + domain_losses.update(domain_loss, anchor_rgb.size(0)) + ir_precisions.update(ir_prec, anchor_rgb.size(0)) + domain_precisions.update(domain_prec, anchor_rgb.size(0)) + + combined_loss = ir_loss + domain_loss + optimizer.zero_grad() + combined_loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if True:#(i + 1) % print_freq == 0: + print('Epoch: [{}][{}/{}]\t' + 'Time {:.3f} ({:.3f})\t' + 'Data {:.3f} ({:.3f})\t' + 'IR Loss {:.3f} ({:.3f})\t' + 'IR Prec {:.2%} ({:.2%})\t' + 'Domain Loss {:.3f} ({:.3f})\t' + 'Domain Prec {:.2%} ({:.2%})\t' + 'LR :{:.3f}\t\t' + .format(epochs, i + 1, len(combined_loader), + batch_time.val, batch_time.avg, + data_time.val, data_time.avg, + ir_losses.val, ir_losses.avg, + ir_precisions.val, ir_precisions.avg, + domain_losses.val, domain_losses.avg, + domain_precisions.val, domain_precisions.avg, + lr), end='\r') + + def _parse_data(self, inputs): + imgs, pids, camids, clusterid, condition = inputs + # clusterid = pids + imgs.requires_grad = False + clusterid = [int(i) for i in clusterid] + clusterid = torch.IntTensor(clusterid).long() + imgs, clusterid = imgs.cuda(), clusterid.cuda() + return imgs, clusterid, camids + + def _forward(self, triplets): + # domain_class, ir_fmap = self.shared_model(inputs, is_classifier=True) + # ir_loss, ir_data = self.ir_obj(ir_fmap, targets, None) + # ir_prec, = accuracy(ir_data.data, targets.data) + # self.rgb_targets = torch.zeros((inputs.size(0), 1)).cuda() + + # domain_loss = self.domain_loss(domain_class, self.rgb_targets) + + # domain_prec = accuracy_score(self.rgb_targets.data.cpu().numpy(), np.round(torch.sigmoid(domain_class.data).cpu().numpy())) + # domain_prec, = accuracy(domain_class.data, self.rgb_targets.data) + + anchor_rgb, rgb_positive, rgb_negative, ir_negative, ir_positive, pid = triplets + _, anchor_rgb = self.shared_model(anchor_rgb.cuda(), modal=1) + # _, rgb_positive = self.shared_model(rgb_positive.cuda(), modal=1) + _, rgb_negative = self.shared_model(rgb_negative.cuda(), modal=1) + _, ir_negative = self.shared_model(ir_negative.cuda(), modal=2) + _, ir_positive = self.shared_model(ir_positive.cuda(), modal=2) + triplet_loss = self.triplet_loss(anchor_rgb, ir_positive, rgb_negative) # + + # triplet_loss = self.triplet_loss(ir_positive, anchor_rgb, ir_negative) + + l1_lambda = 0.000001#0.000001 + l1_penalty_loss = 0 + for output in ir_positive: + l1_penalty_loss += torch.norm(output, 1) + l1_penalty_loss *= l1_lambda + + # ir_loss, ir_data = self.ir_obj(ir_positive.cuda(), pid.cuda(), None) + # rgb_loss, rgb_data = self.ir_obj(anchor_rgb.cuda(), pid.cuda(), None) + # ir_prec, = accuracy(ir_data.data, targets.data) + + return triplet_loss, 0, l1_penalty_loss, 0 # domain_prec#[0] + + + +class StaticMemBankTrainer(BaseTrainer): + def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'): + super().__init__() + self.shared_model = shared_model + self.ir_obj = ir_obj + self.is_adapt = is_adapt + self.mode = mode + + def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1): + self.shared_model.train() + + + + batch_time = AverageMeter() + data_time = AverageMeter() + ir_losses = AverageMeter() + rgb_losses = AverageMeter() + domain_losses = AverageMeter() + part_losses = AverageMeter() + ir_precisions = AverageMeter() + domain_precisions = AverageMeter() + end = time.time() + + for i, (ir_inputs) in enumerate(combined_loader[0]): + # if i == 200: + # break + data_time.update(time.time() - end) + + ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs) + # rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs) + + ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, ir_targets) + #rgb_inputs, rgb_targets) + + ir_losses.update(ir_loss.item(), ir_targets.size(0)) + domain_losses.update(domain_loss.item(), ir_targets.size(0)) + ir_precisions.update(ir_prec, ir_targets.size(0)) + domain_precisions.update(domain_prec, ir_targets.size(0)) + + combined_loss = ir_loss #+ domain_loss + optimizer.zero_grad() + combined_loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if (i + 1) % print_freq == 0: + + print('Epoch: [{}][{}/{}]\t' + 'Time {:.3f} ({:.3f})\t' + 'Data {:.3f} ({:.3f})\t' + 'IR Loss {:.3f} ({:.3f})\t' + 'IR Prec {:.2%} ({:.2%})\t' + 'RGB Loss {:.3f} ({:.3f})\t' + 'RGB Prec {:.2%} ({:.2%})\t' + 'LR :{:.3f}\t\t\t' + .format(epochs, i + 1, len(combined_loader[0]), + batch_time.val, batch_time.avg, + data_time.val, data_time.avg, + ir_losses.val, ir_losses.avg, + ir_precisions.val, ir_precisions.avg, + domain_losses.val, domain_losses.avg, + domain_precisions.val, domain_precisions.avg, + lr), end='\r') + + + def _parse_data(self, inputs): + imgs, pids, camids, clusterid, condition = inputs + # clusterid = pids + imgs.requires_grad = False + clusterid = [int(i) for i in clusterid] + clusterid = torch.IntTensor(clusterid).long() + imgs, clusterid = imgs.cuda(), clusterid.cuda() + return imgs, clusterid, camids + + def _forward(self, ir_inputs, ir_targets): + domain_class, ir_fmap = self.shared_model(ir_inputs, modal=2) + ir_loss, ir_data = self.ir_obj(ir_fmap, ir_targets, None) + ir_prec, = accuracy(ir_data.data, ir_targets.data) + # print(f"rgb_inputs: {rgb_inputs.shape}") + # domain_class, rgb_fmap = self.shared_model(rgb_inputs, modal=1) + # rgb_loss, rgb_data = self.ir_obj(rgb_fmap, rgb_targets, None) + # # print(f"rgb_FMAP: {rgb_fmap.shape}") + # rgb_prec, = accuracy(rgb_data.data, rgb_targets.data) + + return ir_loss, ir_prec[0], ir_loss, ir_prec[0]#domain_prec#[0] + +class BaselineTrainer(BaseTrainer): + def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'): + super().__init__() + self.shared_model = shared_model + self.ir_obj = ir_obj + self.is_adapt = is_adapt + self.mode = mode + + def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1): + self.shared_model.train() + + + # if early_stopping: + # print(f"Early stopping at epoch: {epoch}") + # break + + batch_time = AverageMeter() + data_time = AverageMeter() + ir_losses = AverageMeter() + rgb_losses = AverageMeter() + domain_losses = AverageMeter() + part_losses = AverageMeter() + ir_precisions = AverageMeter() + domain_precisions = AverageMeter() + end = time.time() + + for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])): + # if i == 200: + # break + data_time.update(time.time() - end) + + ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs) + rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs) + + + ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, ir_targets, + rgb_inputs, rgb_targets) + + ir_losses.update(ir_loss.item(), ir_targets.size(0)) + domain_losses.update(domain_loss.item(), ir_targets.size(0)) + ir_precisions.update(ir_prec, ir_targets.size(0)) + domain_precisions.update(domain_prec, ir_targets.size(0)) + + combined_loss = ir_loss + domain_loss + optimizer.zero_grad() + combined_loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if (i + 1) % print_freq == 0: + + print('Epoch: [{}][{}/{}]\t' + 'Time {:.3f} ({:.3f})\t' + 'Data {:.3f} ({:.3f})\t' + 'IR Loss {:.3f} ({:.3f})\t' + 'IR Prec {:.2%} ({:.2%})\t' + 'RGB Loss {:.3f} ({:.3f})\t' + 'RGB Prec {:.2%} ({:.2%})\t' + 'LR :{:.3f}\t\t\t' + .format(epochs, i + 1, len(combined_loader[0]), + batch_time.val, batch_time.avg, + data_time.val, data_time.avg, + ir_losses.val, ir_losses.avg, + ir_precisions.val, ir_precisions.avg, + domain_losses.val, domain_losses.avg, + domain_precisions.val, domain_precisions.avg, + lr), end='\r') + + + def _parse_data(self, inputs): + imgs, pids, camids, clusterid, condition = inputs + # clusterid = pids + imgs.requires_grad = False + clusterid = [int(i) for i in clusterid] + clusterid = torch.IntTensor(clusterid).long() + imgs, clusterid = imgs.cuda(), clusterid.cuda() + return imgs, clusterid, camids + + def _forward(self, ir_inputs, ir_targets, rgb_inputs, rgb_targets): + domain_class, ir_fmap = self.shared_model(ir_inputs, modal=2) + # print(f"ir_inputs: {ir_inputs.shape}") + # + # print(f"IR_FMAP: {ir_fmap.shape}") + ir_loss, ir_data = self.ir_obj(ir_fmap, ir_targets, None) + ir_prec, = accuracy(ir_data.data, ir_targets.data) + # print(f"rgb_inputs: {rgb_inputs.shape}") + domain_class, rgb_fmap = self.shared_model(rgb_inputs, modal=1) + rgb_loss, rgb_data = self.ir_obj(rgb_fmap, rgb_targets, None) + # print(f"rgb_FMAP: {rgb_fmap.shape}") + rgb_prec, = accuracy(rgb_data.data, rgb_targets.data) + + return ir_loss, ir_prec[0], rgb_loss, rgb_prec[0]#domain_prec#[0] + + +class DomainTrainer(BaseTrainer): + def __init__(self, shared_model): + super().__init__() + self.shared_model = shared_model + + self.domain_loss = nn.BCEWithLogitsLoss() + + def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1): + self.shared_model.train() + + batch_time = AverageMeter() + data_time = AverageMeter() + ir_losses = AverageMeter() + rgb_losses = AverageMeter() + domain_losses = AverageMeter() + part_losses = AverageMeter() + ir_precisions = AverageMeter() + domain_precisions = AverageMeter() + end = time.time() + + for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])): + data_time.update(time.time() - end) + + ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs) + rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs) + + ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, rgb_inputs) + + ir_losses.update(ir_loss.item(), ir_targets.size(0)) + domain_losses.update(domain_loss.item(), ir_targets.size(0)) + ir_precisions.update(ir_prec, ir_targets.size(0)) + domain_precisions.update(domain_prec, ir_targets.size(0)) + + combined_loss = ir_loss + domain_loss + optimizer.zero_grad() + combined_loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if (i + 1) % print_freq == 0: + + print('Epoch: [{}][{}/{}]\t' + 'Time {:.3f} ({:.3f})\t' + 'Data {:.3f} ({:.3f})\t' + 'IR Loss {:.3f} ({:.3f})\t' + 'IR Prec {:.2%} ({:.2%})\t' + 'RGB Loss {:.3f} ({:.3f})\t' + 'RGB Prec {:.2%} ({:.2%})\t' + 'LR :{:.3f}\t' + .format(epochs, i + 1, len(combined_loader), + batch_time.val, batch_time.avg, + data_time.val, data_time.avg, + ir_losses.val, ir_losses.avg, + ir_precisions.val, ir_precisions.avg, + domain_losses.val, domain_losses.avg, + domain_precisions.val, domain_precisions.avg, + lr)) + + + def _parse_data(self, inputs): + imgs, pids, camids, clusterid, condition = inputs + # clusterid = pids + imgs.requires_grad = False + clusterid = [int(i) for i in clusterid] + clusterid = torch.IntTensor(clusterid).long() + imgs, clusterid = imgs.cuda(), clusterid.cuda() + return imgs, clusterid, camids + + def _forward(self, ir_inputs, rgb_inputs): + ir_domain_class, ir_fmap = self.shared_model(ir_inputs, is_classifier=True) + rgb_domain_class, rgb_fmap = self.shared_model(rgb_inputs, is_classifier=True) + + self.ir_domain_target = torch.ones((ir_inputs.size(0), 1)).cuda() + self.rgb_domain_target = torch.zeros((rgb_inputs.size(0), 1)).cuda() + + ir_domain_loss = self.domain_loss(ir_domain_class, self.ir_domain_target) + rgb_domain_loss = self.domain_loss(rgb_domain_class, self.rgb_domain_target) + # ir_domain_prec, = accuracy(ir_domain_class.data, self.ir_domain_target.data) + # rgb_domain_prec, = accuracy(rgb_domain_class.data, self.rgb_domain_target.data) + + ir_domain_prec = accuracy_score(self.ir_domain_target.data.cpu().numpy(), np.round(torch.sigmoid(ir_domain_class.data).cpu().numpy()) ) + rgb_domain_prec = accuracy_score(self.rgb_domain_target.data.cpu().numpy(), np.round(torch.sigmoid(rgb_domain_class.data).cpu().numpy())) + + return ir_domain_loss, ir_domain_prec, rgb_domain_loss, rgb_domain_prec + + +class SupervisedTrainer(BaseTrainer): + def __init__(self, shared_model, ir_obj): + super().__init__() + self.shared_model = shared_model + self.ir_obj = ir_obj + self.rgb_obj = ir_obj + self.domain_loss = torch.nn.KLDivLoss() + + def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir'): + # if epochs >=2: + for name, module in self.shared_model.module.CNN.named_modules(): + for param in module.parameters(): + param.requires_grad = False + + # for epoch in range(epochs): + self.shared_model.train() + + + # if early_stopping: + # print(f"Early stopping at epoch: {epoch}") + # break + + batch_time = AverageMeter() + data_time = AverageMeter() + ir_losses = AverageMeter() + rgb_losses = AverageMeter() + domain_losses = AverageMeter() + part_losses = AverageMeter() + ir_precisions = AverageMeter() + rgb_precisions = AverageMeter() + end = time.time() + + for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])): + data_time.update(time.time() - end) + + ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs) + rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs) + + if len(ir_inputs) != len(rgb_inputs): + continue + + ir_loss, rgb_loss, ir_prec, rgb_prec, domain_loss, part_loss = self._forward([ir_inputs, rgb_inputs], + [ir_targets, rgb_targets], + ) + + ir_losses.update(ir_loss.item(), ir_targets.size(0)) + rgb_losses.update(rgb_loss.item(), rgb_targets.size(0)) + domain_losses.update(domain_loss.item(), rgb_targets.size(0)) + part_losses.update(part_loss.item(), rgb_targets.size(0)) + + ir_precisions.update(ir_prec, ir_targets.size(0)) + rgb_precisions.update(rgb_prec, rgb_targets.size(0)) + + + combined_loss = rgb_loss + + optimizer.zero_grad() + combined_loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if (i + 1) % print_freq == 0: + print('Epoch: [{}][{}/{}]\t' + 'Time {:.3f} ({:.3f})\t' + 'Data {:.3f} ({:.3f})\t' + 'IR Loss {:.3f} ({:.3f})\t' + 'RGB Loss {:.3f} ({:.3f})\t' + 'Triplet Loss {:.3f} ({:.3f})\t' + 'Part Loss {:.3f} ({:.3f})\t' + 'IR Prec {:.2%} ({:.2%})\t' + 'RGB Prec {:.2%} ({:.2%})\t' + 'LR :{:.3f}\t' + .format(epochs, i + 1, len(combined_loader[0]), + batch_time.val, batch_time.avg, + data_time.val, data_time.avg, + ir_losses.val, ir_losses.avg, + rgb_losses.val, rgb_losses.avg, + domain_losses.val, domain_losses.avg, + part_losses.val, part_losses.avg, + ir_precisions.val, ir_precisions.avg, + rgb_precisions.val, rgb_precisions.avg, 0.1)) + + + + def _parse_data(self, inputs): + imgs, pids, camids, realid, clusterid = inputs + clusterid = pids + imgs.requires_grad = False + clusterid = [int(i) for i in clusterid] + clusterid = torch.IntTensor(clusterid).long() + imgs, clusterid = imgs.cuda(), clusterid.cuda() + return imgs, clusterid, camids + + def _forward(self, inputs, targets): + _, ir_pooled = self.shared_model(inputs[0]) + ir_loss, ir_outputs = self.ir_obj(ir_pooled, targets[0], None) + ir_prec, = accuracy(ir_outputs.data, targets[0].data) + + _, rgb_pooled = self.shared_model(inputs[1], is_adapt=True) + rgb_loss, outputs = self.ir_obj(rgb_pooled, targets[1], None) + rgb_prec, = accuracy(outputs.data, targets[1].data) + + + return ir_loss, rgb_loss, ir_prec[0], rgb_prec[0], ir_loss, ir_loss + + +class IRSupervisedTrainer(BaseTrainer): + def __init__(self, shared_model, ir_obj): + super().__init__() + self.shared_model = shared_model + self.ir_obj = ir_obj + + def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1): + # for name, module in self.shared_model.module.CNN.named_modules(): + # for param in module.parameters(): + # param.requires_grad = False + # for epoch in range(epochs): + self.shared_model.train() + + + # if early_stopping: + # print(f"Early stopping at epoch: {epoch}") + # break + + batch_time = AverageMeter() + data_time = AverageMeter() + ir_losses = AverageMeter() + rgb_losses = AverageMeter() + domain_losses = AverageMeter() + part_losses = AverageMeter() + ir_precisions = AverageMeter() + rgb_precisions = AverageMeter() + end = time.time() + + for i, ir_inputs in enumerate(combined_loader): + data_time.update(time.time() - end) + + ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs) + + ir_loss, ir_prec = self._forward(ir_inputs, ir_targets) + + ir_losses.update(ir_loss.item(), ir_targets.size(0)) + ir_precisions.update(ir_prec, ir_targets.size(0)) + + combined_loss = ir_loss + optimizer.zero_grad() + combined_loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if (i + 1) % print_freq == 0: + + print('Epoch: [{}][{}/{}]\t' + 'Time {:.3f} ({:.3f})\t' + 'Data {:.3f} ({:.3f})\t' + 'IR Loss {:.3f} ({:.3f})\t' + 'IR Prec {:.2%} ({:.2%})\t' + 'LR :{:.3f}\t' + .format(epochs, i + 1, len(combined_loader), + batch_time.val, batch_time.avg, + data_time.val, data_time.avg, + ir_losses.val, ir_losses.avg, + ir_precisions.val, ir_precisions.avg, + lr)) + return self.ir_obj.M + + + def _parse_data(self, inputs): + imgs, pids, camids, realid, clusterid = inputs + clusterid = pids + imgs.requires_grad = False + clusterid = [int(i) for i in clusterid] + clusterid = torch.IntTensor(clusterid).long() + imgs, clusterid = imgs.cuda(), clusterid.cuda() + return imgs, clusterid, camids + + def _forward(self, inputs, targets): + _, ir_pooled = self.shared_model(inputs) + ir_loss, ir_outputs = self.ir_obj(ir_pooled, targets, None) + ir_prec, = accuracy(ir_outputs.data, targets.data) + + return ir_loss, ir_prec[0] + + +class FCSupervisedTrainer(BaseTrainer): + def __init__(self, shared_model, ir_obj, is_adapt=False): + super().__init__() + self.shared_model = shared_model + self.ir_obj = ir_obj + self.domain_loss = nn.CrossEntropyLoss().cuda() + self.is_adapt = is_adapt + + def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1): + # for name, module in self.shared_model.module.CNN.named_modules(): + # for param in module.parameters(): + # param.requires_grad = False + # for epoch in range(epochs): + self.shared_model.train() + + + # if early_stopping: + # print(f"Early stopping at epoch: {epoch}") + # break + + batch_time = AverageMeter() + data_time = AverageMeter() + ir_losses = AverageMeter() + rgb_losses = AverageMeter() + domain_losses = AverageMeter() + part_losses = AverageMeter() + ir_precisions = AverageMeter() + rgb_precisions = AverageMeter() + end = time.time() + + for i, ir_inputs in enumerate(combined_loader): + data_time.update(time.time() - end) + + ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs) + + ir_loss, ir_prec = self._forward(ir_inputs, ir_targets) + + ir_losses.update(ir_loss.item(), ir_targets.size(0)) + # domain_losses.update(domain_inv_loss.item(), ir_targets.size(0)) + ir_precisions.update(ir_prec, ir_targets.size(0)) + + combined_loss = ir_loss + optimizer.zero_grad() + combined_loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if (i + 1) % print_freq == 0: + + print('Epoch: [{}][{}/{}]\t' + 'Time {:.3f} ({:.3f})\t' + 'Data {:.3f} ({:.3f})\t' + 'IR Loss {:.3f} ({:.3f})\t' + 'IR Prec {:.2%} ({:.2%})\t' + 'LR :{:.3f}\t' + .format(epochs, i + 1, len(combined_loader), + batch_time.val, batch_time.avg, + data_time.val, data_time.avg, + ir_losses.val, ir_losses.avg, + ir_precisions.val, ir_precisions.avg, + # domain_losses.val, domain_losses.avg, + lr)) + return self.ir_obj.M + + + def _parse_data(self, inputs): + imgs, pids, camids, realid, clusterid = inputs + clusterid = pids + imgs.requires_grad = False + clusterid = [int(i) for i in clusterid] + clusterid = torch.IntTensor(clusterid).long() + imgs, clusterid = imgs.cuda(), clusterid.cuda() + return imgs, clusterid, camids + + def _forward(self, inputs, targets): + ir_fc, ir_fmap = self.shared_model(inputs, self.is_adapt) + ir_loss, ir_data = self.ir_obj(ir_fmap, targets, None) + ir_prec, = accuracy(ir_data.data, targets.data) + + return ir_loss, ir_prec[0]#, domain_inv_loss + + +class Trainer(BaseTrainer): + def _parse_data(self, inputs): + imgs, subid, camera, condition, cluster_id = inputs['image'], inputs['subid'], inputs['cam'], inputs['condition'], inputs['cluster_id'] + return imgs.cuda(), cluster_id.cuda().type(torch.int64), camera #subid.cuda() + + def _forward(self, inputs, targets, cam): + _, outputs = self.model(inputs) + + loss, outputs = self.criterion(outputs, targets, cam) #torch.nn.functional.so(outputs, targets) + # outputs = np.argmax(outputs.data.cpu(), axis=1) + + acc, = accuracy(outputs.data, targets.data) + return loss, acc[0] + + # + # if isinstance(self.criterion, ACL_IDL): + # prec, = accuracy(outputs.data, targets.data[0:int(len(targets.data) // 2)]) + # prec = prec[0] + # else: + # prec = 0 + # return loss, prec +