
from adapters import AutoAdapterController, MetaAdapterController, AdapterLayersHyperNetController, \
    AdapterLayersOneHyperNetController, TaskEmbeddingController
import torch
import torch.nn as nn
from networks.base_model import BaseModel, init_weights
from models import get_model
from models.total_loss import  TotalLoss

from adapters.adapter_configuration import AdapterConfig, MetaAdapterConfig


class Trainer(BaseModel):
    def name(self):
        return 'Trainer'

    def __init__(self, opt):
        super(Trainer, self).__init__(opt)
        self.criterion = nn.BCEWithLogitsLoss()
        self.opt = opt
        adapter_config = MetaAdapterConfig()
        adapter_config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.AdapterLayersOneHyperNetController = AdapterLayersOneHyperNetController(adapter_config)
        self.model = get_model(opt.arch, self.AdapterLayersOneHyperNetController).to(torch.float32)
        if opt.fix_backbone:
            params = []
            for name, p in self.model.named_parameters():
                if name == "fc.weight" or name == "fc.bias" or name == "fc.weight":
                    params.append({'params': p, 'lr': opt.lr })
                else:
                    p.requires_grad = False
            for name, p in self.AdapterLayersOneHyperNetController.named_parameters():
                p.requires_grad = True
                params.append({'params': p, 'lr': opt.lr })


        else:
            print(
                "Your backbone is not fixed. Are you sure you want to proceed? If this is a mistake, enable the --fix_backbone command during training and rerun")
            import time
            time.sleep(3)
            params = self.model.parameters()
        if opt.optim == 'adam':
            self.optimizer = torch.optim.AdamW(params, lr=opt.lr, betas=(opt.beta1, 0.999),
                                               weight_decay=opt.weight_decay)
            print("use adam")
        elif opt.optim == 'sgd':
            self.optimizer = torch.optim.SGD(params, lr=opt.lr, momentum=0.0, weight_decay=opt.weight_decay)
            print("use sgd")
        else:
            raise ValueError("optim should be [adam, sgd]")

        self.loss_fn = TotalLoss()
        self.model = self.model.cuda()
        self.model = nn.DataParallel(self.model, device_ids=[0, 1])



    def adjust_learning_rate(self, min_lr=1e-6):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] /= 10.
            if param_group['lr'] < min_lr:
                return False
        return True

    def set_input(self, input):
        self.image_path = input[3]
        self.input = input[0].to(self.device)
        self.label = input[1].to(self.device).float()
        self.inputOrginal = input[2].to(self.device)


    def forward(self, task):
        self.output = self.model(self.input, 5).to(self.device)
        self.outputOrginal = self.model(self.inputOrginal, task).to(self.device)


    def get_loss(self):

        return self.loss_fn(self.output.squeeze(1), self.label)

    def optimize_parameters(self, task):
        self.forward(task)


        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.loss = self.loss_fn( self.output,self.outputOrginal, self.label, self.image_path, device)
        self.optimizer.zero_grad()
        self.loss.backward()

        self.optimizer.step()

