import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import pandas as pd

import sys
sys.path.insert(1, './')

from dataset import CelebADataset
from torchvision.models import resnet50
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD

from torch.utils.data import DataLoader
from torch.nn import Sigmoid
from tqdm import tqdm
import shutil
import argparse
from utils import set_seed, compute_stats
from torchvision import transforms as T
from utils import DROLossComputer
from utils import ResNet18, ResNet50

class ResNet50Trainer():    
    def __init__(self, root_dir, transform=None, opt=None):
        self.root_dir = root_dir
        self.transform = transform
        self.batch_size = opt.batch_size
        self.opt = opt

        self.model_save_name = f'{opt.output_dir}/best_model_{opt.seed}.pth'

        self.epochs = opt.epochs
        self.lr = opt.lr
        set_seed(opt.seed)

        #load the datasets for train/test/val 
        self.dataset_train = CelebADataset(root_dir, split='train', transform=transform, opt=opt)
        num_real = len(self.dataset_train)
        opt.num_per_group = int(num_real * opt.gen_ratio) // 8

        self.dataset_train.load_gen_data()
        self.dataset_train.get_gen_ratio()

        self.dataset_test = CelebADataset(root_dir, split='test', transform=transform, opt=opt)
        self.dataset_val = CelebADataset(root_dir, split='val', transform=transform, opt=opt)

        self.dataloader_train = DataLoader(self.dataset_train, batch_size=opt.batch_size, shuffle=True, num_workers=3)
        self.dataloader_test = DataLoader(self.dataset_test, batch_size=opt.batch_size, num_workers=3)
        self.dataloader_val = DataLoader(self.dataset_val, batch_size=opt.batch_size, num_workers=3)

        self.dataset_train.get_class_distribution()
        self.dataset_test.get_class_distribution()
        self.dataset_val.get_class_distribution()


    def load_model(self):
        self.model = ResNet50(num_classes=len(self.dataset_train.targets_values))
        if self.opt.pretrain_dir != 'None': 
            print(f'Loading pretrained model from {self.opt.pretrain_dir}')
            self.model.load_state_dict(torch.load(self.opt.pretrain_dir))

        self.model.to('cuda')
        adjustments = [float(c) for c in self.opt.generalization_adjustment.split(',')]
        if len(adjustments)==1:
            adjustments = np.array(adjustments* self.dataset_train.n_groups_dro_real_gen())
        else:
            adjustments = np.array(adjustments)

        self.criterion = DROLossComputer(
            torch.nn.CrossEntropyLoss(reduction='none'),
            is_robust=self.opt.robust,
            n_groups=self.dataset_train.n_groups_dro_real_gen(),
            group_counts=self.dataset_train.group_counts_dro_real_gen(),
            alpha=self.opt.alpha,
            gamma=self.opt.gamma,
            adj=adjustments,
            step_size=self.opt.robust_step_size,
            normalize_loss=self.opt.use_normalized_loss,
            btl=self.opt.btl,
            min_var_weight=self.opt.minimum_variational_weight)
        
        self.optimizer = Adam(self.model.parameters(), lr=self.lr, weight_decay=self.opt.weight_decay)
        self.best_accuracy = 0.0


    def train(self):
        self.model.train()
        for epoch in range(self.epochs):
            for i, data in enumerate(tqdm(self.dataloader_train, ascii=True)):
                self.optimizer.zero_grad()
                img, label = data['img'].to('cuda'), data['target'].to('cuda')
                group_idx = data['group_idx_real_gen'].to('cuda')

                output = self.model(img)  

                loss = self.criterion.loss(output, label, group_idx, True)
                loss.backward()
                self.optimizer.step()
                if i % 100 == 0:
                    print('Epoch: {} Iteration: {} Loss: {}'.format(epoch, i, loss))

            accuracy, _, _ = self.evaluate(self.dataloader_val, self.model)
            self.model.train()
            if accuracy > self.best_accuracy:
                self.best_accuracy = accuracy

                if not os.path.exists(self.opt.output_dir):
                    os.makedirs(self.opt.output_dir)
                
                torch.save(self.model.state_dict(), self.model_save_name)            

    #define eval function that takes in dataloader and model and calculates accuracy
    def evaluate(self, dataloader, model): 
        self.model.eval()

        classes = dataloader.dataset.get_classes() 
        biases = dataloader.dataset.get_biases()


        output_list = [] 
        target_list = [] 
        bias_list = []

        with torch.no_grad():
            for _, data in enumerate(tqdm(dataloader, ascii=True)):
                img = data['img'].cuda()
                target = data['target']
                bias = data['bias']

                output = model(img).detach().cpu()
                output_list.append(output) 
                target_list.append(target)
                bias_list.append(bias)

        outputs = torch.cat(output_list, dim=0).numpy()
        targets = torch.cat(target_list, dim=0).numpy()
        biases = torch.cat(bias_list, dim=0).numpy() 

        worst_acc, conflict_acc, balanced_acc = compute_stats(outputs, targets, biases)
        
        print(f'Worst acc: {worst_acc}')
        print(f'Conflict acc: {conflict_acc}')
        print(f'Balanced acc: {balanced_acc}')

        return worst_acc, conflict_acc, balanced_acc

    #define a function that loads the best model and evaluates it on the test set
    def test(self, split='test'):
        self.model.load_state_dict(torch.load(self.model_save_name))

        worst_acc_val, conflict_acc_val, balanced_acc_val = self.evaluate(self.dataloader_val, self.model)
        worst_acc_test, conflict_acc_test, balanced_acc_test = self.evaluate(self.dataloader_test, self.model)

        results = { 
            "worst_acc_val": worst_acc_val,
            "conflict_acc_val": conflict_acc_val,
            "balanced_acc_val": balanced_acc_val,
            "worst_acc_test": worst_acc_test,
            "conflict_acc_test": conflict_acc_test,
            "balanced_acc_test": balanced_acc_test,
        }

        df = pd.DataFrame(results, index=[0])
        df.to_csv(os.path.join(self.opt.results_dir, f'results_{self.opt.seed}.csv'))


if __name__ == '__main__':
    #define the transform
    transform = transforms.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='test_records')
    parser.add_argument('--results_dir', type=str, default='test_records')

    parser.add_argument('--target_attr', type=str, default='Blond_Hair')
    parser.add_argument('--pretrain_dir', type=str, default='None')

    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--gen_balance', action='store_true')
    parser.add_argument('--gen_add_ratio', type=float, default=0.0)
    parser.add_argument('--write_results', action='store_true')
    parser.add_argument('--minority_to_keep', type=float, default=1.0)

    parser.add_argument('--balance_gen', action='store_true')
    parser.add_argument('--limit_to_gen', action='store_true')
    parser.add_argument("--freeze_backbone", action="store_true")
    parser.add_argument("--reinit_linear_layer", action="store_true")

    parser.add_argument('--robust', default=False, action='store_true')
    parser.add_argument('--alpha', type=float, default=0.2)
    parser.add_argument('--generalization_adjustment', default="0.0")
    parser.add_argument('--automatic_adjustment', default=False, action='store_true')
    parser.add_argument('--robust_step_size', default=0.01, type=float)
    parser.add_argument('--use_normalized_loss', default=False, action='store_true')
    parser.add_argument('--btl', default=False, action='store_true')
    parser.add_argument('--gamma', type=float, default=0.1)
    parser.add_argument('--minimum_variational_weight', type=float, default=0)
    parser.add_argument('--gen_ratio', type=float, default=1.0)


