import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

from Support.datareader import datareader_standard

img_size = 32
data_root = r'Your_dir/fashion-dataset/'

net_path = 'model_save/'

class MainNet(torch.nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()

        self.conv_net_1 = torch.nn.Sequential(
            nn.Conv2d(3, 32*3, 4, 2, 1),
            nn.Conv2d(32 * 3, 48 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48*3, 48*3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48 * 3, 96 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(96*3),
            torch.nn.Conv2d(96 * 3, 96 * 3, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(96 * 3),
            nn.Flatten(),
        )

        self.conv_net_2 = torch.nn.Sequential(
            nn.Conv2d(3, 32*3, 4, 2, 1),
            nn.Conv2d(32 * 3, 48 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48*3, 48*3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48 * 3, 96 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(96*3),
            torch.nn.Conv2d(96 * 3, 96 * 3, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(96 * 3),
            nn.Flatten(),
        )
        self.conv_net_3 = torch.nn.Sequential(
            nn.Conv2d(3, 32*3, 4, 2, 1),
            nn.Conv2d(32 * 3, 48 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48*3, 48*3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48 * 3, 96 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(96*3),
            torch.nn.Conv2d(96 * 3, 96 * 3, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(96 * 3),
            nn.Flatten(),
        )

        self.net_show = torch.nn.Sequential(
            nn.Conv2d(3, 32*3, 4, 2, 1),
            nn.Conv2d(32 * 3, 48 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48*3, 48*3, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(48 * 3, 96 * 3, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(96*3),
            torch.nn.Conv2d(96 * 3, 96 * 3, 4, 2, 1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(96 * 3),
            nn.Flatten(),
            torch.nn.Linear(96 * 3, 32 * 8),
            torch.nn.ReLU(),
            torch.nn.Linear(32 * 8, 32 * 4),
            torch.nn.Linear(32 * 4, 8)
        )


        self.sub_1=torch.nn.Sequential(
            torch.nn.Linear(96*3 , 32 * 8),
            torch.nn.ReLU(),
            torch.nn.Linear(32 * 8, 32 * 4),
            torch.nn.Linear(32 * 4, 8))

        self.sub_2=torch.nn.Sequential(
            torch.nn.Linear(96*3 , 32 * 8),
            torch.nn.ReLU(),
            torch.nn.Linear(32 * 8, 32 * 4),
            torch.nn.Linear(32 * 4, 8))

        self.sub_3=torch.nn.Sequential(
            torch.nn.Linear(96*3 , 32 * 8),
            torch.nn.ReLU(),
            torch.nn.Linear(32 * 8, 32 * 4),
            torch.nn.Linear(32 * 4, 8))

    def forward(self, x):
        fea_1=self.conv_net_1(x)
        fea_2=self.conv_net_2(x)
        fea_3=self.conv_net_3(x)
        out_1=F.softmax(self.sub_1(fea_1))
        out_2 = F.softmax(self.sub_2(fea_2))
        out_3 = F.softmax(self.sub_3(fea_3))

        return out_1,out_2,out_3

transform = transforms.Compose([transforms.Scale(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))])

trans_1 = lambda x:x[0]
trans_2 = lambda x:x[1]
trans_3 = lambda x:x[2]
train_data_1 = datareader_standard(file_txt='list_train_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_1,offset=None)
train_data_2 = datareader_standard(file_txt='list_train_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_2,offset=None)
train_data_3 = datareader_standard(file_txt='list_train_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_3,offset=None)

train_data=torch.utils.data.ConcatDataset([train_data_1,train_data_2,train_data_3])

test_data_1 = datareader_standard(file_txt='list_test_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_1,offset=None)
test_data_2 = datareader_standard(file_txt='list_test_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_2,offset=None)
test_data_3 = datareader_standard(file_txt='list_test_mix_all.txt', data_root=data_root, transform=transform, choice=slice(1,4),target_transform=trans_3,offset=None)

train_loader = DataLoader(dataset=train_data,batch_size=200,shuffle=True)
test_loader_1=DataLoader(dataset=test_data_1,batch_size=200)
test_loader_2=DataLoader(dataset=test_data_2,batch_size=200)
test_loader_3=DataLoader(dataset=test_data_3,batch_size=200)

test_loader_list=[test_loader_1,test_loader_2,test_loader_3]

model = MainNet()
model.cuda()

opt = torch.optim.Adam(model.parameters(), lr=0.001,betas=(0.5,0.999))

loss_func = torch.nn.CrossEntropyLoss(reduce=False)


for epoch in range(1000):
    loss_stat = []
    accu_stat = []
    model.train()
    for batch_x, batch_y in train_loader:
        batch_x_cd, batch_y_cd = batch_x.cuda(), batch_y.cuda()
        batch_pred_1,batch_pred_2,batch_pred_3 = model(batch_x_cd)

        y_pred_1 = torch.argmax(batch_pred_1,dim=1)
        y_pred_2 = torch.argmax(batch_pred_2, dim=1)
        y_pred_3 = torch.argmax(batch_pred_3, dim=1)

        loss_1 = loss_func(batch_pred_1,batch_y_cd)
        loss_2 = loss_func(batch_pred_2, batch_y_cd)
        loss_3 = loss_func(batch_pred_3, batch_y_cd)

        loss_min=torch.stack([loss_1,loss_2,loss_3],dim=0)
        loss_bp = torch.min(loss_min, dim=0)[0].mean()

        opt.zero_grad()
        loss_bp.backward()
        opt.step()

        loss_stat.append(loss_bp.cpu().detach().numpy())
        accu_stat.append(((y_pred_1 == batch_y_cd) + (y_pred_2 == batch_y_cd) +(y_pred_3 == batch_y_cd) + 0.0).mean().cpu().detach().numpy())

    print(epoch)
    loss_stat = np.array(loss_stat)
    print(loss_stat.mean())
    accu_stat = np.array(accu_stat)
    print(accu_stat.mean())

    if epoch%20 == 0:
        torch.save(model.state_dict(), net_path + str(epoch) + '_model.pkl')

    if epoch%5 == 0:
        model.eval()
        print("Test")
        print(epoch)
        for test_type in range(3):
            test_loader=test_loader_list[test_type]

            accu_stat_1 = []
            accu_stat_2 = []
            accu_stat_3 = []

            for batch_x, batch_y in test_loader:
                batch_x_cd, batch_y_cd = batch_x.cuda(), batch_y.cuda()

                batch_pred_1, batch_pred_2, batch_pred_3 = model(batch_x_cd)

                y_pred_1 = torch.argmax(batch_pred_1, dim=1)
                y_pred_2 = torch.argmax(batch_pred_2, dim=1)
                y_pred_3 = torch.argmax(batch_pred_3, dim=1)

                accu_stat_1.append(((y_pred_1 == batch_y_cd)+ 0.0).mean().cpu().detach().numpy())
                accu_stat_2.append(((y_pred_2 == batch_y_cd) + 0.0).mean().cpu().detach().numpy())
                accu_stat_3.append(((y_pred_3 == batch_y_cd) + 0.0).mean().cpu().detach().numpy())

            print(test_type)
            accu_stat_1 = np.array(accu_stat_1)
            print(accu_stat_1.mean())
            accu_stat_2 = np.array(accu_stat_2)
            print(accu_stat_2.mean())
            accu_stat_3 = np.array(accu_stat_3)
            print(accu_stat_3.mean())