import os
import copy
import torch
import torchvision.models as models
import torch.nn as nn

from coopt.server_client.model_training import (optimize_model, 
                                                optimize_data,
                                                align,
                                                OneHotEncoder)


class Client:
    def __init__(self, client_id, model_type, resolution, feature_dim, args):
        self.client_id = client_id
        self.model_type = model_type
        self.args = args
        self.model = None
        self.best_val_acc = 0.0
        if 'human' in self.model_type:
            self.resolution = args.input_size
            # self.feature_dim = args.nclass
            self.feature_dim = args.feature_dim
            self.class_embedding = nn.Embedding(args.nclass, self.feature_dim)
            self.class_embedding.requires_grad_(False)
            
            # self.feature_dim = args.nclass
            # self.one_hot_encoding = OneHotEncoder(args.nclass)
            
        else:
            self.resolution = resolution
            self.feature_dim = feature_dim
            


    def load_pretrained_model(self):
        if 'human' in self.model_type:
            self.model = self.class_embedding
            # self.model = self.one_hot_encoding
            # print(type(self.model))
        else:
            if not os.path.exists(f'outputs/models/client/{self.model_type}'):
                if self.model_type == 'mobilenet_v2':
                    self.model = models.mobilenet_v2(pretrained=True)
                elif self.model_type == 'resnet101': 
                    self.model = models.resnet101(pretrained=True)
            else:
                self.model = torch.load(f'outputs/models/client/{self.model_type}')

        
    def train(self, global_x, global_optimal_data, global_y):
        self.model, self.best_val_acc = optimize_model(
            self.args, 
            global_x, 
            global_optimal_data, 
            global_y, 
            self.model,
            self.model_type)




    def optimize_data(self, x, max_feature_dim, dim_up=False, align_W=None, align_b=None):
        return optimize_data(
            copy.deepcopy(self.model), 
            self.feature_dim,
            max_feature_dim,
            self.resolution,
            x, 
            dim_up=dim_up,
            align_W=align_W, align_b=align_b
        )
    

    
    def align(
        self, 
        client_data, 
        align_feature_dim,
        align_data,
        align_features,
        dim_up=False,
        align_W=None, align_b=None
        ):
        return align(
            client_data,
            self.model,
            self.feature_dim,
            self.resolution,
            align_feature_dim,
            align_data,
            align_features,
            dim_up=dim_up,
            align_W=align_W, align_b=align_b
        )