Skip to content
Snippets Groups Projects
Commit c798f3fd authored by Kshitij Nikhal's avatar Kshitij Nikhal
Browse files

add trainer

parent 00ff70ad
Branches
No related tags found
No related merge requests found
from __future__ import print_function, absolute_import
import time
import numpy as np
import torch
from sklearn.metrics import accuracy_score
from evaluation.eval import accuracy
from utils.meters import AverageMeter
from torch import nn
class BaseTrainer(object):
def __init__(self):
super(BaseTrainer, self).__init__()
def train(self, combined_loader, optimizer, epochs, stage, print_freq=1, mode='ir'):
raise NotImplementedError
def _parse_data(self, inputs):
raise NotImplementedError
def _forward(self, inputs, targets, opt):
raise NotImplementedError
class UnsupervisedTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.is_adapt = is_adapt
self.mode = mode
self.ir_targets = torch.ones((32, 1)).cuda()
self.rgb_targets = torch.zeros((32, 1)).cuda()
self.domain_loss = nn.BCEWithLogitsLoss()
self.triplet_loss = nn.TripletMarginLoss(swap=True)
def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
domain_precisions = AverageMeter()
end = time.time()
for i, triplet_inputs in enumerate(combined_loader[0]):
if i == 200:
break
data_time.update(time.time() - end)
# ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
anchor_rgb, pid, _, rgb_positive, rgb_negative, ir_negative, ir_positive = triplet_inputs
ir_loss, ir_prec, domain_loss, domain_prec = self._forward(#ir_inputs, ir_targets,
[anchor_rgb, rgb_positive, rgb_negative,
ir_negative, ir_positive, pid])
ir_losses.update(ir_loss.item(), anchor_rgb.size(0))
domain_losses.update(domain_loss, anchor_rgb.size(0))
ir_precisions.update(ir_prec, anchor_rgb.size(0))
domain_precisions.update(domain_prec, anchor_rgb.size(0))
combined_loss = ir_loss + domain_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if True:#(i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'Domain Loss {:.3f} ({:.3f})\t'
'Domain Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t\t'
.format(epochs, i + 1, len(combined_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
domain_losses.val, domain_losses.avg,
domain_precisions.val, domain_precisions.avg,
lr), end='\r')
def _parse_data(self, inputs):
imgs, pids, camids, clusterid, condition = inputs
# clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, triplets):
# domain_class, ir_fmap = self.shared_model(inputs, is_classifier=True)
# ir_loss, ir_data = self.ir_obj(ir_fmap, targets, None)
# ir_prec, = accuracy(ir_data.data, targets.data)
# self.rgb_targets = torch.zeros((inputs.size(0), 1)).cuda()
# domain_loss = self.domain_loss(domain_class, self.rgb_targets)
# domain_prec = accuracy_score(self.rgb_targets.data.cpu().numpy(), np.round(torch.sigmoid(domain_class.data).cpu().numpy()))
# domain_prec, = accuracy(domain_class.data, self.rgb_targets.data)
anchor_rgb, rgb_positive, rgb_negative, ir_negative, ir_positive, pid = triplets
_, anchor_rgb = self.shared_model(anchor_rgb.cuda(), modal=1)
# _, rgb_positive = self.shared_model(rgb_positive.cuda(), modal=1)
_, rgb_negative = self.shared_model(rgb_negative.cuda(), modal=1)
_, ir_negative = self.shared_model(ir_negative.cuda(), modal=2)
_, ir_positive = self.shared_model(ir_positive.cuda(), modal=2)
triplet_loss = self.triplet_loss(anchor_rgb, ir_positive, rgb_negative) # +
# triplet_loss = self.triplet_loss(ir_positive, anchor_rgb, ir_negative)
l1_lambda = 0.000001#0.000001
l1_penalty_loss = 0
for output in ir_positive:
l1_penalty_loss += torch.norm(output, 1)
l1_penalty_loss *= l1_lambda
# ir_loss, ir_data = self.ir_obj(ir_positive.cuda(), pid.cuda(), None)
# rgb_loss, rgb_data = self.ir_obj(anchor_rgb.cuda(), pid.cuda(), None)
# ir_prec, = accuracy(ir_data.data, targets.data)
return triplet_loss, 0, l1_penalty_loss, 0 # domain_prec#[0]
class StaticMemBankTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.is_adapt = is_adapt
self.mode = mode
def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
self.shared_model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
domain_precisions = AverageMeter()
end = time.time()
for i, (ir_inputs) in enumerate(combined_loader[0]):
# if i == 200:
# break
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
# rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, ir_targets)
#rgb_inputs, rgb_targets)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
domain_losses.update(domain_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
domain_precisions.update(domain_prec, ir_targets.size(0))
combined_loss = ir_loss #+ domain_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'RGB Loss {:.3f} ({:.3f})\t'
'RGB Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t\t\t'
.format(epochs, i + 1, len(combined_loader[0]),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
domain_losses.val, domain_losses.avg,
domain_precisions.val, domain_precisions.avg,
lr), end='\r')
def _parse_data(self, inputs):
imgs, pids, camids, clusterid, condition = inputs
# clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, ir_inputs, ir_targets):
domain_class, ir_fmap = self.shared_model(ir_inputs, modal=2)
ir_loss, ir_data = self.ir_obj(ir_fmap, ir_targets, None)
ir_prec, = accuracy(ir_data.data, ir_targets.data)
# print(f"rgb_inputs: {rgb_inputs.shape}")
# domain_class, rgb_fmap = self.shared_model(rgb_inputs, modal=1)
# rgb_loss, rgb_data = self.ir_obj(rgb_fmap, rgb_targets, None)
# # print(f"rgb_FMAP: {rgb_fmap.shape}")
# rgb_prec, = accuracy(rgb_data.data, rgb_targets.data)
return ir_loss, ir_prec[0], ir_loss, ir_prec[0]#domain_prec#[0]
class BaselineTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj, is_adapt=False, mode='ir'):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.is_adapt = is_adapt
self.mode = mode
def train(self, combined_loader, optimizer, epochs, stage, print_freq, lr=0.1):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
domain_precisions = AverageMeter()
end = time.time()
for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
# if i == 200:
# break
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, ir_targets,
rgb_inputs, rgb_targets)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
domain_losses.update(domain_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
domain_precisions.update(domain_prec, ir_targets.size(0))
combined_loss = ir_loss + domain_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'RGB Loss {:.3f} ({:.3f})\t'
'RGB Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t\t\t'
.format(epochs, i + 1, len(combined_loader[0]),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
domain_losses.val, domain_losses.avg,
domain_precisions.val, domain_precisions.avg,
lr), end='\r')
def _parse_data(self, inputs):
imgs, pids, camids, clusterid, condition = inputs
# clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, ir_inputs, ir_targets, rgb_inputs, rgb_targets):
domain_class, ir_fmap = self.shared_model(ir_inputs, modal=2)
# print(f"ir_inputs: {ir_inputs.shape}")
#
# print(f"IR_FMAP: {ir_fmap.shape}")
ir_loss, ir_data = self.ir_obj(ir_fmap, ir_targets, None)
ir_prec, = accuracy(ir_data.data, ir_targets.data)
# print(f"rgb_inputs: {rgb_inputs.shape}")
domain_class, rgb_fmap = self.shared_model(rgb_inputs, modal=1)
rgb_loss, rgb_data = self.ir_obj(rgb_fmap, rgb_targets, None)
# print(f"rgb_FMAP: {rgb_fmap.shape}")
rgb_prec, = accuracy(rgb_data.data, rgb_targets.data)
return ir_loss, ir_prec[0], rgb_loss, rgb_prec[0]#domain_prec#[0]
class DomainTrainer(BaseTrainer):
def __init__(self, shared_model):
super().__init__()
self.shared_model = shared_model
self.domain_loss = nn.BCEWithLogitsLoss()
def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
self.shared_model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
domain_precisions = AverageMeter()
end = time.time()
for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
ir_loss, ir_prec, domain_loss, domain_prec = self._forward(ir_inputs, rgb_inputs)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
domain_losses.update(domain_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
domain_precisions.update(domain_prec, ir_targets.size(0))
combined_loss = ir_loss + domain_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'RGB Loss {:.3f} ({:.3f})\t'
'RGB Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t'
.format(epochs, i + 1, len(combined_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
domain_losses.val, domain_losses.avg,
domain_precisions.val, domain_precisions.avg,
lr))
def _parse_data(self, inputs):
imgs, pids, camids, clusterid, condition = inputs
# clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, ir_inputs, rgb_inputs):
ir_domain_class, ir_fmap = self.shared_model(ir_inputs, is_classifier=True)
rgb_domain_class, rgb_fmap = self.shared_model(rgb_inputs, is_classifier=True)
self.ir_domain_target = torch.ones((ir_inputs.size(0), 1)).cuda()
self.rgb_domain_target = torch.zeros((rgb_inputs.size(0), 1)).cuda()
ir_domain_loss = self.domain_loss(ir_domain_class, self.ir_domain_target)
rgb_domain_loss = self.domain_loss(rgb_domain_class, self.rgb_domain_target)
# ir_domain_prec, = accuracy(ir_domain_class.data, self.ir_domain_target.data)
# rgb_domain_prec, = accuracy(rgb_domain_class.data, self.rgb_domain_target.data)
ir_domain_prec = accuracy_score(self.ir_domain_target.data.cpu().numpy(), np.round(torch.sigmoid(ir_domain_class.data).cpu().numpy()) )
rgb_domain_prec = accuracy_score(self.rgb_domain_target.data.cpu().numpy(), np.round(torch.sigmoid(rgb_domain_class.data).cpu().numpy()))
return ir_domain_loss, ir_domain_prec, rgb_domain_loss, rgb_domain_prec
class SupervisedTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.rgb_obj = ir_obj
self.domain_loss = torch.nn.KLDivLoss()
def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir'):
# if epochs >=2:
for name, module in self.shared_model.module.CNN.named_modules():
for param in module.parameters():
param.requires_grad = False
# for epoch in range(epochs):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
rgb_precisions = AverageMeter()
end = time.time()
for i, (ir_inputs, rgb_inputs) in enumerate(zip(combined_loader[0], combined_loader[1])):
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
rgb_inputs, rgb_targets, rgb_camids = self._parse_data(rgb_inputs)
if len(ir_inputs) != len(rgb_inputs):
continue
ir_loss, rgb_loss, ir_prec, rgb_prec, domain_loss, part_loss = self._forward([ir_inputs, rgb_inputs],
[ir_targets, rgb_targets],
)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
rgb_losses.update(rgb_loss.item(), rgb_targets.size(0))
domain_losses.update(domain_loss.item(), rgb_targets.size(0))
part_losses.update(part_loss.item(), rgb_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
rgb_precisions.update(rgb_prec, rgb_targets.size(0))
combined_loss = rgb_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'RGB Loss {:.3f} ({:.3f})\t'
'Triplet Loss {:.3f} ({:.3f})\t'
'Part Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'RGB Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t'
.format(epochs, i + 1, len(combined_loader[0]),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
rgb_losses.val, rgb_losses.avg,
domain_losses.val, domain_losses.avg,
part_losses.val, part_losses.avg,
ir_precisions.val, ir_precisions.avg,
rgb_precisions.val, rgb_precisions.avg, 0.1))
def _parse_data(self, inputs):
imgs, pids, camids, realid, clusterid = inputs
clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, inputs, targets):
_, ir_pooled = self.shared_model(inputs[0])
ir_loss, ir_outputs = self.ir_obj(ir_pooled, targets[0], None)
ir_prec, = accuracy(ir_outputs.data, targets[0].data)
_, rgb_pooled = self.shared_model(inputs[1], is_adapt=True)
rgb_loss, outputs = self.ir_obj(rgb_pooled, targets[1], None)
rgb_prec, = accuracy(outputs.data, targets[1].data)
return ir_loss, rgb_loss, ir_prec[0], rgb_prec[0], ir_loss, ir_loss
class IRSupervisedTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
# for name, module in self.shared_model.module.CNN.named_modules():
# for param in module.parameters():
# param.requires_grad = False
# for epoch in range(epochs):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
rgb_precisions = AverageMeter()
end = time.time()
for i, ir_inputs in enumerate(combined_loader):
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
ir_loss, ir_prec = self._forward(ir_inputs, ir_targets)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
combined_loss = ir_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t'
.format(epochs, i + 1, len(combined_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
lr))
return self.ir_obj.M
def _parse_data(self, inputs):
imgs, pids, camids, realid, clusterid = inputs
clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, inputs, targets):
_, ir_pooled = self.shared_model(inputs)
ir_loss, ir_outputs = self.ir_obj(ir_pooled, targets, None)
ir_prec, = accuracy(ir_outputs.data, targets.data)
return ir_loss, ir_prec[0]
class FCSupervisedTrainer(BaseTrainer):
def __init__(self, shared_model, ir_obj, is_adapt=False):
super().__init__()
self.shared_model = shared_model
self.ir_obj = ir_obj
self.domain_loss = nn.CrossEntropyLoss().cuda()
self.is_adapt = is_adapt
def train(self, combined_loader, optimizer, epochs, stage, print_freq, mode='ir', lr=0.1):
# for name, module in self.shared_model.module.CNN.named_modules():
# for param in module.parameters():
# param.requires_grad = False
# for epoch in range(epochs):
self.shared_model.train()
# if early_stopping:
# print(f"Early stopping at epoch: {epoch}")
# break
batch_time = AverageMeter()
data_time = AverageMeter()
ir_losses = AverageMeter()
rgb_losses = AverageMeter()
domain_losses = AverageMeter()
part_losses = AverageMeter()
ir_precisions = AverageMeter()
rgb_precisions = AverageMeter()
end = time.time()
for i, ir_inputs in enumerate(combined_loader):
data_time.update(time.time() - end)
ir_inputs, ir_targets, ir_camids = self._parse_data(ir_inputs)
ir_loss, ir_prec = self._forward(ir_inputs, ir_targets)
ir_losses.update(ir_loss.item(), ir_targets.size(0))
# domain_losses.update(domain_inv_loss.item(), ir_targets.size(0))
ir_precisions.update(ir_prec, ir_targets.size(0))
combined_loss = ir_loss
optimizer.zero_grad()
combined_loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'IR Loss {:.3f} ({:.3f})\t'
'IR Prec {:.2%} ({:.2%})\t'
'LR :{:.3f}\t'
.format(epochs, i + 1, len(combined_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
ir_losses.val, ir_losses.avg,
ir_precisions.val, ir_precisions.avg,
# domain_losses.val, domain_losses.avg,
lr))
return self.ir_obj.M
def _parse_data(self, inputs):
imgs, pids, camids, realid, clusterid = inputs
clusterid = pids
imgs.requires_grad = False
clusterid = [int(i) for i in clusterid]
clusterid = torch.IntTensor(clusterid).long()
imgs, clusterid = imgs.cuda(), clusterid.cuda()
return imgs, clusterid, camids
def _forward(self, inputs, targets):
ir_fc, ir_fmap = self.shared_model(inputs, self.is_adapt)
ir_loss, ir_data = self.ir_obj(ir_fmap, targets, None)
ir_prec, = accuracy(ir_data.data, targets.data)
return ir_loss, ir_prec[0]#, domain_inv_loss
class Trainer(BaseTrainer):
def _parse_data(self, inputs):
imgs, subid, camera, condition, cluster_id = inputs['image'], inputs['subid'], inputs['cam'], inputs['condition'], inputs['cluster_id']
return imgs.cuda(), cluster_id.cuda().type(torch.int64), camera #subid.cuda()
def _forward(self, inputs, targets, cam):
_, outputs = self.model(inputs)
loss, outputs = self.criterion(outputs, targets, cam) #torch.nn.functional.so(outputs, targets)
# outputs = np.argmax(outputs.data.cpu(), axis=1)
acc, = accuracy(outputs.data, targets.data)
return loss, acc[0]
#
# if isinstance(self.criterion, ACL_IDL):
# prec, = accuracy(outputs.data, targets.data[0:int(len(targets.data) // 2)])
# prec = prec[0]
# else:
# prec = 0
# return loss, prec
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment