from dataset.selectedRotateImageFolder import *
"""
Copyright to SAR Authors, ICLR 2023 Oral (notable-top-5%)
built upon on Tent and EATA code.
"""
from logging import debug
import os
import time
import argparse
import json
import random
import numpy as np
from pycm import *

import math
from typing import ValuesView

from utils.utils import get_logger
from dataset.selectedRotateImageFolder import prepare_test_data
from utils.cli_utils import *

import torch    
import torch.nn.functional as F

import tent
import eata
import sar
from sam import SAM
import timm
import bnadapt
import models.Res as Resnet

from utils_dart import *



def prepare_int_data(args,dirichunks,diri_delta, use_transforms=True):	
    te_transforms_local = te_transforms
    intdir = os.path.join(args.data, 'train')
    intset = SelectedRotateImageFolder_dirichlet(intdir, te_transforms_local, 
                                            diri_chunks=dirichunks, diri_delta=diri_delta,
                                                original=True, rotation=False,
                                                   rotation_transform=rotation_te_transforms)

    if not hasattr(args, 'workers'):
        args.workers = 1
    intloader = torch.utils.data.DataLoader(intset, batch_size=args.int_batch_size, shuffle=args.if_shuffle,
                                                    num_workers=args.workers, pin_memory=True)
    return intset, intloader
    

def get_args():

    parser = argparse.ArgumentParser(description='SAR exps')

    # path
    parser.add_argument('--data', default="path of clean ImageNet training dataset", help='path to dataset')
    parser.add_argument('--data_corruption', default='path of ImageNet-C dataset', help='path to corruption dataset')
    parser.add_argument('--output', default='./eval_results/log', help='the output directory of this experiment')

    parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ')
    parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
    parser.add_argument('--debug', default=False, type=bool, help='debug or not.')

    # dataloader
    parser.add_argument('--workers', default=2, type=int, help='number of data loading workers (default: 4)')
    parser.add_argument('--test_batch_size', default=64, type=int, help='mini-batch size for testing, before default value is 4')
    parser.add_argument('--int_batch_size', default=50, type=int, help='mini-batch size for testing, before default value is 4')

    parser.add_argument('--if_shuffle', default=True, type=bool, help='if shuffle the test set.')

    # corruption settings
    parser.add_argument('--level', default=5, type=int, help='corruption level of test(val) set.')
    parser.add_argument('--corruption', default='gaussian_noise', type=str, help='corruption type of test(val) set.')

    # eata settings
    parser.add_argument('--fisher_size', default=2000, type=int, help='number of samples to compute fisher information matrix.')
    parser.add_argument('--fisher_alpha', type=float, default=2000., help='the trade-off between entropy and regularization loss, in Eqn. (8)')
    parser.add_argument('--e_margin', type=float, default=math.log(1000)*0.40, help='entropy margin E_0 in Eqn. (3) for filtering reliable samples')
    parser.add_argument('--d_margin', type=float, default=0.05, help='\epsilon in Eqn. (5) for filtering redundant samples')

    # Exp Settings
    parser.add_argument('--method', default='bn_adapt', type=str, help='no_adapt, tent, eata, sar, bn_adapt')
    parser.add_argument('--model', default='resnet50_bn_torch', type=str, help='resnet50_gn_timm or resnet50_bn_torch or vitbase_timm')
    parser.add_argument('--exp_type', default='label_shifts', type=str, help='normal, mix_shifts, bs1, label_shifts')

    # SAR parameters
    parser.add_argument('--sar_margin_e0', default=math.log(1000)*0.40, type=float, help='the threshold for reliable minimization in SAR, Eqn. (2)')
    parser.add_argument('--imbalance_ratio', default=500000, type=int, help='imbalance ratio for label shift exps, selected from [1, 1000, 2000, 3000, 4000, 5000, 500000], 1  denotes totally uniform and 500000 denotes (almost the same to Pure Class Order). See Section 4.3 for details;')

    # int time
    parser.add_argument('--num_chunks', default=2000, type=int)
    parser.add_argument('--delta', default=0.001, type=float)
    parser.add_argument('--num_epochs', default=100, type=int)
    return parser.parse_args()

# Define a function for testing using TTA (Temp Test Accuracy) 
# logits and labels for BNAdapt is needed for our training log for convenience
def test_temp(logits, labels, g_phi, args):
    num_examples = logits.size(0)
    n_batches = math.ceil(num_examples / args.test_batch_size)
    batch_size = args.test_batch_size
    acc_old,acc_new = 0, 0
    for counter in range(n_batches):
        logits_curr = logits[counter*batch_size: (counter+1)*batch_size]
        labels_curr = labels[counter*batch_size: (counter+1)*batch_size]

        with torch.no_grad():
            T_curr = torch.diag(g_phi(torch.softmax(logits_curr, 1).mean(0).cuda()).cpu())
        
        acc_old += (logits_curr.argmax(1) == labels_curr).float().sum()
        acc_new += ((logits_curr@T_curr).argmax(1) == labels_curr).float().sum()
        
    return acc_old.item() / num_examples, acc_new.item() /num_examples



if __name__ == '__main__':

    args = get_args()
    # Set random seeds for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # load pre-trained model
    net_init = Resnet.__dict__['resnet50'](pretrained=True)
    net_init = net_init.cuda()

    # Load temporary test logits and labels obtained by BNAdapt for different corruption levels (before running int_time.py, we need to save the BNAdapt logits and labels for training log)
    common_corruptions = ['gaussian_noise']
    irs =  [1, 1000, 2000, 3000, 4000, 5000, 500000]
    logits_tes = []
    labels_tes = []
    test_lv = 5
    for ir in irs:
        logits_te = np.load("./eval_results/data/logits_te_lv%d_imb%d_seed%d.npy"%(test_lv, ir, args.seed))
        labels_te = np.load("./eval_results/data/labels_te_lv%d_imb%d_seed%d.npy"%(test_lv, ir, args.seed))
        logits_te, labels_te = torch.tensor(logits_te), torch.tensor(labels_te).long()
        logits_tes.append(logits_te)
        labels_tes.append(labels_te)

    
    args.if_shuffle = False
    num_classes = 1000

    # Train an MLP (Multi-Layer Perceptron)
    num_epochs = args.num_epochs
    hiddendim = 1000
    netW = MLP_diag(1000, hiddendim=hiddendim).cuda()
    loss_fn = nn.CrossEntropyLoss().cuda()
    params = netW.parameters()
    optimizer = torch.optim.Adam(params, lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    model_adapt = copy.deepcopy(net_init)
    model_adapt.train()

    # Create a dictionary to store accuracy results
    acc_storage = {}
    acc_storage['acc_tr_pre'] = []
    acc_storage['acc_tr_post'] = []
    acc_storage['acc_te_pre'] = []
    acc_storage['acc_te_post'] = []
    for epoch in range(num_epochs):
        acc_tr_old, acc_tr_new = 0,0
        print ("Epoch%d training begins"%epoch)

        # Prepare intermediate data and create a data loader (sampled by Dirichlet distribution)
        int_dataset, int_loader = prepare_int_data(args, args.num_chunks, args.delta)
        for counter, batch in enumerate(int_loader):
            x_curr, y_curr = batch[0], batch[1]
            x_curr = x_curr.cuda()
            y_curr = y_curr.cuda()

            with torch.no_grad():
                outputs = model_adapt(x_curr)
            
            # Calculate T based on the average softmax of outputs
            T = torch.diag(netW(torch.softmax(outputs, dim=-1).mean(0)))

            # Calculate loss and perform optimization
            loss = loss_fn(outputs @ T, y_curr)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # Calculate accuracy before and after applying T
            acc_tr_old += (outputs.argmax(1) == y_curr).float().sum()
            acc_tr_new += ((outputs @ T).argmax(1) == y_curr).float().sum()
        
        # Compute test accuracy
        if epoch % 20 == 19:
            print ("tr accs(%.4f_%d)"%(args.delta, args.num_chunks), acc_tr_old.item() / len(int_loader.dataset), acc_tr_new.item() / len(int_loader.dataset))
            acc_old, acc_new = test_temp(logits_tes[0], labels_tes[0], netW, args)
            print ("te accs(%.4f_%d)"%(args.delta, args.num_chunks), epoch, acc_old, acc_new)

            acc_storage['acc_tr_pre'].append(acc_tr_old.item() / len(int_loader.dataset))
            acc_storage['acc_tr_post'].append(acc_tr_new.item() / len(int_loader.dataset))
            acc_storage['acc_te_pre'].append(acc_old)
            acc_storage['acc_te_post'].append(acc_new)
        
        # Save the trained model and accuracy logs
        if epoch ==0 or epoch % 20 == 19 :
            torch.save(netW.state_dict(), "path of trained gphi")
            torch.save(acc_storage, "path of tr logs")

        # Adjust learning rate using a scheduler
        scheduler.step()




            