import sys
import torch
import torch.nn as nn

class VictimModel(nn.Module):
    def __init__(self, name, path, device, arch=None, param_type="arch"):
        super(VictimModel, self).__init__()
        self.name = name

        # load model from respective repo, just need forward functionality and target parameters
        if self.name == "pdarts":
            sys.path.append("../pdarts")
            import pdarts_utils
            from model_search import Network
            
            self.model = torch.load(path, map_location=device)

            if param_type == "arch":
                self._target_parameters = [p for p in self.model.arch_parameters()]
            elif param_type == "weight":
                self._target_parameters = [p for p in self.model.parameters()]
            else:
                raise ValueError(f"Invalid parameter type: {param_type}")
        elif self.name == "resnet18":
            from models.resnet import resnet18

            self.model = resnet18()
            self.model.load_state_dict(torch.load(path))
            self.model.to(device)

            if param_type == "arch":
                raise NotImplementedError("ResNet18 does not have architecture parameters")
            elif param_type == "weight":
                self._target_parameters = [p for p in self.model.parameters()]
        
        elif self.name == "d-darts":
            sys.path.append("../pdarts")
            from model import NetworkCIFAR as Network
            import genotypes

            if arch is None:
                raise ValueError("Genotype must be provided for discretized DARTS model")
            
            genotype = eval("genotypes.%s" % arch)

            self.model = Network(36, 10, 20, True, genotype)    # use default config
            self.model.load_state_dict(torch.load(path))
            self.model.drop_path_prob = 0
            self.model.to(device)

            # self.model = nn.DataParallel(self.model)

            if param_type == "arch":
                raise NotImplementedError("Discretized DARTS does not have architecture parameters")
            elif param_type == "weight":
                self._target_parameters = [p for p in self.model.parameters()]


    def train(self):
        self.model.train()
    
    def eval(self):
        self.model.eval()

    def forward(self, x):
        if self.name == 'd-darts':
            logits, _ = self.model(x)
        else:
            logits = self.model(x)
        
        return logits

    def target_parameters(self):
        return self._target_parameters

    def last_layer_parameters(self):
        pass
    
    def save(self, path):
        if self.name == "pdarts":
            torch.save(self.model, path)
        else:
            torch.save(self.model.state_dict(), path)