from __future__ import print_function, division
import os
import torch
import numpy as np
import math
import option
import nni
import torch.optim as optim
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from util import EDMLoss, AverageMeter
from torchvision import models
from dataset import AVADataset, BBDataset2
from scipy.stats import pearsonr
from scipy.stats import spearmanr
from sklearn.metrics import accuracy_score
from tensorboardX import SummaryWriter
from tqdm import tqdm
from collections import OrderedDict
from nni.utils import merge_parameter
import pandas as pd
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
import csv


opt = option.init()
device = torch.device("cuda")
MOBILE_NET_V2_UTR = 'https://s3-us-west-1.amazonaws.com/models-nima/mobilenetv2.pth.tar'


### TaNet ***

import os
import torch
import numpy as np
import math
import torch.optim as optim
# import option
import nni
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torchvision import models
# from dataset import AVADataset
from util import EDMLoss, AverageMeter
from tensorboardX import SummaryWriter
from tqdm import tqdm
from scipy.stats import pearsonr
from scipy.stats import spearmanr
from sklearn.metrics import accuracy_score
from nni.utils import merge_parameter
from torchsummary import summary

# opt = option.init()
device = torch.device("cuda:0")
MOBILE_NET_V2_UTR = 'https://s3-us-west-1.amazonaws.com/models-nima/mobilenetv2.pth.tar'


import sys

sys.path.append('./')
sys.path.append('../')

import torch
import torch.nn as nn

from torchsummary import summary

from torchvision import transforms
import torch.nn.init as init




class Net(nn.Module):
    def __init__(self, encoder):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
        self.enc_5 = nn.Sequential(*enc_layers[31:44])  # relu4_1 -> relu5_1

        # fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    # extract relu1_1, relu2_1, relu3_1, relu4_1, relu5_1 features from input image
    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(5):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    def forward(self, style):
        style_feats = self.encode_with_intermediate(style)
        return style_feats[4]


def create_data_part(opt):
    test_ds = BBDataset2(opt['path_to_images'])
    # train_ds = AVADataset(train_csv_path, opt['path_to_images'], if_train=True)
    # val_ds = AVADataset(val_csv_path, opt['path_to_images'], if_train=False)
    # test_ds = AVADataset(test_csv_path, opt['path_to_images'], if_train=False)

    test_loader = DataLoader(test_ds, batch_size=opt['batch_size'], num_workers=opt['num_workers'], shuffle=False)

    return test_loader

def train(opt, model, loader, optimizer, criterion, writer=None, global_step=None, name=None):
    model.train()
    # Freeze
    for name, param in model.named_parameters():
        if name[:11] == "res365_last":
            param.requires_grad = False
        else:
            param.requires_grad = True
    train_losses = AverageMeter()
    for idx, (x, y) in enumerate(tqdm(loader)):
        x = x.type(torch.FloatTensor).to(device)
        y = y.to(device).view(y.size(0),-1).float()
        y_pred = model(x).float()
        loss = criterion(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.update(loss.item(), x.size(0))
    return train_losses.avg

def validate(opt, model, loader, criterion, writer=None, global_step=None, name=None):
    model.eval()
    validate_losses = AverageMeter()
    torch.set_printoptions(precision=3)
    true_score = []
    pred_score = []
    style_feats = []  # 用于存储每张图片的分数
    image_lists = []

    for idx, x in enumerate(tqdm(loader)):
        path = x[1]
        # print(path)
        x = x[0]
        # print(x.shape)
        x = x.type(torch.FloatTensor).to(device)
        feats = model(x)

        feats = feats.data.cpu()

        # print(feats.shape)
        torch.cuda.empty_cache()

        # print(path)

        for i, img_path in enumerate(path):
            img_name = os.path.basename(img_path)
            image_lists.append(img_name)
            feat_tmp = feats[i].view(131072)
            style_feats.append(feat_tmp)

    style_feats = torch.stack(style_feats, dim=0)

    # print(f'Score Avg: {score_avg}')
    os.makedirs('./codebook', exist_ok=True)
    torch.save(style_feats, './codebook/style_dict_1k.pt')
    torch.save(image_lists, './codebook/image_lists_1k.pt')

def start_train(opt):

    dataloader_test= create_data_part(opt)
    criterion = nn.MSELoss()
    criterion.to(device)

    vgg = nn.Sequential(
        nn.Conv2d(3, 3, (1, 1)),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(3, 64, (3, 3)),
        nn.ReLU(),  # relu1-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(64, 64, (3, 3)),
        nn.ReLU(),  # relu1-2
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(64, 128, (3, 3)),
        nn.ReLU(),  # relu2-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(128, 128, (3, 3)),
        nn.ReLU(),  # relu2-2
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(128, 256, (3, 3)),
        nn.ReLU(),  # relu3-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-4
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 512, (3, 3)),
        nn.ReLU(),  # relu4-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-4
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-1, this is the last layer used
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU()  # relu5-4
    )

    vgg.load_state_dict(torch.load('./checkpoints/vgg_normalised.pth'))

    vgg = nn.Sequential(*list(vgg.children())[:44])

    model = Net(vgg)
    model = model.to(device)


    writer = SummaryWriter(log_dir=os.path.join(opt['experiment_dir_name'], 'logs'))

    for e in range(opt['num_epoch']):
        # adjust_learning_rate(opt, optimizer, e)
        # train_loss = train(opt,model=model, loader=dataloader_train, optimizer=optimizer, criterion=criterion,
        #                    writer=writer, global_step=len(dataloader_train) * e,
        #                    name=f"{opt['experiment_dir_name']}_by_batch")
        validate(opt,model=model, loader=dataloader_test, criterion=criterion,
                            writer=writer, global_step=len(dataloader_test) * e,
                            name=f"{opt['experiment_dir_name']}_by_batch")
        # val_loss,vacc,vsrcc,vlcc = validate(opt,model=model, loader=dataloader_train, criterion=criterion,
        #                     writer=writer, global_step=len(dataloader_train) * e,
        #                     name=f"{opt['experiment_dir_name']}_by_batch")

    writer.close()


if __name__ =="__main__":
    import warnings
    warnings.filterwarnings('ignore')
    print(os.getcwd())
    tuner_params = nni.get_next_parameter()
    params = vars(merge_parameter(opt, tuner_params))
    start_train(params)


