import torch
import torch.nn as nn 
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from resnet import  ResNet34
import numpy as np
import logging
import matplotlib.pyplot as plt
import json
from tqdm import tqdm
import os
import argparse
import random
import psutil



use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)

def train_val(criter, n_epochs, train_loader, val_loader):
    global model, optim, scheduler
    train_loss_list=[]
    val_acc_list=[]
    best_acc = 0
    # for epoch in tqdm(range(n_epochs)):
    for epoch in range(n_epochs):            
        model.train()
        epoch_loss=0
        for imgs, labels in tqdm(train_loader):      
        # for imgs, labels, indices in train_loader:      
            imgs = imgs.to(device)
            labels = labels.to(device) 
            outp = model(imgs)
            
            optim.zero_grad()

            loss = criter(F.softmax(outp), labels)
            epoch_loss += loss.item()

            loss.backward()
            optim.step() 

        logger.debug(f'Epochs [{epoch+1}/{n_epochs}], Losses: {epoch_loss:.4f}')
        train_loss_list.append(epoch_loss)

        model.eval()
        with torch.no_grad():
            total_correct_pred=0
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                output= model(imgs)
                pred_y = torch.argmax(output, 1)
                total_correct_pred += (pred_y == labels).sum().item()
        
        val_acc_list.append(total_correct_pred/len(val_loader.dataset))
        logger.debug(f'val_acc: {val_acc_list[-1]}')
        if val_acc_list[-1] > best_acc:
            best_acc = val_acc_list[-1]
            logger.info(f'Best@epoch {epoch}:\nval_acc: {val_acc_list[-1]}\nloss:{epoch_loss:.4f}')
            
        scheduler.step()

    return train_loss_list, val_acc_list

class Add_Noise(object):
    def __init__(self, n_classes, noise_rate):
        self.n_classes = n_classes
        self.noise_rate = noise_rate
    def __call__(self,y):
        noise_prob = np.ones((self.n_classes))* (self.noise_rate / (self.n_classes-1))
        noise_prob[y] =  1 - self.noise_rate
        return np.random.choice(self.n_classes,p=noise_prob)


parser = argparse.ArgumentParser(description='Microbatch online learning experiment')
parser.add_argument('--loss_types', nargs='+', type=str, default=['mse'], help='List of loss functions used for training')
parser.add_argument('--noise_rates', nargs='+', type=float, default=[0, 0.2, 0.4, 0.6], help='List of symmetric noise rates')
parser.add_argument('--seeds', nargs='+', type=int, default=[3, 4, 5], help='List of random seeds for reproducibility')
parser.add_argument('--T', nargs='+', type=int, default=[2 ,3, 4], help='2**T microbatches')
parser.add_argument('--batch_size', type=int, default=256, help='Training batch size per routine')
parser.add_argument('--n_epochs', type=int, default=50, help='Training epochs per routine')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--n_workers_per_dl', type=int, default=len(psutil.Process().cpu_affinity())//2, help='No. workers used for each dataloader, defaulted to half of the currently free CPUs.')
parser.add_argument('--dataset', type=str, default='CIFAR10', help='Experimenting dataset')


args = parser.parse_args()

logger = logging.getLogger()
loss_type_list = args.loss_types
n_epochs = args.n_epochs
batch_size = args.batch_size
seed_list=args.seeds
lr = args.lr
n_workers_per_dl = args.n_workers_per_dl
dataset = args.dataset

for T in args.T:
    for loss_type in loss_type_list:
        noise_rate_list= args.noise_rates
        for noise_rate in noise_rate_list:
            logger.info(f'loss_type: {loss_type} \t noise_rate: {noise_rate}')
            train_losses_list=[]
            test_acc_list=[]
            for seed in seed_list:
                random.seed(seed)
                torch.manual_seed(seed)
                logger.info(f'seed: {seed}')
                name = f'vary_b_{dataset}_{T}_n{noise_rate}_{loss_type}_seed{seed}'
                
                if dataset == 'CIFAR10':
                    n_classes = 10
                    trainds = torchvision.datasets.CIFAR10(root='./data',train=True, 
                                                transform=transforms.Compose([
                                                    transforms.RandomCrop(32, padding=4),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)),
                                                ]),
                                                target_transform=transforms.Compose([
                                                    Add_Noise(n_classes, noise_rate),
                                                    transforms.Lambda(lambda y: F.one_hot(torch.tensor(y), n_classes).float())
                                                ]),
                                                download=True)

                    testds = torchvision.datasets.CIFAR10(root='./data', 
                                                        train=False, 
                                                        transform=transforms.Compose([
                                                            transforms.ToTensor(),
                                                            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262))
                                                        ]),
                                                        download=True)
                elif dataset == 'CIFAR100':
                    n_classes = 100
                    trainds = torchvision.datasets.CIFAR100(root='./data',train=True, 
                                                transform=transforms.Compose([
                                                    transforms.RandomCrop(32, padding=4),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
                                                ]),
                                                target_transform=transforms.Compose([
                                                    Add_Noise(n_classes, noise_rate),
                                                    transforms.Lambda(lambda y: F.one_hot(torch.tensor(y), n_classes).float())
                                                ]),
                                                download=True)

                    testds = torchvision.datasets.CIFAR100(root='./data', 
                                                        train=False, 
                                                        transform=transforms.Compose([
                                                            transforms.ToTensor(),
                                                            transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
                                                        ]),
                                                        download=True)
                else:
                    raise NotImplementedError
                    
                model = ResNet34(num_classes=n_classes)
                model = model.to(device)

                optim = torch.optim.Adam(model.parameters(), lr=lr)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=[80,120], gamma=0.1)
                
                pretrainds_id, online_trainds_id = train_test_split(range(len(trainds)), test_size=0.5, stratify=np.array(trainds.targets))
                pretrainds = torch.utils.data.Subset(trainds, pretrainds_id)
                online_trainds = torch.utils.data.Subset(trainds, online_trainds_id)

                train_indices_list = [online_trainds_id]
                for _ in range(T):
                    n = len(train_indices_list)
                    for i in range(n):
                        id_1, id_2 = train_test_split(train_indices_list[i], test_size=0.5, stratify=np.array(trainds.targets)[train_indices_list[i]])
                        train_indices_list[i] = id_1
                        train_indices_list.append(id_2)

                trainld_list = [torch.utils.data.DataLoader(pretrainds,batch_size=batch_size, shuffle=True, num_workers=n_workers_per_dl)]
                
                for t in range(2**T):
                    trainld_list.append(torch.utils.data.DataLoader(dataset=torch.utils.data.Subset(trainds, train_indices_list[t]),batch_size=batch_size, shuffle=True, num_workers=n_workers_per_dl))
                
                testld = torch.utils.data.DataLoader(dataset=testds,batch_size=batch_size,shuffle=False)#, num_workers=n_workers_per_dl)
                logger.debug(name)

                train_losses_full, test_acc_full = [], []

                if loss_type =='mse':   
                    criter = nn.MSELoss().to(device)
                else:
                    raise NotImplementedError
                
                for t in range(2**T):
                    logger.info(f'Training pass {t}:')
                    train_loss, test_acc = train_val(criter, n_epochs, trainld_list[t], testld)
                    train_losses_full += train_loss
                    test_acc_full += test_acc

                train_losses_list.append(train_losses_full)
                test_acc_list.append(test_acc_full)

            mean_train_losses = np.mean(train_losses_list,axis=0)
            std_train_losses = np.std(train_losses_list, axis=0)
            mean_acc_test = np.mean(test_acc_list,axis=0)
            std_acc_test = np.std(test_acc_list,axis=0)

            save_dir = f'{dataset}/vary_b_T{T}'
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            name = f'vary_b_{dataset}_{T}_n{noise_rate}_{loss_type}'

            with open(os.path.join(save_dir, f'{name}_trainLossesMean.json'), "w") as fp:
                json.dump(mean_train_losses.tolist(), fp)
            with open(os.path.join(save_dir, f'{name}_trainLossesStd.json'), "w") as fp:
                json.dump(std_train_losses.tolist(), fp)
            with open(os.path.join(save_dir, f'{name}_accTestMean.json'), "w") as fp:
                json.dump(mean_acc_test.tolist(), fp)
            with open(os.path.join(save_dir, f'{name}_accTestStd.json'), "w") as fp:
                json.dump(std_acc_test.tolist(), fp)

            plt.figure(0)
            plt.plot(mean_train_losses, label='Training loss')
            plt.fill_between(range(len(mean_train_losses)), mean_train_losses - std_train_losses, mean_train_losses + std_train_losses, alpha=0.2)
            plt.xlabel('epoch')
            plt.ylabel('Loss')
            plt.title(name)
            plt.legend()
            plt.savefig(os.path.join(save_dir, f'{name}_trainLosses.jpg'))
            plt.close()
            plt.figure(1)
            plt.plot(mean_acc_test, label='Accuracy')
            plt.fill_between(range(len(mean_acc_test)), mean_acc_test - std_acc_test, mean_acc_test + std_acc_test, alpha=0.2)
            plt.xlabel('epoch')
            plt.ylabel('Accuracy')
            plt.title(name)
            plt.legend()
            plt.savefig(os.path.join(save_dir, f'{name}_accTest.jpg'))
            plt.close()