import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from datasets import input_dataset_c1m
import time
import random
import argparse
import numpy as np
from utils import *
from focalloss import *
from loss import loss_cross_entropy, lq_loss, loss_peer, loss_pls, loss_nls, cb_loss, logit_adj
from metrics import *


# Options ----------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--pre_type", type=str, default='image50')  # image
parser.add_argument('--dataset', type=str, help='clothing1M', default='clothing1M')
parser.add_argument('--device', type=str, help='cuda', default='device')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--loss', type=str, help='ce, fw', default='ce')
parser.add_argument('--num_classes', type=int, default=14, help='num of classes')
parser.add_argument('--g_idx', type=int, default=2, help='num of sub-populations to consider')
parser.add_argument('--lmd', default=0.1, type=float, help='lmd')
parser.add_argument('--cluster_type', type=str, help='knn, random', default='random')

def adjust_learning_rate(optimizer, epoch,alpha_plan):
    for param_group in optimizer.param_groups:
        param_group['lr']=alpha_plan[epoch]

def set_model_min(config):
    print(f'Use model {config.pre_type}')
    
    if config.pre_type == 'image50':
        model = res_image.resnet50(pretrained=True)
    else:
        RuntimeError('Undefined pretrained model.')
    if 'image' in config.pre_type:
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, config.num_classes)
    model.to(config.device)
    return model, None


if __name__ == "__main__":

    # Setup ------------------------------------------------------------------------
    t1 = time.time()
    torch.multiprocessing.set_sharing_strategy('file_system')
    torch.set_num_threads(3)    
    config = parser.parse_args()
    config.device = set_device()
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    random.seed(config.seed)
    torch.backends.cudnn.deterministic = True
    model, preprocess = set_model_min(config)

    preprocess_rand = None

    sel_idx = range(int(1e6))
    if config.loss != 'peerloss':
        train_peer = None
        train_dataset, test_dataset, num_classes, train_prior = input_dataset_c1m(config.dataset, g_idx=config.g_idx, cluster_type = config.cluster_type, transform=preprocess_rand, sel_idx = sel_idx, is_peer = False)
    else:
        train_dataset, train_peer, test_dataset, num_classes, train_prior = input_dataset_c1m(config.dataset, g_idx=config.g_idx, cluster_type = config.cluster_type, transform=preprocess_rand, sel_idx = sel_idx, is_peer = True)
    config.num_classes = num_classes
    config.train_prior = train_prior
    print(f'train_prior is {train_prior}')


    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    if config.loss != 'peerloss':
        peer_dataloader = None
    else:
        peer_dataloader = torch.utils.data.DataLoader(train_peer,batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = 16, shuffle = False)

    
    alpha_plan = [0.01] * 30 + [0.001] * 30 + [0.0001] * 30 + [0.00001] * 30 # plan-1
    if config.loss in ['nls']:
        alpha_plan = [1e-7] * 40
        pretrain_path = f'./results/c1m_frf/ce_{config.cluster_type}_group{config.g_idx}_lmd0.0_last.pth.tar'
        state_dict = torch.load(pretrain_path, map_location = "cpu")
        model.load_state_dict(state_dict['state_dict'])
    if config.loss == 'peerloss':
        alpha_plan = [0.002] * 10 + [5e-5] * 5 + [1e-5] * 5 + [5e-6] * 5 + [1e-6] * 5 + [5e-7] * 5 + [1e-7] * 5
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)
    best_acc_ = 0.0
    
    for epoch in range(len(alpha_plan)):
        adjust_learning_rate(optimizer, epoch, alpha_plan)
        # training
        model.train()
        acc = 0.0
        cnt = 0.0
        if config.dataset == "clothing1M": 
            train_dataset.shuffle_and_imbalance(sel_idx)
        for i_batch, (feature, label, groups, index) in enumerate(train_dataloader):

            cnt += index.size(0)#index.shape[0]
            optimizer.zero_grad()
            feature = feature.to(config.device)
            label = label.to(config.device)
            groups = groups.to(config.device)
            _, y_pred = model(feature)        
            
            if config.loss=='ce':
                loss = loss_cross_entropy(epoch, y_pred, label)
            elif config.loss=='peerloss':
                # if epoch > -1: #MARK
                if epoch > 10:
                    peer_iter = iter(peer_dataloader)
                    feature_peer, _, _, _ = next(peer_iter)
                    _, y_pred_peer = model(feature_peer.to(config.device))
                    label_peer = torch.randint(0, 14, (label.shape)).to(config.device)
                    loss, _ = loss_peer(epoch, y_pred, y_pred_peer, label, label_peer)
                else:
                    loss = loss_cross_entropy(epoch, y_pred, label)
            elif config.loss == 'pls':
                loss = loss_pls(epoch, y_pred, label)
            elif config.loss == 'nls':
                loss = loss_nls(epoch, y_pred, label)
            elif config.loss=='focal':
                loss = FocalLoss(gamma=2.0)(y_pred, label)
            elif config.loss=='logit_adj':
                loss = logit_adj(config.train_prior, y_pred, label)
            else:
                # print(loss_type)
                raise NotImplementedError(f" {config.loss} Not Implemented")

            constraints_fair = constraints_dict['acc']
            constraints_confidence = constraints_dict['no_conf']
            loss_reg, _ = constraints_fair(y_pred, groups, label, n_groups=config.g_idx)
            loss += torch.sum((config.lmd / config.g_idx) * torch.abs(loss_reg))
            outputs = F.softmax(y_pred, dim=1)
            _, pred = torch.max(outputs.data, 1)
            pred = pred.detach().cpu().numpy()
            label = label.detach().cpu().numpy()

            acc += (pred == label).sum()
            loss.backward()
            optimizer.step()

        t2 = time.time()
        print(f"time in an epoch: {t2-t1}(s)", flush=True)
        train_acc = float(acc)/float(cnt) if cnt != 0 else -1


        # evaluation
        model.eval()
        acc = 0.0
        cnt = 0.0
        for i_batch, (feature, label, index) in enumerate(test_loader):
            cnt += index.shape[0]
            feature = feature.to(config.device)
            label = label.to(config.device)
            _, y_pred = model(feature)
            outputs = F.softmax(y_pred, dim=1)
            _, pred = torch.max(outputs.data, 1)
            acc += (pred.cpu() == label.cpu()).sum()
        test_acc = float(acc)/float(cnt) if cnt != 0 else -1
        print(f'Epoch {epoch}: train acc = {train_acc}, test acc = {test_acc}', flush = True)
        base_path = 'results/c1m_frf'
        if not os.path.exists(base_path):
            os.makedirs(base_path)
        if test_acc > best_acc_:
            state = {'state_dict': model.state_dict(),
                     'epoch':epoch,
                     'acc':acc,
            }
            save_path= base_path + f'/{config.loss}_{config.cluster_type}_group{config.g_idx}_lmd{config.lmd}_best.pth.tar'
            torch.save(state,save_path)
            best_acc_ = test_acc
            print(f'model saved to {save_path}! Best = {best_acc_}')
        if epoch == len(alpha_plan) -1:
            state = {'state_dict': model.state_dict(),
                     'epoch':epoch,
                     'acc':acc,
            }
            torch.save(state, f'./results/c1m_frf/{config.loss}_{config.cluster_type}_group{config.g_idx}_lmd{config.lmd}_last.pth.tar')