import argparse
import os
import random
import numpy as np
import pandas as pd

import torch
import torchvision
import torch.optim
import torch.utils.data
import torch.nn.functional as F

import _init_paths
from dataset.new_dataset import get_datasets

from models.smodel import build_model


from utils.helper import clean_state_dict, get_raw_dict, ModelEma
from utils.rkloss import ranking_loss


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"


def parser_args():
    parser = argparse.ArgumentParser(description='Second Training')
    

    # data
    parser.add_argument('--dataset_name', help='dataset name', default='flickr', choices=['flickr', 'twitter', 'raf', 'emotion6', 'fbp5500'])
    parser.add_argument('--dataset_dir', help='dir of all datasets', default='./data')
    parser.add_argument('--img_size', default=256, type=int,
                        help='size of input images')

    # train
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')
    parser.add_argument('-b', '--batch_size', default=32, type=int,
                        help='batch size')
    parser.add_argument('-p', '--print_freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--amp', action='store_true', default=True,
                        help='apply amp')
    parser.add_argument('--train_unlabel', action='store_true', default=False,
                        help="train unlabel data")
    parser.add_argument('--train_ensemble', action='store_true', default=False,
                        help='apply ensemble during training')

    # random seed
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')


    # model
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True,
                        help='use pre-trained model. default is True. ')
    parser.add_argument('--is_data_parallel', action='store_true', default=False,
                        help='on/off nn.DataParallel()')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--resume_omit', default=[], type=str, nargs='*')


    args = parser.parse_args()

    args.dataset_dir = os.path.join(args.dataset_dir, args.dataset_name)

    return args


def get_args():
    args = parser_args()
    return args


def same_seeds(seed):
    random.seed(seed) 
    np.random.seed(seed)  
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def main():
    torchvision.disable_beta_transforms_warning()
    args = get_args()

    if args.seed is not None:
        same_seeds(args.seed)

    return main_worker(args)

def main_worker(args):

    # build model
    model = build_model(args)
    if args.is_data_parallel:
        model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)

            if 'state_dict' in checkpoint:
                state_dict = clean_state_dict(checkpoint['state_dict'])
            elif 'model' in checkpoint:
                state_dict = clean_state_dict(checkpoint['model'])
            else:
                raise ValueError("No model or state_dicr Found!!!")
            for omit_name in args.resume_omit:
                del state_dict[omit_name]
            model.load_state_dict(state_dict, strict=False)
            del checkpoint
            del state_dict
            torch.cuda.empty_cache() 

    # tensorboard


    # Data loading code
    train_label_dataset, train_unlabel_dataset, val_dataset, test_dataset, weak_setlabel_dataset = get_datasets(args)
    print("len(train_label_dataset):", len(train_label_dataset)) 
    print("len(train_unlabel_dataset):", len(train_unlabel_dataset)) 
    print("len(val_dataset):", len(val_dataset))
    print("len(test_dataset):", len(test_dataset))
    print("len(weakset_data):", len(weak_setlabel_dataset))

    args.workers = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8])  # number of workers

    train_loader = torch.utils.data.DataLoader(
        weak_setlabel_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    torch.cuda.empty_cache()
    labels = validate(train_loader, model, args)
    print(labels.shape)
    data = pd.read_csv(os.path.join(args.dataset_dir, 'train_unlabel_data.csv'))
    # data = pd.read_csv(os.path.join(args.dataset_dir, 'train_label_data.csv'))
    data2 = pd.read_csv(os.path.join(args.dataset_dir, 'train_label_data.csv'))
    data.iloc[:, 1:] = labels.cpu().numpy()
    data3 = pd.concat([data, data2], ignore_index=True)
    data3.to_csv(os.path.join(args.dataset_dir, 'wlabel_data.csv'), index=False)
    # data.to_csv('./test_10tlabelnew.csv', index=False)
    return 0


@torch.no_grad()
def validate(val_loader, model, args):
    model.eval()
    labels = []
    for i, ((X, X_w1, X_w2, X_w3), y) in enumerate(val_loader):
        batch_size = X.size(0)
        input = torch.cat([X, X_w1, X_w2, X_w3], dim=0).cuda(non_blocking=True)
        # compute output
        with torch.cuda.amp.autocast(enabled=args.amp):
            pred = model(input)
        pred = F.softmax(pred, dim=-1)
        y_hat, yw1_hat, yw2_hat, yw3_hat = torch.split(pred, batch_size, dim=0)
        y_mean = (y_hat + yw1_hat + yw2_hat + yw3_hat) / 4
        labels.append(y_mean)
        
    labels = torch.cat(labels, dim=0)
    return labels

if __name__ == '__main__':
    main()