import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import random

from sklearn.model_selection import StratifiedKFold

from utils import cos_mat, euc_mat, tovec# , track_grad_

class NearestNeighborClassifier:
    def __init__(self, distance_fn, k=1):
        if distance_fn == 'cosine':
            self.distance_fn = lambda x,y: -cos_mat(x,y)
        elif distance_fn == 'euclidean':
            self.distance_fn = euc_mat
        else:
            raise ValueError('Unrecognized distance')
        
        self.k = k
        
    def fit(self, x_train, y_train):
        self.x_train = x_train
        self.y_train = y_train
        
    def score(self, x_test, y_test):
        ds = self.distance_fn(self.x_train, x_test)
        ps = self.y_train[ds.min(dim=0)[1]]
        acc = (y_test == ps).float().mean().item()
        return acc
    
    def predict(self, x):
        ds = self.distance_fn(self.x_train, x)
        pred = self.y_train[ds.min(dim=0)[1]]
        return pred
    
class LogisticRegression(nn.Module):
    """A bare bones PyTorch logistic regression module. Uses LBFGS optimizer"""
    def __init__(self, C=1.0, max_iter=300):
        super().__init__()
        self.C = C
        self.max_iter = max_iter
        self.model = None
        
    def fit(self, X, y):
        """ Fits the linear model with LBFGS
        X: tensor of shape (n_samples, n_features)
        y: tensor of shape (n_samples,) (y[i] must be an integer in [0,1,2,3,...,n_classes-1])
        """
        # init linear model
        n_classes = len(y.unique())
        self.model = nn.Linear(X.shape[1],n_classes,bias=True).to(X.device)
        self.model.weight.data.zero_()
        self.model.bias.data.zero_()
        
        # init lbgfs
        opt = torch.optim.LBFGS(self.model.parameters(), max_iter=self.max_iter)
        def closure():
            opt.zero_grad()
            loss = 1/(2*X.shape[0]) * (self.model.weight**2).sum() + self.C*F.cross_entropy(self.model(X),y)
            # loss = 1/2 * (self.model.weight**2).sum() + self.C*F.cross_entropy(self.model(X),y)

            loss.backward()
            return loss
        
        # run lbfgs
        opt.step(closure)
        
    def score(self, X, y):
        """ Compute the accuracy """
        pred = self.model(X).max(dim=1)[1]
        return (pred == y).float().mean().item()
    
    def predict(self, X):
        return self.model(X).max(dim=1)[1]
    
class LogisticRegressionCV(nn.Module):
    def __init__(self, Cs=None, max_iter=300, n_splits=5, retrain=True, verbose=0):
        super().__init__()
        if Cs is None: Cs = torch.logspace(-4,4,10)
        
        self.Cs = Cs
        self.max_iter = max_iter
        self.n_splits = n_splits
        self.retrain = retrain
        self.verbose = verbose
        self.clf = None
        
    def fit(self, X, y):
        # checks
        assert(len(X.shape) == 2)
        assert(len(y.shape) == 1)
        assert(X.shape[0] == y.shape[0])

        unq, cts = y.unique(return_counts=True)
        assert(y.min() == 0)
        assert(len(unq)-1 == y.max())
        assert(cts.min() >= 5)

        # split data
        skf = StratifiedKFold(n_splits=self.n_splits)
        train_idx, test_idx = next(skf.split(y.cpu(),y.cpu()))

        # train all Cs
        scores = []
        clfs = []
        for i, C_ in enumerate(self.Cs):
            clf = LogisticRegression(C=C_,max_iter=self.max_iter)
            clf.fit(X[train_idx], y[train_idx])
            score = clf.score(X[test_idx], y[test_idx])
            scores.append(score)
            clfs.append(clf)
            
            if self.verbose:
                print('iter: {}/{}, C: {:.2e}, score: {:.3f}'.format(i+1,len(self.Cs),C_, score))
            
        self.scores = scores
        self.best_C = self.Cs[np.argmax(self.scores)]
        
        if self.verbose:
            print('best C: {:.2e}, best score: {:.3f}'.format(self.best_C, max(self.scores)))
        
        # retrain using full
        if self.retrain:
            if self.verbose: print('retraining with all data')
            self.clf = LogisticRegression(C=self.best_C)
            self.clf.fit(X,y)
        else: 
            self.clf = clfs[np.argmax(self.scores)]
    
    def predict(self, X):
        return self.clf.predict(X)
    
    def score(self, X, y):
        return self.clf.score(X,y)
    
class Evaluator:
    def __init__(self, clf, nlabels_per_class=['ALL'], ntrials_per_config=[1],
                verbose=0, seed=None):
        if clf == 'nearest_neighbor':
            clf_constructor = lambda: NearestNeighborClassifier('cosine')
        elif clf == 'logistic_regression':
            clf_constructor = lambda: LogisticRegressionCV(
                verbose=(verbose==3), max_iter=100, retrain=False)
        else:
            clf_constructor = clf
            clf = 'unknown clf'
            
        # else:
            # raise ValueError('Unrecognized classifier: {}'.format(clf))
                             
        self.clf_constructor = clf_constructor
        self.clf_name = clf
        
        self.nlabels_per_class = nlabels_per_class
        if len(ntrials_per_config) == 1:
            ntrials_per_config *= len(nlabels_per_class)
        else:
            assert(len(ntrials_per_config) == len(nlabels_per_class))
        self.ntrials_per_config = ntrials_per_config
        
        self.verbose = verbose
        self.seed = seed
    
    def evaluate(self, x_train, labels_train, x_test, labels_test):
        """Returns
            average_score_per_label
            stanard_error_per_label = std / sqrt(ntrials)
            info: other useful stuff
        """
        # preprocess inputs
        x_train = tovec(x_train)
        x_test = tovec(x_test)
        
        if self.seed is not None: torch.random.manual_seed(self.seed)

        info = {}
        t0 = time.time()
        n_classes = len(labels_train.unique())
        for nlabels, ntrials in zip(self.nlabels_per_class,self.ntrials_per_config):
            info[nlabels] = {'scores': [], 'avg_score': None}
            for trial in range(ntrials):
                # select data
                if nlabels == 'ALL':
                    ixs = torch.randperm(x_train.shape[0])
                else:
                    ixs = torch.zeros(nlabels*n_classes,dtype=int, device=x_train.device)
                    for i in range(n_classes):
                        ixs_ = torch.where(labels_train==i)[0]
                        ixs[i*nlabels:(i+1)*nlabels] = ixs_[torch.randperm(len(ixs_))[:nlabels]]

                # train & score classifier
                clf = self.clf_constructor()
                clf.fit(x_train[ixs],labels_train[ixs])
                score = clf.score(x_test,labels_test)
                info[nlabels]['scores'].append(score)

            # average results
            info[nlabels]['avg_score'] = np.mean(info[nlabels]['scores'])
            info[nlabels]['std_score'] = np.std(info[nlabels]['scores'])
            info[nlabels]['err_score'] = np.std(info[nlabels]['scores']) / np.sqrt(ntrials)

            # print
            if self.verbose:
                if nlabels == 'ALL': nlabelsp = len(labels_train) // n_classes
                else: nlabelsp = nlabels
                print('{}, nlabels: {}, ntrials: {}, avg_score: {:.4f}, err_score: {:.4f}, time: {:.2f}s'.format(
                        self.clf_name, nlabelsp, ntrials, info[nlabels]['avg_score'], info[nlabels]['err_score'], time.time()-t0))
                   
        # return
        # avg_scores = [info[nlabels]['avg_score'] for nlabels in self.nlabels_per_class]
        # std_scores = [info[nlabels]['std_score'] for nlabels in self.nlabels_per_class]
        # err_scores = [info[nlabels]['err_score'] for nlabels in self.nlabels_per_class]
        
        avg_scores = np.array([info[nlabels]['avg_score'] for nlabels in self.nlabels_per_class])
        std_scores = np.array([info[nlabels]['std_score'] for nlabels in self.nlabels_per_class])
        err_scores = np.array([info[nlabels]['err_score'] for nlabels in self.nlabels_per_class])

        return avg_scores, err_scores, info
    
class OmniglotEvaluator:
    def __init__(self, n_way=20, n_shot=1, n_trials=1000, within_class=True):
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_trials = n_trials
        self.within_class = within_class
    
    def evaluate(self, x, labels):
        x = tovec(x)
        scores = []
        trials = self._generate_trials(labels)
        for trial in trials:
            train_idx, test_idx = trial
            x_train = x[train_idx]
            labels_train = torch.tensor([labels[idx][1] for idx in train_idx]).to(x.device)
            
            x_test = x[test_idx]
            labels_test = torch.tensor([labels[idx][1] for idx in test_idx]).to(x.device)
            
            clf = NearestNeighborClassifier('cosine')
            clf.fit(x_train, labels_train)
            scores.append(clf.score(x_test, labels_test))
            
        return np.mean(scores), np.std(scores) / np.sqrt(len(scores))
        
    def _generate_trials(self, labels):
        # create map
        alph_ids, char_ids, alph_char_map, char_idx_map = set(), set(), dict(), dict()
        for idx, (alph_id, char_id) in enumerate(labels):
            alph_id, char_id = alph_id.item(), char_id.item()
            
            # map from alphabet to char_id
            map_ = alph_char_map.get(alph_id, set()) 
            map_.add(char_id)
            alph_char_map[alph_id] = map_

            # map from char_id to real_id
            char_idx_map[char_id] = char_idx_map.get(char_id, []) + [idx] # map from characters to order in the labels

            # list of all possible char_ids
            char_ids.add(char_id)
            alph_ids.add(alph_id)

        char_ids = list(char_ids)
        alph_ids = list(alph_ids)
        for k,v in alph_char_map.items():
            alph_char_map[k] = list(v)

        # sample
        trials = []
        for trial in range(self.n_trials):
            # sample n_way unique characters for training
            if self.within_class:
                alph_id = random.choice(alph_ids)
                train_char_ids = random.sample(alph_char_map[alph_id], self.n_way)
            else:
                train_char_ids = random.sample(char_ids, self.n_way)

            # sample n_way unique instances for each character
            train_ids = []
            for char_id in train_char_ids:
                for idx in random.sample(char_idx_map[char_id], self.n_shot):
                    train_ids.append(idx)

            # sample a test character and corresponding samples
            test_char_id = random.choice(train_char_ids)
            test_ids = [idx for idx in char_idx_map[test_char_id] if idx not in train_ids]

            trials.append([train_ids, test_ids])

        return trials

# class NearestNeighborClassifier:
#     def __init__(self, distance_fn='cosine', chunk_size=1024, device='cpu'):
#         if distance_fn == 'cosine':
#             self.distance_fn = lambda x,y: -cos_mat(x,y)
#         elif distance_fn == 'euclidean':
#             self.distance_fn = euc_mat
#         else:
#             raise ValueError('Unrecognized distance')
            
#         self.chunk_size = chunk_size
#         self.device = device
        
#     def fit(self, x_train, y_train):
#         assert(len(x_train.shape)==2)
#         self.x_train = x_train
#         self.y_train = y_train
        
#     def score(self, x_test, y_test):
#         assert(len(x_test.shape)==2)
#         ds = torch.zeros(self.x_train.shape[0], x_test.shape[0], dtype=x_test.dtype, device=x_test.device)
#         for i, x_train_ in enumerate(torch.split(self.x_train,self.chunk_size)):
#             x_train_ = x_train_.to(self.device)
#             for j, x_test_ in enumerate(torch.split(x_test,self.chunk_size)):
#                 x_test_ = x_test_.to(self.device)
#                 ds_ = self.distance_fn(x_train_, x_test_).to(x_test.device)
#                 ds[i*self.chunk_size:(i+1)*self.chunk_size, j*self.chunk_size:(j+1)*self.chunk_size] = ds_

#         ps = self.y_train[ds.min(dim=0)[1]]
#         acc = (y_test == ps).float().mean().item()
#         return acc

class GridSearchCV:
    def __init__(self, clf, param_grid, retrain=True, verbose=1, n_splits=5):
        self.clf = clf
        self.param_grid = param_grid
        self.retrain = retrain
        self.verbose = verbose
        self.n_splits = n_splits
        
    def fit(self, X, y):
        # split 
        skf = StratifiedKFold(n_splits=self.n_splits)
        train_idx, test_idx = next(skf.split(y.cpu(),y.cpu()))

        # train all evaluate all grid points
        scores = []
        clfs = []
        for i, C_ in enumerate(self.param_grid):
            clf = self.clf(alpha=C_)
            clf.fit(X[train_idx], y[train_idx])
            score = clf.score(X[test_idx], y[test_idx])
            scores.append(score)
            clfs.append(clf)
            
            if self.verbose:
                print('iter: {}/{}, C: {:.2e}, score: {:.3f}'.format(i+1,len(self.param_grid),C_, score))
            
        self.scores = scores
        self.best_C = self.param_grid[np.argmax(self.scores)]
        
        if self.verbose:
            print('best C: {:.2e}, best score: {:.3f}'.format(self.best_C, max(self.scores)))
        
        # retrain using full
        if self.retrain:
            if self.verbose: print('retraining with all data')
            self.clf = LogisticRegression(C=self.best_C)
            self.clf.fit(X,y)
        else: 
            self.clf = clfs[np.argmax(self.scores)]
            
class MLP(nn.Module):
    """A bare bones PyTorch MLP. Uses LBFGS optimizer"""
    def __init__(self, hidden_layer_sizes=(100,), alpha=.010, max_iter=300):
        super().__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.alpha = alpha
        self.max_iter = max_iter
        self.model = None
        
    def fit(self, X, y):
        """ Fits the linear model with LBFGS
        X: tensor of shape (n_samples, n_features)
        y: tensor of shape (n_samples,) (y[i] must be an integer in [0,1,2,3,...,n_classes-1])
        """
        # init linear model
        n_inputs = X.shape[1]
        n_classes = len(y.unique())
        sizes = [n_inputs] + list(self.hidden_layer_sizes)
        
        layers = []
        for i,o in zip(sizes[:-1],sizes[1:]):
            layers.append(nn.Linear(i,o))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(sizes[-1],n_classes))
        model = nn.Sequential(*layers)
        self.model = model.to(X.device)
        
        # init lbgfs
        opt = torch.optim.LBFGS(self.model.parameters(), max_iter=self.max_iter)
        def closure():
            opt.zero_grad()
            weights = self.get_weights_()
            wd = sum((w**2).sum() for w in weights)
            loss = self.alpha/2 * wd + F.cross_entropy(self.model(X),y)
            # loss = 1/2 * (self.model.weight**2).sum() + self.C*F.cross_entropy(self.model(X),y)

            loss.backward()
            return loss
        
        # run lbfgs
        opt.step(closure)
        
    def score(self, X, y):
        """ Compute the accuracy """
        pred = self.model(X).max(dim=1)[1]
        return (pred == y).float().mean().item()
    
    def predict(self, X):
        return self.model(X).max(dim=1)[1]
    
    def get_weights_(self):
        return [p for n,p in self.model.named_parameters() if n.endswith('.weight')]   