import numpy as np
import random
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from pathlib import Path
import os
from getpass import getuser
from socket import gethostname
from utils.progressbar import track_iter_progress
from utils.text import TextLogger
from utils.timer import Timer
from .builder import STRATEGIES
from datasets.dataloader import GetDataLoader, GetHandler
from copy import deepcopy
from datasets.base_dataset import BaseDataset
from datasets.randAug import AblationAugmentPool
from evaluation import *
from .utils import get_initialized_module, get_lr
import csv


def generate_split_info(split):
    assert type(split) in [list, tuple, set, str]
    if type(split) == str:
        if split.find('single') != -1:
            loader_name = 'aug'
        elif split.find('mixup') != -1:
            loader_name = 'mixup'
        else:
            loader_name = 'base'
        to_soft = False
    else:
        loader_name = []
        for split_elem in split:
            if split_elem.find('single') != -1:
                loader_name.append('aug')
            elif split_elem.find('mixup') != -1:
                loader_name.append('mixup')
            else:
                loader_name.append('base')
        if 'mixup' in loader_name:
            to_soft = [True if (split_elem.find('mixup') == -1) else False
                       for split_elem in split]
        else:
            to_soft = [False for _ in split]
    return split, loader_name, to_soft


@STRATEGIES.register_module()
class Strategy:
    def __init__(self, dataset: BaseDataset, net, args, logger, timestamp):
        self.dataset = dataset
        self.net = net  
        self.args = args
        
        self.clf, self.optimizer, self.scheduler = \
            get_initialized_module(self.net, self.args.lr, self.args.momentum, self.args.weight_decay,
                                   self.args.milestones, num_classes=len(dataset.CLASSES),
                                   drop_ratio=args.drop_ratio)

        
        self.cycle = 0
        self.epoch = 0
        self.logger = logger
        self.TextLogger = TextLogger(self.clf, vars(args), logger)
        self.timer = Timer()
        self.timestamp = timestamp
        self.acc_val_list = []
        self.acc_test_list = []
        self.num_labels_list = []
        self.TextLogger._dump_log(vars(args))

    def init_clf(self):
        """When we want to initialize the model we use, apply this function.
        Random parameter initialization is included.
        """
        self.clf, self.optimizer, self.scheduler = \
            get_initialized_module(self.net, self.args.lr, self.args.momentum, self.args.weight_decay,
                                   self.args.milestones, num_classes=len(self.dataset.CLASSES))

    def query(self, n):
        """Query new samples according the current model or some strategies.

        :param n: (int)The number of samples to query.

        Returns:
            list[int]: The indices of queried samples.

        """
        raise NotImplementedError

    def get_ulb_list(self):
        if self.args.aug_ulb <= 0:
            return 'train_u'
        split_list = ['train_u']
        if 'train_u_aug_single' in self.dataset.DATA_INFOS.keys():
            split_list.append('train_u_aug_single')
        if 'train_u_aug_mixup' in self.dataset.DATA_INFOS.keys():
            split_list.append('train_u_aug_mixup')
        return split_list

    def update(self, n):
        if n == 0:
            return None
        
        if self.args.aug_ulb > 0:
            
            if self.args.aug_ratio > 0:
                self.generate_aug(self.args.aug_ratio, split='train_u', aug_intensity=self.args.aug_ulb,
                                  aug_type=self.args.ablation_aug_type)
            if self.args.mixup_ratio > 0:
                self.generate_mixup(self.args.mixup_ratio, split='train_u', fix_lam=self.args.aug_ulb * 0.1)
            
            idxs_temp_q = self.query(n)
            idxs_q = self.find_idxs_q(idxs_temp_q)
            print(f'Selected initial samples: {len(idxs_q)}')
            if len(idxs_q) > n:  
                idxs_q = np.random.choice(idxs_q, n, replace=False)
            self.dataset.update_lb(idxs_q)
            if len(idxs_q) < n:  
                idxs_u = np.arange(len(self.dataset.DATA_INFOS['train_full']))[self.dataset.INDEX_ULB]
                idxs_q_makeup = np.random.choice(idxs_u, n - len(idxs_q), replace=False)
                self.dataset.update_lb(idxs_q_makeup)
        else:
            idxs_q = self.query(n)
            self.dataset.update_lb(idxs_q)
        
        if self.args.aug_lab > 0:
            if self.args.aug_ratio > 0:
                self.generate_aug(self.args.aug_ratio, aug_intensity=self.args.aug_lab,
                                  aug_type=self.args.ablation_aug_type)
            if self.args.mixup_ratio > 0:
                self.generate_mixup(self.args.mixup_ratio, fix_lam=self.args.aug_lab * 0.1)
        return idxs_q

    def find_idxs_q(self, idxs_temp_q):
        
        
        
        idxs_q = []
        data_u_all = deepcopy(self.dataset.DATA_INFOS['train_u'])
        if 'train_u_aug_single' in self.dataset.DATA_INFOS.keys():
            data_u_all.extend(self.dataset.DATA_INFOS['train_u_aug_single'])
        if 'train_u_aug_mixup' in self.dataset.DATA_INFOS.keys():
            data_u_all.extend(self.dataset.DATA_INFOS['train_u_aug_mixup'])
        for idx_temp in idxs_temp_q:
            data_u_item = data_u_all[idx_temp]
            if 'img' in data_u_item.keys():  
                idxs_q.append(data_u_item['no'])
                continue
            if 'aug_transform' in data_u_item.keys():  
                idxs_q.append(
                    self.dataset.DATA_INFOS[data_u_item['split']][data_u_item['idx']]['no'])
                continue
            if 'idx_a' in data_u_item.keys():  
                idxs_q.extend([
                    self.dataset.DATA_INFOS[data_u_item['split']][data_u_item['idx_a']]['no'],
                    self.dataset.DATA_INFOS[data_u_item['split']][data_u_item['idx_b']]['no']
                ])
                continue
        idxs_q = list(set(idxs_q))
        return np.array(idxs_q)

    def _train(self, loader_tr, clf_group: dict, clf_name='train', soft_target=False, log_show=True):
        """Represents one epoch.

        :param loader_tr: (:obj:`torch.utils.data.DataLoader`) The training data wrapped in DataLoader.

        Accuracy and loss in the each iter will be recorded.

        """
        iter_out = self.args.out_iter_freq
        loss_list = []
        right_count_list = []
        samples_per_batch = []
        clf_group['clf'].train()
        for batch_idx, (x, y, _, _) in enumerate(loader_tr):
            x, y = x.cuda(), y.cuda()
            clf_group['optimizer'].zero_grad()
            out, _, _ = self.clf(x)
            if soft_target:
                pred = F.log_softmax(out, dim=1)
                pred_compare = out.max(1)[1]
                y_compare = y.max(1)[1]
                right_count_list.append((pred_compare == y_compare).sum().item())
                samples_per_batch.append(len(y))
                criterion = torch.nn.KLDivLoss()
                loss = criterion(pred, y)
            else:
                pred = out.max(1)[1]
                right_count_list.append((pred == y).sum().item())
                samples_per_batch.append(len(y))
                loss = F.cross_entropy(out, y)
            loss_list.append(loss.item())
            loss.backward()
            clf_group['optimizer'].step()
            iter_time = self.timer.since_last_check()
            if log_show:
                if (batch_idx + 1) % iter_out == 0:
                    log_dict = dict(
                        mode=clf_name,   
                        epoch=self.epoch,   
                        iter=batch_idx + 1,   
                        lr=get_lr(clf_group['optimizer']),   
                        time=iter_time,   
                        acc=1.0 * np.sum(right_count_list[-iter_out:]) / np.sum(samples_per_batch[-iter_out:]),
                        loss=np.sum(loss_list[-iter_out:])
                    )
                    self.TextLogger.log(
                        log_dict=log_dict,
                        iters_per_epoch=len(loader_tr),
                        iter_count=self.epoch * len(loader_tr) + batch_idx,
                        max_iters=self.args.n_epoch * len(loader_tr),  
                        interval=iter_out   
                    )
        clf_group['scheduler'].step()

    def get_train_info(self):
        split = ['train']
        to_soft = [True if self.args.mixup_ratio > 0 else False]
        loader_name = ['base']
        repeat = [self.args.duplicate_ratio]

        if self.args.aug_ratio > 0 and self.args.aug_lab > 0:
            if self.cycle == 0:
                self.generate_aug(self.args.aug_ratio, aug_intensity=self.args.aug_lab,
                                  aug_type=self.args.ablation_aug_type)
            split.append('train_aug_single')
            to_soft.append(True if self.args.mixup_ratio > 0 else False)
            loader_name.append('aug')
            repeat.append(self.args.duplicate_ratio)
        if self.args.mixup_ratio > 0 and self.args.aug_lab > 0:
            if self.cycle == 0:
                self.generate_mixup(self.args.mixup_ratio)
            split.append('train_aug_mixup')
            to_soft.append(False)
            loader_name.append('mixup')
            repeat.append(self.args.duplicate_ratio)
        return split, to_soft, loader_name, repeat

    def train(self):
        self.logger.info('Start running, host: %s, work_dir: %s',
                         f'{getuser()}@{gethostname()}', self.args.work_dir)
        self.logger.info('max: %d epochs', self.args.n_epoch)
        self.clf.train()
        split, to_soft, loader_name, repeat = self.get_train_info()

        
        dataset_tr = GetHandler(self.dataset, split=split, repeat=repeat, to_soft=to_soft, loader_name=loader_name)
        
        loader_tr = DataLoader(dataset_tr, shuffle=True, batch_size=self.args.batch_size,
                               num_workers=self.args.num_workers)

        while self.epoch < self.args.n_epoch:
            self.timer.since_last_check()
            self._train(loader_tr, {'clf': self.clf, 'optimizer': self.optimizer, 'scheduler': self.scheduler},
                        soft_target=True if (self.args.mixup_ratio > 0) else False)
            if self.epoch % self.args.save_freq == 0 and self.epoch > 0:
                pass
                
            self.epoch += 1
        self.epoch = 0
        

    def run(self):
        while self.cycle < self.args.n_cycle:
            active_path = os.path.join(self.args.work_dir, f'active_round_{self.cycle}')
            os.makedirs(active_path, mode=0o777, exist_ok=True)
            num_labels = len(self.dataset.DATA_INFOS['train'])
            self.logger.info(f'Active Round {self.cycle} with {num_labels} labeled instances')
            active_meta_log_dict = dict(
                mode='active_meta',
                cycle=self.cycle,
                num_labels=num_labels,
                idxs_queried=list(self.dataset.QUERIED_HISTORY[-1]),
                idxs_lb=list(np.arange(len(self.dataset.DATA_INFOS['train_full']))[self.dataset.INDEX_LB])
            )
            self.TextLogger._dump_log(active_meta_log_dict)
            if not self.args.updating:
                self.init_clf()
            self.train()
            self.acc_val_list.append(self.predict(self.clf, split='val'))
            self.acc_test_list.append(self.predict(self.clf, split='test'))
            self.num_labels_list.append(num_labels)
            
            self.update(self.args.num_query)  
            self.cycle += 1
        self.record_test_accuracy()

    def predict(self, clf, split='train', metric='accuracy',
                topk=None, n_drop=None, thrs=None, dropout_split=False, log_show=True):
        
        
        
        
        
        
        
        
        if type(split) in [list, tuple, set]:
            
            repeat = [1 for _ in split]
        else:
            repeat = 1
        split, loader_name, to_soft = generate_split_info(split)
        dataset = GetHandler(self.dataset, split=split, repeat=repeat,
                             to_soft=to_soft, loader_name=loader_name)
        
        loader = DataLoader(dataset, shuffle=False, batch_size=self.args.batch_size,
                            num_workers=self.args.num_workers)

        if isinstance(clf, torch.nn.Module):
            clf.eval()
        if n_drop is None:
            n_drop = 1
        if topk is None:
            topk = 1
        if thrs is None:
            thrs = 0.
        if metric in ['accuracy', 'precision', 'recall', 'f1_score', 'support']:
            
            self.logger.info(f"Calculating Performance with {metric}...")
            pred = []
            target = []
            with torch.no_grad():
                for x, y, _, idxs in track_iter_progress(loader):
                    x, y = x.cuda(), y.cuda()
                    if isinstance(clf, torch.nn.Module):
                        out, _, _ = clf(x)
                    else:
                        out = clf(x)
                    prob = F.softmax(out, dim=1)
                    pred.append(prob)
                    target.append(y)
            pred = torch.cat(pred).cuda()
            target = torch.cat(target).cuda()
            if metric == 'accuracy':
                result = accuracy(pred, target, topk, thrs)
            elif metric == 'precision':
                result = precision(pred, target, thrs=thrs)
            elif metric == 'recall':
                result = recall(pred, target, thrs=thrs)
            elif metric == 'f1_score':
                result = f1_score(pred, target, thrs=thrs)
            elif metric == 'support':
                result = support(pred, target)
            else:
                raise Exception(f"Metric {metric} not implemented!")
            if len(result) == 1:
                result = result.item()
            else:
                result = result.numpy().tolist()
            if log_show:
                log_dict = dict(mode=split, cycle=self.cycle)
                log_dict[metric] = result
                self.TextLogger.log(log_dict)
        else:  
            self.logger.info(f"Calculating Informativeness with {metric}...")
            if isinstance(clf, torch.nn.Module):
                clf.train()
            if dropout_split is False:
                pred = []
                for i in range(n_drop):
                    self.logger.info('n_drop {}/{}'.format(i + 1, n_drop))
                    with torch.no_grad():
                        for batch_idx, (x, _, _, _) in enumerate(track_iter_progress(loader)):
                            x = x.cuda()
                            if isinstance(clf, torch.nn.Module):
                                out, _, _ = clf(x)
                            else:
                                out = clf(x)
                            if i == 0:
                                pred.append(F.softmax(out, dim=1))
                            else:
                                pred[batch_idx] += F.softmax(out, dim=1)
                pred = torch.cat(pred).cuda()
                
                pred /= n_drop
                if metric == 'entropy':
                    log_pred = torch.log(pred)
                    
                    result = - (pred * log_pred).sum(1)
                elif metric == 'lc':
                    
                    result = pred.max(1)[0]
                elif metric == 'margin':
                    
                    pred_sorted, _ = pred.sort(descending=True)
                    result = pred_sorted[:, 0] - pred_sorted[:, 1]
                elif metric == 'prob':
                    result = pred
                else:
                    raise Exception(f"Metric {metric} not implemented!")
            else:
                print("No metric will be used in dropout split mode!")
                if type(split) == str:
                    data_length = len(self.dataset.DATA_INFOS[split])
                else:
                    data_length = sum([len(self.dataset.DATA_INFOS[split_elem]) for split_elem in split])
                result = torch.zeros([n_drop, data_length, len(self.dataset.CLASSES)]).cuda()
                for i in range(n_drop):
                    self.logger.info('n_drop {}/{}'.format(i + 1, n_drop))
                    with torch.no_grad():
                        for x, _, _, idxs in track_iter_progress(loader):
                            x = x.cuda()
                            if isinstance(clf, torch.nn.Module):
                                out, _, _ = clf(x)
                            else:
                                out = clf(x)
                            result[i][idxs] += F.softmax(out, dim=1)
        
        return result
        

    def get_embedding(self, clf, split='train', embed_type='default'):
        
        
        if type(split) in [list, tuple, set]:
            
            repeat = [1 for _ in split]
        else:
            repeat = 1
        split, loader_name, to_soft = generate_split_info(split)
        dataset = GetHandler(self.dataset, split=split, repeat=repeat,
                             to_soft=to_soft, loader_name=loader_name)
        
        loader = DataLoader(dataset, shuffle=False, batch_size=self.args.batch_size,
                            num_workers=self.args.num_workers)

        clf.eval()
        self.logger.info(f"Extracting embedding of type {embed_type}...")
        embdim = self.get_embedding_dim()
        nlabs = len(self.dataset.CLASSES)
        if embed_type == 'default':
            embedding = []
            with torch.no_grad():
                for x, _, _, idxs in track_iter_progress(loader):
                    x = x.cuda()
                    _, e1, _ = clf(x)
                    embedding.append(e1)
            embedding = torch.cat(embedding).cuda()
        elif embed_type == 'grad':
            if type(split) == str:
                data_length = len(self.dataset.DATA_INFOS[split])
            else:
                data_length = sum([len(self.dataset.DATA_INFOS[split_elem]) for split_elem in split])
            embedding = np.zeros([data_length, embdim * nlabs])
            for x, y, _, idxs in loader:
                x = x.cuda()
                cout, e, _ = clf(x)
                out = e.data.cpu().numpy()
                batchProbs = F.softmax(cout, dim=1).data.cpu().numpy()
                maxInds = np.argmax(batchProbs, 1)
                for j in range(len(y)):
                    for c in range(nlabs):
                        if c == maxInds[j]:
                            embedding[idxs[j]][embdim * c: embdim * (c + 1)] = \
                                deepcopy(out[j]) * (1 - batchProbs[j][c])
                        else:
                            embedding[idxs[j]][embdim * c: embdim * (c + 1)] = \
                                    deepcopy(out[j]) * (-1 * batchProbs[j][c])
            return torch.Tensor(embedding)
        else:
            raise Exception(f'Embedding of type {embed_type} not implemented!')
        return embedding

    def save(self):
        """Save the current model parameters."""
        model_out_path = Path(os.path.join(self.args['work_dir'], f'active_round_{self.cycle}'))
        state = self.clf.state_dict(),
        if not model_out_path.exists():
            model_out_path.mkdir()
        save_target = model_out_path / f"active_round_{self.cycle}-" \
                                       f"label_num_{np.sum(self.idxs_lb).item()}-epoch_{self.epoch}.pth"
        torch.save(state, save_target)

        self.logger.info('==> save model to {}'.format(save_target))

    def get_embedding_dim(self) -> int:
        loader = GetDataLoader(self.dataset, split='train',
                               shuffle=False,
                               batch_size=self.args.batch_size,
                               num_workers=self.args.num_workers)
        self.clf.eval()
        with torch.no_grad():
            for x, _, _, _ in loader:
                x = x.cuda()
                _, e1, _ = self.clf(x)
                return e1.shape[1]

    def generate_aug(self, aug_ratio: int = 1, split='train', aug_intensity=2, aug_type='AutoContrast'):
        
        aug_split_name = split + '_aug_single'
        self.dataset.DATA_INFOS[aug_split_name] = []
        for _ in range(aug_ratio):
            self.dataset.DATA_INFOS[aug_split_name].extend([
                dict(split=split, idx=i, aug_transform=AblationAugmentPool(n=aug_intensity, m=10, aug_type=aug_type))
                for i in range(len(self.dataset.DATA_INFOS[split]))])
        
        self.dataset.TRANSFORM[aug_split_name] = self.dataset.TRANSFORM['val']

    def del_aug(self, split='train'):
        if split + '_aug_single' in self.dataset.DATA_INFOS.keys():
            del self.dataset.DATA_INFOS[split + '_aug_single']

    def generate_mixup(self, aug_mix_ratio: int = 1, split='train', fix_lam=None):
        
        aug_split_name = split + '_aug_mixup'
        self.dataset.DATA_INFOS[aug_split_name] = []
        for _ in range(aug_mix_ratio):
            indices_rand_perm = list(range(len(self.dataset.DATA_INFOS[split])))
            random.shuffle(indices_rand_perm)
            if fix_lam is None:
                self.dataset.DATA_INFOS[aug_split_name].extend([
                    dict(split=split, idx_a=i, idx_b=indices_rand_perm[i], lam=np.random.beta(1.0, 1.0))
                    for i in range(len(self.dataset.DATA_INFOS[split]))])
            else:
                self.dataset.DATA_INFOS[aug_split_name].extend([
                    dict(split=split, idx_a=i, idx_b=indices_rand_perm[i], lam=fix_lam)
                    for i in range(len(self.dataset.DATA_INFOS[split]))])
        self.dataset.TRANSFORM[aug_split_name] = self.dataset.TRANSFORM['val']
        return indices_rand_perm

    def del_mixup(self, split='train'):
        if split + '_aug_mixup' in self.dataset.DATA_INFOS.keys():
            del self.dataset.DATA_INFOS[split + '_aug_mixup']

    def record_test_accuracy(self):
        file_name = os.path.join(self.args.work_dir, 'accuracy.csv')
        header = ['num_labels', 'accuracy']
        with open(file_name, 'w', newline='') as f:
            f_csv = csv.writer(f)
            f_csv.writerow(header)
            for i, acc in enumerate(self.acc_test_list):
                f_csv.writerow([(i + 1) * self.args.num_query, acc])
