import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

from Support.datareader import datareader_standard


img_size = 32
data_root = r'Your_dir/fashion-dataset/'

net_path = 'model_save/'


class MainNet(torch.nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()

        self.conv_net_1 = torch.nn.Sequential(
            nn.Conv2d(3, 32*3, 4, 2, 1),
            nn.Conv2d(32 * 3, 48 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48*3, 48*3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48 * 3, 96 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(96*3),
            torch.nn.Conv2d(96 * 3, 96 * 3, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(96 * 3),
            nn.Flatten(),
        )

        self.conv_net_2 = torch.nn.Sequential(
            nn.Conv2d(3, 32*3, 4, 2, 1),
            nn.Conv2d(32 * 3, 48 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48*3, 48*3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48 * 3, 96 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(96*3),
            torch.nn.Conv2d(96 * 3, 96 * 3, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(96 * 3),
            nn.Flatten(),
        )
        self.conv_net_3 = torch.nn.Sequential(
            nn.Conv2d(3, 32*3, 4, 2, 1),
            nn.Conv2d(32 * 3, 48 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48*3, 48*3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48 * 3, 96 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(96*3),
            torch.nn.Conv2d(96 * 3, 96 * 3, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(96 * 3),
            nn.Flatten(),
        )

        self.net_show = torch.nn.Sequential(
            nn.Conv2d(3, 32*3, 4, 2, 1),
            nn.Conv2d(32 * 3, 48 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48*3, 48*3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48 * 3, 96 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(96*3),
            torch.nn.Conv2d(96 * 3, 96 * 3, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(96 * 3),
            nn.Flatten(),
            torch.nn.Linear(96 * 3, 32 * 8),
            torch.nn.ReLU(),
            torch.nn.Linear(32 * 8, 32 * 4),
            torch.nn.Linear(32 * 4, 8)
        )


        self.sub_1 = torch.nn.Sequential(
            torch.nn.Linear(96 * 3, 32 * 8),
            torch.nn.ReLU(),
            torch.nn.Linear(32 * 8, 32 * 4),
            torch.nn.Linear(32 * 4, 8))

        self.sub_2 = torch.nn.Sequential(
            torch.nn.Linear(96 * 3, 32 * 8),
            torch.nn.ReLU(),
            torch.nn.Linear(32 * 8, 32 * 4),
            torch.nn.Linear(32 * 4, 8))

        self.sub_3 = torch.nn.Sequential(
            torch.nn.Linear(96 * 3, 32 * 8),
            torch.nn.ReLU(),
            torch.nn.Linear(32 * 8, 32 * 4),
            torch.nn.Linear(32 * 4, 8))

    def forward(self, x):
        fea_1 = self.conv_net_1(x)
        fea_2 = self.conv_net_2(x)
        fea_3 = self.conv_net_3(x)
        out_1 = F.softmax(self.sub_1(fea_1))
        out_2 = F.softmax(self.sub_2(fea_2))
        out_3 = F.softmax(self.sub_3(fea_3))

        return out_1, out_2, out_3


if __name__ == '__main__':
    transform = transforms.Compose([transforms.Scale(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trans_1 = lambda x:x[0]
    trans_2 = lambda x:x[1]
    trans_3 = lambda x:x[2]
    train_data_1 = datareader_standard(file_txt='list_train_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_1,offset=None)
    train_data_2 = datareader_standard(file_txt='list_train_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_2,offset=None)
    train_data_3 = datareader_standard(file_txt='list_train_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_3,offset=None)

    train_data = torch.utils.data.ConcatDataset([train_data_1,train_data_2,train_data_3])

    test_data_1 = datareader_standard(file_txt='list_test_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_1,offset=None)
    test_data_2 = datareader_standard(file_txt='list_test_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_2,offset=None)
    test_data_3 = datareader_standard(file_txt='list_test_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_3,offset=None)

    train_loader = DataLoader(dataset=train_data,batch_size=200,shuffle=True)
    test_loader_1 = DataLoader(dataset=test_data_1,batch_size=200)
    test_loader_2 = DataLoader(dataset=test_data_2,batch_size=200)
    test_loader_3 = DataLoader(dataset=test_data_3,batch_size=200)

    test_loader_list = [test_loader_1,test_loader_2,test_loader_3]

    model = MainNet()
    model.cuda()

    opt = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.5,0.999))

    loss_func = torch.nn.CrossEntropyLoss(reduce=False)

    for epoch in range(1000):
        loss_stat = []
        accu_stat = []
        model.train()
        for batch_x, batch_y in train_loader:
            batch_x_cd, batch_y_cd = batch_x.cuda(), batch_y.cuda()
            batch_pred_1,batch_pred_2,batch_pred_3 = model(batch_x_cd)

            y_pred_1 = torch.argmax(batch_pred_1,dim=1)
            y_pred_2 = torch.argmax(batch_pred_2, dim=1)
            y_pred_3 = torch.argmax(batch_pred_3, dim=1)

            loss_1 = loss_func(batch_pred_1,batch_y_cd)
            loss_2 = loss_func(batch_pred_2, batch_y_cd)
            loss_3 = loss_func(batch_pred_3, batch_y_cd)

            loss_min = torch.stack([loss_1,loss_2,loss_3],dim=0)
            loss_bp = torch.min(loss_min, dim=0)[0].mean()

            opt.zero_grad()
            loss_bp.backward()
            opt.step()

            loss_stat.append(loss_bp.cpu().detach().numpy())
            accu_stat.append(((y_pred_1 == batch_y_cd) + (y_pred_2 == batch_y_cd) +(y_pred_3 == batch_y_cd) + 0.0).mean().cpu().detach().numpy())

        print(epoch)
        loss_stat = np.array(loss_stat)
        print(loss_stat.mean())
        accu_stat = np.array(accu_stat)
        print(accu_stat.mean())

        if epoch % 20 == 0:
            torch.save(model.state_dict(), net_path + str(epoch) + '_model.pkl')

        if epoch % 5 == 0:
            model.eval()
            print("Test")
            print(epoch)
            for test_type in range(3):
                test_loader = test_loader_list[test_type]

                accu_stat_1 = []
                accu_stat_2 = []
                accu_stat_3 = []

                for batch_x, batch_y in test_loader:
                    batch_x_cd, batch_y_cd = batch_x.cuda(), batch_y.cuda()

                    batch_pred_1, batch_pred_2, batch_pred_3 = model(batch_x_cd)

                    y_pred_1 = torch.argmax(batch_pred_1, dim=1)
                    y_pred_2 = torch.argmax(batch_pred_2, dim=1)
                    y_pred_3 = torch.argmax(batch_pred_3, dim=1)

                    accu_stat_1.append(((y_pred_1 == batch_y_cd)+ 0.0).mean().cpu().detach().numpy())
                    accu_stat_2.append(((y_pred_2 == batch_y_cd) + 0.0).mean().cpu().detach().numpy())
                    accu_stat_3.append(((y_pred_3 == batch_y_cd) + 0.0).mean().cpu().detach().numpy())

                print(test_type)
                accu_stat_1 = np.array(accu_stat_1)
                print(accu_stat_1.mean())
                accu_stat_2 = np.array(accu_stat_2)
                print(accu_stat_2.mean())
                accu_stat_3 = np.array(accu_stat_3)
                print(accu_stat_3.mean())
