Skip to content
Snippets Groups Projects
Commit c798f3fd authored by Kshitij Nikhal's avatar Kshitij Nikhal
Browse files

add trainer

parent 00ff70ad
Branches
Tags v1.0.4
No related merge requests found
from __future__ import print_function, absolute_import
import time
import numpy as np
import torch
from sklearn.metrics import accuracy_score
from evaluation.eval import accuracy
from utils.meters import AverageMeter
from torch import nn
class BaseTrainer(object):
def __init__(self):
super(BaseTrainer, self).__init__()
def train(self, combined_loader, optimizer, epochs, stage, print_freq=1, mode='ir'):
raise NotImplementedError
def _parse_data(self, inputs):
raise NotImplementedError
def _forward(self, inputs, targets, opt):
raise NotImplementedError
class UnsupervisedTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.is_adapt = is_adapt
self.mode = mode
self.ir_targets = torch.ones((32, 1)).cuda()
self.rgb_targets = torch.zeros((32, 1)).cuda()
self.domain_loss = nn.BCEWithLogitsLoss()
self.triplet_loss = nn.TripletMarginLoss(swap=True)
def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
domain_precisions = AverageMeter()
end = time.time()
for i, triplet_inputs in enumerate(combined_loader[0]):
if i == 200:
break
data_time.update(time.time() - end)
# ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
anchor_rgb, pid, _, rgb_positive, rgb_negative, ir_negative, ir_positive = triplet_inputs
ir_loss, ir_prec, domain_loss, domain_prec = self._forward(#ir_inputs, ir_targets,
[anchor_rgb, rgb_positive, rgb_negative,
ir_negative, ir_positive, pid])
ir_losses.update(ir_loss.item(), anchor_rgb.size(0))
domain_losses.update(domain_loss, anchor_rgb.size(0))
ir_precisions.update(ir_prec, anchor_rgb.size(0))
domain_precisions.update(domain_prec, anchor_rgb.size(0))
combined_loss = ir_loss + domain_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if True:#(i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'Domain Loss {:.3f} ({:.3f})\t'
'Domain Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t\t'
.format(epochs, i + 1, len(combined_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
domain_losses.val, domain_losses.avg,
domain_precisions.val, domain_precisions.avg,
lr), end='\r')
def _parse_data(self, inputs):
imgs, pids, camids, clusterid, condition = inputs
# clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, triplets):
# domain_class, ir_fmap = self.shared_model(inputs, is_classifier=True)
# ir_loss, ir_data = self.ir_obj(ir_fmap, targets, None)
# ir_prec, = accuracy(ir_data.data, targets.data)
# self.rgb_targets = torch.zeros((inputs.size(0), 1)).cuda()
# domain_loss = self.domain_loss(domain_class, self.rgb_targets)
# domain_prec = accuracy_score(self.rgb_targets.data.cpu().numpy(), np.round(torch.sigmoid(domain_class.data).cpu().numpy()))
# domain_prec, = accuracy(domain_class.data, self.rgb_targets.data)
anchor_rgb, rgb_positive, rgb_negative, ir_negative, ir_positive, pid = triplets
_, anchor_rgb = self.shared_model(anchor_rgb.cuda(), modal=1)
# _, rgb_positive = self.shared_model(rgb_positive.cuda(), modal=1)
_, rgb_negative = self.shared_model(rgb_negative.cuda(), modal=1)
_, ir_negative = self.shared_model(ir_negative.cuda(), modal=2)
_, ir_positive = self.shared_model(ir_positive.cuda(), modal=2)
triplet_loss = self.triplet_loss(anchor_rgb, ir_positive, rgb_negative) # +
# triplet_loss = self.triplet_loss(ir_positive, anchor_rgb, ir_negative)
l1_lambda = 0.000001#0.000001
l1_penalty_loss = 0
for output in ir_positive:
l1_penalty_loss += torch.norm(output, 1)
l1_penalty_loss *= l1_lambda
# ir_loss, ir_data = self.ir_obj(ir_positive.cuda(), pid.cuda(), None)
# rgb_loss, rgb_data = self.ir_obj(anchor_rgb.cuda(), pid.cuda(), None)
# ir_prec, = accuracy(ir_data.data, targets.data)
return triplet_loss, 0, l1_penalty_loss, 0 # domain_prec#[0]
class StaticMemBankTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.is_adapt = is_adapt
self.mode = mode
def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
self.shared_model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
domain_precisions = AverageMeter()
end = time.time()
for i, (ir_inputs) in enumerate(combined_loader[0]):
# if i == 200:
# break
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
# rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, ir_targets)
#rgb_inputs, rgb_targets)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
domain_losses.update(domain_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
domain_precisions.update(domain_prec, ir_targets.size(0))
combined_loss = ir_loss #+ domain_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'RGB Loss {:.3f} ({:.3f})\t'
'RGB Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t\t\t'
.format(epochs, i + 1, len(combined_loader[0]),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
domain_losses.val, domain_losses.avg,
domain_precisions.val, domain_precisions.avg,
lr), end='\r')
def _parse_data(self, inputs):
imgs, pids, camids, clusterid, condition = inputs
# clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, ir_inputs, ir_targets):
domain_class, ir_fmap = self.shared_model(ir_inputs, modal=2)
ir_loss, ir_data = self.ir_obj(ir_fmap, ir_targets, None)
ir_prec, = accuracy(ir_data.data, ir_targets.data)
# print(f"rgb_inputs: {rgb_inputs.shape}")
# domain_class, rgb_fmap = self.shared_model(rgb_inputs, modal=1)
# rgb_loss, rgb_data = self.ir_obj(rgb_fmap, rgb_targets, None)
# # print(f"rgb_FMAP: {rgb_fmap.shape}")
# rgb_prec, = accuracy(rgb_data.data, rgb_targets.data)
return ir_loss, ir_prec[0], ir_loss, ir_prec[0]#domain_prec#[0]
class BaselineTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.is_adapt = is_adapt
self.mode = mode
def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
domain_precisions = AverageMeter()
end = time.time()
for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
# if i == 200:
# break
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, ir_targets,
rgb_inputs, rgb_targets)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
domain_losses.update(domain_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
domain_precisions.update(domain_prec, ir_targets.size(0))
combined_loss = ir_loss + domain_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'RGB Loss {:.3f} ({:.3f})\t'
'RGB Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t\t\t'
.format(epochs, i + 1, len(combined_loader[0]),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
domain_losses.val, domain_losses.avg,
domain_precisions.val, domain_precisions.avg,
lr), end='\r')
def _parse_data(self, inputs):
imgs, pids, camids, clusterid, condition = inputs
# clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, ir_inputs, ir_targets, rgb_inputs, rgb_targets):
domain_class, ir_fmap = self.shared_model(ir_inputs, modal=2)
# print(f"ir_inputs: {ir_inputs.shape}")
#
# print(f"IR_FMAP: {ir_fmap.shape}")
ir_loss, ir_data = self.ir_obj(ir_fmap, ir_targets, None)
ir_prec, = accuracy(ir_data.data, ir_targets.data)
# print(f"rgb_inputs: {rgb_inputs.shape}")
domain_class, rgb_fmap = self.shared_model(rgb_inputs, modal=1)
rgb_loss, rgb_data = self.ir_obj(rgb_fmap, rgb_targets, None)
# print(f"rgb_FMAP: {rgb_fmap.shape}")
rgb_prec, = accuracy(rgb_data.data, rgb_targets.data)
return ir_loss, ir_prec[0], rgb_loss, rgb_prec[0]#domain_prec#[0]
class DomainTrainer(BaseTrainer):
def __init__(self, shared_model):
super().__init__()
self.shared_model = shared_model
self.domain_loss = nn.BCEWithLogitsLoss()
def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
self.shared_model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
domain_precisions = AverageMeter()
end = time.time()
for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, rgb_inputs)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
domain_losses.update(domain_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
domain_precisions.update(domain_prec, ir_targets.size(0))
combined_loss = ir_loss + domain_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'RGB Loss {:.3f} ({:.3f})\t'
'RGB Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t'
.format(epochs, i + 1, len(combined_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
domain_losses.val, domain_losses.avg,
domain_precisions.val, domain_precisions.avg,
lr))
def _parse_data(self, inputs):
imgs, pids, camids, clusterid, condition = inputs
# clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, ir_inputs, rgb_inputs):
ir_domain_class, ir_fmap = self.shared_model(ir_inputs, is_classifier=True)
rgb_domain_class, rgb_fmap = self.shared_model(rgb_inputs, is_classifier=True)
self.ir_domain_target = torch.ones((ir_inputs.size(0), 1)).cuda()
self.rgb_domain_target = torch.zeros((rgb_inputs.size(0), 1)).cuda()
ir_domain_loss = self.domain_loss(ir_domain_class, self.ir_domain_target)
rgb_domain_loss = self.domain_loss(rgb_domain_class, self.rgb_domain_target)
# ir_domain_prec, = accuracy(ir_domain_class.data, self.ir_domain_target.data)
# rgb_domain_prec, = accuracy(rgb_domain_class.data, self.rgb_domain_target.data)
ir_domain_prec = accuracy_score(self.ir_domain_target.data.cpu().numpy(), np.round(torch.sigmoid(ir_domain_class.data).cpu().numpy()) )
rgb_domain_prec = accuracy_score(self.rgb_domain_target.data.cpu().numpy(), np.round(torch.sigmoid(rgb_domain_class.data).cpu().numpy()))
return ir_domain_loss, ir_domain_prec, rgb_domain_loss, rgb_domain_prec
class SupervisedTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.rgb_obj = ir_obj
self.domain_loss = torch.nn.KLDivLoss()
def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir'):
# if epochs >=2:
for name, module in self.shared_model.module.CNN.named_modules():
for param in module.parameters():
param.requires_grad = False
# for epoch in range(epochs):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
rgb_precisions = AverageMeter()
end = time.time()
for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
if len(ir_inputs) != len(rgb_inputs):
continue
ir_loss, rgb_loss, ir_prec, rgb_prec, domain_loss, part_loss = self._forward([ir_inputs, rgb_inputs],
[ir_targets, rgb_targets],
)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
rgb_losses.update(rgb_loss.item(), rgb_targets.size(0))
domain_losses.update(domain_loss.item(), rgb_targets.size(0))
part_losses.update(part_loss.item(), rgb_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
rgb_precisions.update(rgb_prec, rgb_targets.size(0))
combined_loss = rgb_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'RGB Loss {:.3f} ({:.3f})\t'
'Triplet Loss {:.3f} ({:.3f})\t'
'Part Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'RGB Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t'
.format(epochs, i + 1, len(combined_loader[0]),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
rgb_losses.val, rgb_losses.avg,
domain_losses.val, domain_losses.avg,
part_losses.val, part_losses.avg,
ir_precisions.val, ir_precisions.avg,
rgb_precisions.val, rgb_precisions.avg, 0.1))
def _parse_data(self, inputs):
imgs, pids, camids, realid, clusterid = inputs
clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, inputs, targets):
_, ir_pooled = self.shared_model(inputs[0])
ir_loss, ir_outputs = self.ir_obj(ir_pooled, targets[0], None)
ir_prec, = accuracy(ir_outputs.data, targets[0].data)
_, rgb_pooled = self.shared_model(inputs[1], is_adapt=True)
rgb_loss, outputs = self.ir_obj(rgb_pooled, targets[1], None)
rgb_prec, = accuracy(outputs.data, targets[1].data)
return ir_loss, rgb_loss, ir_prec[0], rgb_prec[0], ir_loss, ir_loss
class IRSupervisedTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
# for name, module in self.shared_model.module.CNN.named_modules():
# for param in module.parameters():
# param.requires_grad = False
# for epoch in range(epochs):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
rgb_precisions = AverageMeter()
end = time.time()
for i, ir_inputs in enumerate(combined_loader):
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
ir_loss, ir_prec = self._forward(ir_inputs, ir_targets)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
combined_loss = ir_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t'
.format(epochs, i + 1, len(combined_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
lr))
return self.ir_obj.M
def _parse_data(self, inputs):
imgs, pids, camids, realid, clusterid = inputs
clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, inputs, targets):
_, ir_pooled = self.shared_model(inputs)
ir_loss, ir_outputs = self.ir_obj(ir_pooled, targets, None)
ir_prec, = accuracy(ir_outputs.data, targets.data)
return ir_loss, ir_prec[0]
class FCSupervisedTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj, is_adapt=False):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.domain_loss = nn.CrossEntropyLoss().cuda()
self.is_adapt = is_adapt
def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
# for name, module in self.shared_model.module.CNN.named_modules():
# for param in module.parameters():
# param.requires_grad = False
# for epoch in range(epochs):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
rgb_precisions = AverageMeter()
end = time.time()
for i, ir_inputs in enumerate(combined_loader):
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
ir_loss, ir_prec = self._forward(ir_inputs, ir_targets)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
# domain_losses.update(domain_inv_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
combined_loss = ir_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t'
.format(epochs, i + 1, len(combined_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
# domain_losses.val, domain_losses.avg,
lr))
return self.ir_obj.M
def _parse_data(self, inputs):
imgs, pids, camids, realid, clusterid = inputs
clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, inputs, targets):
ir_fc, ir_fmap = self.shared_model(inputs, self.is_adapt)
ir_loss, ir_data = self.ir_obj(ir_fmap, targets, None)
ir_prec, = accuracy(ir_data.data, targets.data)
return ir_loss, ir_prec[0]#, domain_inv_loss
class Trainer(BaseTrainer):
def _parse_data(self, inputs):
imgs, subid, camera, condition, cluster_id = inputs['image'], inputs['subid'], inputs['cam'], inputs['condition'], inputs['cluster_id']
return imgs.cuda(), cluster_id.cuda().type(torch.int64), camera #subid.cuda()
def _forward(self, inputs, targets, cam):
_, outputs = self.model(inputs)
loss, outputs = self.criterion(outputs, targets, cam) #torch.nn.functional.so(outputs, targets)
# outputs = np.argmax(outputs.data.cpu(), axis=1)
acc, = accuracy(outputs.data, targets.data)
return loss, acc[0]
#
# if isinstance(self.criterion, ACL_IDL):
# prec, = accuracy(outputs.data, targets.data[0:int(len(targets.data) // 2)])
# prec = prec[0]
# else:
# prec = 0
# return loss, prec
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment