import copy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision.models as models

from collections import OrderedDict
import time
import matplotlib.pyplot as plt
import torchvision

from util import *

from stargan import Generator, Discriminator, load_stargan

import sys
sys.path.append('')
import warnings
from iaf import IterAlignFlow
from ddl.base import (BoundaryWarning, DataConversionWarning)

sys.path.append('')
from autoencoders.ae_model import AE
warnings.simplefilter('ignore', BoundaryWarning) # Ignore boundary warnings from ddl
warnings.simplefilter('ignore', DataConversionWarning) # Ignore data conversion warnings from ddl




class Client(object):
    def __init__(self, loader, config):

        self.loader = loader
        self.config = config
        self.device = config.device

        self._build_model()

    def _build_model(self):
        if self.config.trans != 'stargan':
            self.model = FedDIRT(self.config).to(self.device)
        else:
            self.model = FedDIRTStarGAN(self.config).to(self.device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr)
        #self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.lr)
        self.lossMeter = AverageMeter()
        self.accMeter = AverageMeter()

    def restore_optim(self):
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr)
        #self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.lr)
        if self.optimizer_state != None:
            state = self.optimizer.state_dict()
            state['state'] = self.optimizer_state
            self.optimizer.load_state_dict(state)


    def train(self,fig=False, iters=0,sup_target=None):
        '''
        Update for one mini-batch
        '''


        try:
            x, y, d = next(self.data_iter)
        except:
            self.data_iter = iter(self.loader)
            x, y, d = next(self.data_iter)

        x, y, d = x.to(self.device), y.to(self.device), d.to(self.device)
        self.optimizer.zero_grad()
        loss, acc = self.model(x, y,d,fig=fig,iters=iters,sup_target=sup_target)
        loss.backward()
        self.optimizer.step()
        self.optimizer_state = self.optimizer.state_dict()['state']

        with torch.no_grad():
            self.lossMeter.update(loss.item(), len(x))
            self.accMeter.update(acc, len(x))


class Central(object):
    def __init__(self, loader_dict, test_loader, config):

        self.config = config
        self.list_train_domains = config.list_train_domains
        self.device = config.device

        self.num_iters = config.iters
        print('total iters: ',self.num_iters)

        self.sync_step = config.sync_step
        self.eval_step = config.eval_step

        self._build_model()
        self._init_client(loader_dict)
        self.test_loader = test_loader

        self.nparams = 0
        self.model_size = check_nparams(self.model)


    def _build_model(self):
        if self.config.trans != 'stargan':
            self.model = FedDIRT(self.config).to(self.device)
        else:
            print(self.config.trans)
            self.model = FedDIRTStarGAN(self.config).to(self.device)

    def _init_client(self,loader_dict):
        clients_dict = dict()
        for domain in self.list_train_domains:
            clients_dict[domain] = Client(loader_dict[domain],self.config)
        self.clients_dict = clients_dict

    def _aggregate(self,coeffs=None):

        if not coeffs:
            coeffs = [1/len(self.list_train_domains) for _ in range(len(self.list_train_domains))]

        averaged_weights = OrderedDict()
        for i, domain in enumerate(self.list_train_domains):
            local_weight = self.clients_dict[domain].model.state_dict()
            for key in self.model.state_dict().keys():
                if i == 0:
                    averaged_weights[key] = coeffs[i] * local_weight[key]
                else:
                    averaged_weights[key] += coeffs[i] * local_weight[key]
        #print(averaged_weights.keys())
        self.model.load_state_dict(averaged_weights)

    def _transmit(self):
        for domain in self.list_train_domains:
            self.clients_dict[domain].model = copy.deepcopy(self.model)
            self.clients_dict[domain].model = self.clients_dict[domain].model.to(self.clients_dict[domain].device)
            self.clients_dict[domain].model.train()
            self.clients_dict[domain].restore_optim()



    def eval(self):
        self.model.eval()
        lossMeter = AverageMeter()
        accMeter = AverageMeter()
        for batch_idx, (x, y) in enumerate(self.test_loader):
            # To device
            x, y = x.to(self.device), y.to(self.device)

            loss, acc = self.model(x, y)

            lossMeter.update(loss.item(), len(x))
            accMeter.update(acc, len(x))

        return lossMeter, accMeter

    def _agg_train_loss(self):
        loss = 0
        acc = 0
        for domain in self.list_train_domains:
            loss += self.clients_dict[domain].lossMeter.value()
            acc += self.clients_dict[domain].accMeter.value()
            self.clients_dict[domain].lossMeter = AverageMeter()
            self.clients_dict[domain].accMeter = AverageMeter()

        loss = loss/len(self.list_train_domains)
        acc = acc/len(self.list_train_domains)

    def train(self):

        start_iters = 0

        loss_tracker, acc_tracker, np_tracker = [], [], []

        # Start training.
        print('Start training...')
        #start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # =================================================================================== #
            # 1. Train local clients for each mini-batch                                          #
            # =================================================================================== #
            sup_target = np.random.choice(5)
            for domain in self.list_train_domains:
                if (i+1)%10 ==0:
                    self.clients_dict[domain].train(fig=True,iters=i,sup_target=sup_target)
                else:
                    self.clients_dict[domain].train(fig=False,iters=i,sup_target=sup_target)
            # =================================================================================== #
            # 2. Synchronize with central and each local client                                   #
            # =================================================================================== #
            if (i + 1) % self.sync_step == 0:

                # aggregate for central model
                self._aggregate()

                # transmit central model to each client
                self._transmit()

            # =================================================================================== #
            # 3. Evaluate model                                                                   #
            # =================================================================================== #

            if (i + 1) % self.eval_step == 0:
                self._agg_train_loss()
                loss, acc = self.eval()
                print(f'after {i + 1} iters, test loss: {loss}, test acc: {acc}')
                nparams = self.eval_step * 2 * self.model_size
                self.nparams += nparams


                with torch.no_grad():
                    loss_tracker.append(loss)
                    acc_tracker.append(acc)
                    np_tracker.append(self.nparams)

                tracker = dict()
                tracker['loss'] = loss_tracker
                tracker['acc'] = acc_tracker
                tracker['np'] = np_tracker
                save_name = self.config.trans.replace('/', '-') + self.config.note
                torch.save(tracker, f'./saved/{save_name}.pt')

        tracker = dict()
        tracker['loss'] = loss_tracker
        tracker['acc'] = acc_tracker
        tracker['np'] = np_tracker
        return tracker



class FedDIRT(nn.Module):
    def __init__(self,config):
        super(FedDIRT,self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
        )
        # self.encoder = nn.Sequential(
        #     nn.Conv2d(1, 32, kernel_size=5, stride=1), nn.ReLU(), nn.MaxPool2d(2, 2),
        #     nn.Conv2d(32, 64, kernel_size=5, stride=1), nn.ReLU(), nn.MaxPool2d(2, 2),
        # )
        self.fc11 = nn.Sequential(nn.Linear(1024, 64))

        torch.nn.init.xavier_uniform_(self.encoder[0].weight)
        torch.nn.init.xavier_uniform_(self.encoder[4].weight)
        torch.nn.init.xavier_uniform_(self.fc11[0].weight)
        self.fc11[0].bias.data.zero_()

        self.cls = nn.Linear(64, 10)
        torch.nn.init.xavier_uniform_(self.cls.weight)
        self.cls.bias.data.zero_()

        self.n_domains = config.n_domains
        self.use_shared = config.use_shared
        self.trans_name = config.tn

        self.extra = config.extra
        self.uni_target = config.uni_target
        self.sup_uni_target = config.sup_uni_target

        # =================================================================================== #
        #                                 (1) Whether use shared                              #
        # =================================================================================== #
        if self.use_shared:
            print('Use shared space !!!')
        # =================================================================================== #
        #                                 (2) Whether use AE                                  #
        # =================================================================================== #
        if config.tn == 'indaeinb':
            self._load_indae(config)

        # =================================================================================== #
        #                                 (3) Whether use Hist                                  #
        # =================================================================================== #
        model_dir = config.model_dir

        print(model_dir)
        self.trans = torch.load(model_dir,map_location=config.device)

        self.reg = config.reg
        self.device = config.device

    def forward(self,x,y,d=None, fig=False, iters=0, sup_target=None):
        if self.training:
            if self.extra:
                x,y,d = self._extra_data(x,y,d)


        h = self.encoder(x)
        #print(h.shape)
        h = h.view(-1, 1024)
        z = self.fc11(h)
        #print(y.shape)
        logits = self.cls(F.relu(z))
        loss = F.cross_entropy(logits, y)
        acc = ((logits.argmax(1)==y).sum().float()/len(y)).item()

        device = x.device
        if self.training:
            with torch.no_grad():
                #target_d = np.random.choice(5)
                if self.extra:
                    target_d = (torch.ones(x.shape[0]) * np.random.choice(5)).to(torch.int64).to(self.device)
                elif self.uni_target:
                    target_d = (torch.ones(x.shape[0]) * np.random.choice(5)).to(torch.int64).to(self.device)
                elif self.sup_uni_target:
                    target_d = (torch.ones(x.shape[0]) * sup_target).to(torch.int64).to(self.device)
                else:
                    target_d = torch.Tensor([np.random.choice(5) for _ in range(x.shape[0])]).to(torch.int64).to(self.device)
                    # x.new_ones(x.shape[0]).to(torch.int64) * np.random.choice(5)
                if self.trans_name == 'inb':
                    if self.use_shared:
                        x_ = self._inb_findshare(x,y,d)
                    else:
                        x_ = self._inb(x,y,d,target_d)
                elif self.trans_name == 'indaeinb':
                        x_ = self._indaeinb(x,y,d,target_d)
                else:
                    pass

            if fig:
                grid_img = torchvision.utils.make_grid(x[:20].view(-1, 1, 28, 28).cpu(), nrow=10, normalize=True)
                grid_img = torchvision.utils.make_grid(x_[:20].view(-1, 1, 28, 28).cpu(), nrow=10, normalize=True)
            h_ = self.encoder(x_)
            h_ = h_.view(-1, 1024)
            z_ = self.fc11(h_)
            reg = F.mse_loss(z_,z,reduction='mean')

            # if (iters+1) % 100 ==0:
            #     ratio = float(reg/loss)
            #     self.reg = self.reg/ratio
            #     print(f'change reg weight to {self.reg}')
            loss = loss + self.reg * reg

        return loss, acc

    def _extra_data(self,x,y,d):
        de = torch.Tensor([np.random.choice(5) for _ in range(x.shape[0])]).to(torch.int64).to(self.device)
        if self.trans_name == 'inb':
            xe = self._inb(x, y, d, de)
        else:
            xe = self._indaeinb(x, y, d, de)

        x = torch.cat((x,xe),dim=0)
        d = torch.cat((d,de))
        y = torch.cat((y,y))
        return x,y,d


    def _inb(self,x,y,d,target_d):
        x_ = batch_inb_translate(self.trans, x.view(-1, 784), y, d, target_d).view(-1, 1, 28, 28)
        return x_


    def _inb_findshare(self,x,y,d):
        x_ = batch_inb_findshare(self.trans, x.view(-1, 784), y, d).view(-1, 1, 28, 28)
        return x_

    def _load_indae(self, args):
        self.enc = wrap_enc(args)
        self.dec = wrap_dec(args)
        return self


    def _indaeinb(self,x,y,d,target_d):
        x_ = batch_indaeinb_translate(self.trans, x.view(-1, 784), y, d,
                                 target_d,
                                 self.enc,self.dec,
                                 ).view(-1, 1, 28, 28)
        return x_





def inb_translate(cd, x, d, target_d):
    z = cd(x,d)
    #trans_d = torch.ones(z.shape[0]) * target_d
    x_trans = cd.inverse(z,target_d)
    return x_trans

def batch_inb_translate(cd_dict,x,y,d, target_d):

    classes = torch.unique(y)
    x_trans = torch.zeros_like(x).to(x.device)
    for yy in classes:
        xt = x[y==yy]
        dt = d[y==yy]
        tdt = target_d[y==yy]
        xt_trans = inb_translate(cd_dict[int(yy)], xt, dt, tdt)
        x_trans[y==yy] = xt_trans
    return x_trans




def indaeinb_translate(cd, x, d, target_d,enc,dec):
    x_enc = torch.zeros(x.shape[0],288).to(x.device)
    domains = torch.unique(d)
    for dd in domains:
        x_enc[d==dd] = enc(x[d==dd],dd)
    z = cd(x_enc,d)
    #trans_d = torch.ones(z.shape[0]) * target_d
    x_trans = cd.inverse(z,target_d)

    x_trans_dec = torch.zeros(x.shape[0],784).to(x.device)
    domains = torch.unique(target_d)
    for dd in domains:
        x_trans_dec[target_d==dd] = dec(x_trans[target_d==dd],dd)
    return x_trans_dec

def batch_indaeinb_translate(cd_dict,x,y,d, target_d, enc,dec):

    classes = torch.unique(y)
    x_trans = torch.zeros_like(x).to(x.device)
    for yy in classes:
        xt = x[y==yy]
        dt = d[y==yy]
        tdt = target_d[y==yy]

        xt_trans = indaeinb_translate(cd_dict[int(yy)], xt, dt, tdt,enc,dec)
        x_trans[y==yy] = xt_trans
    return x_trans

def batch_inb_findshare(cd_dict,x,y,d):

    classes = torch.unique(y)

    x_trans = torch.zeros_like(x).to(x.device)
    for yy in classes:
        xt = x[y==yy]
        dt = d[y==yy]
        xt_trans = cd_dict[int(yy)](xt, dt)
        x_trans[y==yy] = xt_trans
    return x_trans


class wrap_enc(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.ae_list = []
        for dd in args.list_train_domains:
            ae = AE()
            ae_path = args.ae_dir + '/ae' + '-' + str(dd) + '.pt'
            ae.load_state_dict(torch.load(ae_path))
            ae = ae.to(args.device)
            self.ae_list.append(ae.encoder)
            print(f'Finish loading encoder from {ae_path}')

    def forward(self, X, y):
        X = X.view(-1, 1, 28, 28)
        return self.ae_list[int(y)](X).view(X.shape[0], -1)


class wrap_dec(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.ae_list = []
        for dd in args.list_train_domains:
            ae = AE()
            ae_path = args.ae_dir + '/ae' + '-' + str(dd) + '.pt'
            ae.load_state_dict(torch.load(ae_path))
            ae = ae.to(args.device)
            self.ae_list.append(ae.decoder)
            print(f'Finish loading decoder from {ae_path}')

    def forward(self, X, y):
        X = X.view(-1, 32, 3, 3)
        return self.ae_list[int(y)](X).view(X.shape[0], -1)


def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=stride, padding=1, bias=False)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = nn.Sequential(
                conv3x3(in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

class FedDIRTStarGAN(nn.Module):
    def __init__(self,config):
        super(FedDIRTStarGAN,self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(8), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(8, 16, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(16), nn.ReLU(), nn.MaxPool2d(2, 2),
        )
        #self.encoder = nn.Sequential(ResidualBlock(1,32,stride=2),ResidualBlock(32,64,stride=2))
        self.fc11 = nn.Sequential(nn.Linear(256, 64))
        #self.fc11 = nn.Sequential(nn.Linear(3136,1024),nn.ReLU(),nn.Linear(1024,64))

        torch.nn.init.xavier_uniform_(self.encoder[0].weight)
        torch.nn.init.xavier_uniform_(self.encoder[4].weight)
        torch.nn.init.xavier_uniform_(self.fc11[0].weight)
        self.fc11[0].bias.data.zero_()

        self.cls = nn.Linear(64, 10)
        torch.nn.init.xavier_uniform_(self.cls.weight)
        self.cls.bias.data.zero_()
        print('saved/stargan_model/{}_domain{}_last-G.ckpt'.format(config.dataset,config.target_domain))
        self.trans = load_stargan(ckpt=config.model_dir + 'stargan_model/RotatedMnist_domain{}_last-G.ckpt'.format(75))
        self.trans.eval()

        self.device = config.device
        self.reg = config.reg
        self.max_iters = config.iters
        self.sup_uni_target = config.sup_uni_target

    def forward(self,x,y,d=None,fig=False,iters=0,sup_target=None):
        h = self.encoder(x)
        h = h.view(-1, 256)
        z = self.fc11(h)

        logits = self.cls(F.relu(z))
        loss = F.cross_entropy(logits, y)
        acc = ((logits.argmax(1)==y).sum().float()/len(y)).item()


        if self.training:
            with torch.no_grad():
                one_hot_d = x.new_zeros([x.shape[0],5])
                one_hot_d.scatter_(1, d[:,None], 1)
                if self.sup_uni_target:
                    d_ = (torch.ones(x.shape[0]) * sup_target).to(torch.int64).to(self.device)
                else:
                    d_ = torch.Tensor([np.random.choice(5) for _ in range(x.shape[0])]).to(torch.int64).to(self.device)
                #d_ = x.new_ones(x.shape[0]).to(torch.int64)*np.random.choice(5)
                one_hot_d_ = x.new_zeros([x.shape[0],5])
                one_hot_d_.scatter_(1, d_[:,None], 1)
                x_ = self.trans(x,one_hot_d,one_hot_d_)

            h_ = self.encoder(x_)
            h_ = h_.view(-1, 256)
            z_ = self.fc11(h_)
            reg = F.mse_loss(z_,z,reduction='mean')



            # if (iters+1) % 100 ==0:
            #     ratio = float(reg/loss)
            #     self.reg = self.reg/ratio
            #     print(f'change reg weight to {self.reg}')

            #loss = loss + self.reg * reg
            loss = loss + self.reg * reg
        return loss, acc


# class FedDIRTStarGAN(nn.Module):
#     def __init__(self,config):
#         super(FedDIRTStarGAN,self).__init__()
#         # self.encoder = nn.Sequential(
#         #     nn.Conv2d(1, 32, kernel_size=5, stride=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
#         #     nn.Conv2d(32, 64, kernel_size=5, stride=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
#         # )
#         self.encoder = nn.Sequential(
#             nn.Conv2d(1, 32, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
#             nn.Conv2d(32, 64, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
#         )
#         self.fc11 = nn.Sequential(nn.Linear(1024, 64))
#
#         torch.nn.init.xavier_uniform_(self.encoder[0].weight)
#         torch.nn.init.xavier_uniform_(self.encoder[4].weight)
#         torch.nn.init.xavier_uniform_(self.fc11[0].weight)
#         self.fc11[0].bias.data.zero_()
#
#         self.cls = nn.Linear(64, 10)
#         torch.nn.init.xavier_uniform_(self.cls.weight)
#         self.cls.bias.data.zero_()
#         print('saved/stargan_model/{}_domain{}_last-G.ckpt'.format(config.dataset,config.target_domain))
#         self.trans = load_stargan(ckpt=config.model_dir + 'stargan_model/{}_domain{}_last-G.ckpt'.format(config.dataset,75))
#         self.trans.eval()
#
#         self.device = config.device
#         self.reg = config.reg
#
#     def forward(self,x,y,d=None,fig=False):
#         h = self.encoder(x)
#         h = h.view(-1, 1024)
#         z = self.fc11(h)
#
#         logits = self.cls(F.relu(z))
#         loss = F.cross_entropy(logits, y)
#         acc = ((logits.argmax(1)==y).sum().float()/len(y)).item()
#
#
#         if self.training:
#             with torch.no_grad():
#                 one_hot_d = x.new_zeros([x.shape[0],5])
#                 one_hot_d.scatter_(1, d[:,None], 1)
#                 d_ = torch.Tensor([np.random.choice(5) for _ in range(x.shape[0])]).to(torch.int64).to(self.device)
#                 #d_ = x.new_ones(x.shape[0]).to(torch.int64)*np.random.choice(5)
#                 one_hot_d_ = x.new_zeros([x.shape[0],5])
#                 one_hot_d_.scatter_(1, d_[:,None], 1)
#                 x_ = self.trans(x,one_hot_d,one_hot_d_)
#
#             h_ = self.encoder(x_)
#             h_ = h_.view(-1, 1024)
#             z_ = self.fc11(h_)
#             reg = F.mse_loss(z_,z,reduction='mean')
#             loss = loss + reg * self.reg
#
#         return loss, acc



# v0
#         # self.encoder = nn.Sequential(
#         #     nn.Conv2d(1, 32, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2),
#         #     nn.Conv2d(32, 64, kernel_size=5, stride=1, bias=False), nn.InstanceNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
#         # )
#         self.encoder = nn.Sequential(
#             nn.Conv2d(1, 32, kernel_size=5, stride=1),  nn.ReLU(), nn.MaxPool2d(2, 2),
#             nn.Conv2d(32, 64, kernel_size=5, stride=1), nn.ReLU(), nn.MaxPool2d(2, 2),
#         )
#         self.fc11 = nn.Sequential(nn.Linear(1024, 64))
#
#         # torch.nn.init.xavier_uniform_(self.encoder[0].weight)
#         # torch.nn.init.xavier_uniform_(self.encoder[4].weight)
#         # torch.nn.init.xavier_uniform_(self.fc11[0].weight)
#         # self.fc11[0].bias.data.zero_()
#
#         self.cls = nn.Linear(64, 10)
#         #torch.nn.init.xavier_uniform_(self.cls.weight)
#         #self.cls.bias.data.zero_()


# v1
#         self.encoder = nn.Sequential(
#             nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
#             nn.BatchNorm2d(32),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2)
#         )
#         self.fc11 = nn.Sequential(
#             nn.Linear(in_features=64 * 6 * 6, out_features=600),
#             nn.Dropout2d(0.25),
#             nn.Linear(in_features=600, out_features=120),
#         )
#
#         self.cls = nn.Linear(in_features=120, out_features=10)


# v2
#         self.encoder = nn.Sequential(
#             nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
#             nn.InstanceNorm2d(6),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
#             nn.InstanceNorm2d(16),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2)
#         )
#
#         self.fc11 = nn.Sequential(
#             nn.Linear(in_features=256, out_features=120),
#             nn.ReLU(),
#             nn.Dropout2d(0.25),
#             nn.Linear(in_features=120, out_features=84),
#             nn.ReLU(),
#         )
#
#         self.cls = nn.Linear(in_features=84, out_features=10)

# v4
        # self.encoder = nn.Sequential(
        #     nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2,bias=False), nn.InstanceNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2),
        #     nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=2, bias=False), nn.InstanceNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2),
        #     nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2),
        #     nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=2, bias=False), nn.InstanceNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2),
        # )
        # self.fc11 = nn.Sequential(nn.Linear(1152, 400))
        # self.cls = nn.Sequential(nn.Linear(400, 120),
        #                          nn.ReLU(),
        #                          nn.Dropout2d(0.2),
        #                          nn.Linear(120,64),
        #                          nn.ReLU(),
        #                          nn.Dropout2d(0.2),
        #                          nn.Linear(64,10))