#CREATE CLASSIFICATION HEAD
from sklearn.linear_model import LogisticRegression
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import numpy as np
from functools import reduce

t1 = pd.read_csv('timm_models_raw.csv', header=None)
t2 = pd.read_csv('timm_models2_raw.csv', header=None)
t1['mean']=t1.mean(axis=1)
t2['mean']=t2.mean(axis=1)
t_dims=pd.read_csv('timm_model_dims.csv', header=None)
t1=t1.sort_values(by=[0])
t2=t2.sort_values(by=[0])
t_dims=t_dims.sort_values(by=[0])
timm_means=np.mean([t1['mean'], t2['mean']], axis=0)
timm_names=[t[5:] for t in t1[0].values]
timm_dims=t_dims[1].values

t3=pd.DataFrame([timm_names, timm_means, timm_dims]).T
t3.columns=['model', 'mean', 'dims']
t3['mean']=t3['mean'].astype('float64')*100
t3['dims']=t3['dims'].astype('float64')

t3=t3[t3['mean']>0]

batch_size=1000
dataset='mnist'
embed_path=f'/data/HAMFSL/embeds/{dataset}/'

class OneLayerNet(nn.Module):
    def __init__(self, in_size):
        super(OneLayerNet, self).__init__()
        self.fc1 = nn.Linear(in_size, 100)  
        self.dropout = nn.Dropout(p=0.8)
    def forward(self, x):
        x = self.fc1(self.dropout(x))
        return x
    
class TwoLayerNet(nn.Module):
    def __init__(self, in_size):
        super(TwoLayerNet, self).__init__()
        self.fc1 = nn.Linear(in_size, in_size) 
        self.fc2 = nn.Linear(in_size, 100) 
        self.dropout = nn.Dropout(p=0.8)
    def forward(self, x):
        x=self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    
#TRAIN AND EVALUATION EXP MODULE
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm



def experiment(heads, embed_model, n=None, epochs=10, verbose=False, embed_path=embed_path):
    embeds = np.load(embed_path+embed_model+'.npy')
    targets = np.load(embed_path+embed_model+'_targets.npy')
    embed_list=[]
    target_list=[]
    data_train=[]
    data_test=[]
    labels_train=[]
    labels_test=[]
    for i in range(100): #max(targets)
        embed_list.append(embeds[np.where(targets==i)])
        target_list.append(targets[np.where(targets==i)])
        data_train_temp, data_test_temp, labels_train_temp, labels_test_temp = train_test_split(embed_list[-1], target_list[-1], test_size=1-(100*n/len(targets)), random_state=1)
        data_train.append(data_train_temp)
        data_test.append(data_test_temp)
        labels_train.append(labels_train_temp)
        labels_test.append(labels_test_temp)
    data_train=np.concatenate(data_train,axis=0)
    data_test=np.concatenate(data_test,axis=0)
    labels_train=np.concatenate(labels_train,axis=0)
    labels_test=np.concatenate(labels_test,axis=0)
    results=[]
    for head in tqdm(heads):
        
        if 'Net' in head:
            trainset=TensorDataset(torch.Tensor(data_train),torch.Tensor(labels_train).to(torch.int64))
            testset=TensorDataset(torch.Tensor(data_test),torch.Tensor(labels_test).to(torch.int64))
            criterion=nn.CrossEntropyLoss()
            net=eval(head)(embeds.shape[-1])
            optimizer = optim.Adam(net.parameters(), lr=0.002)
            train_loader = DataLoader(trainset, batch_size, shuffle=True)
            test_loader = DataLoader(testset, batch_size, shuffle=True)
            for epoch in range(epochs):
                net.train()
                running_loss = 0.0
                running_acc=0.0
                for i, (ins,targs) in enumerate(train_loader):
                    # in your training loop:
                    optimizer.zero_grad()   # zero the gradient buffers
                    output = net(ins)
                    loss = criterion(output, targs)
                    loss.backward()
                    optimizer.step()    # Does the update
                    running_acc += torch.sum(torch.argmax(output,axis=-1) == targs)
                    running_loss += loss.item()
                if verbose:
                    print(f'[{epoch + 1}] loss: {running_loss:.3f} train_acc: {running_acc/len(labels_train)}')
                net.eval()
                test_loss = 0.0
                test_acc=0.0
                for i, (ins,targs) in enumerate(test_loader):
                    output = net(ins)
                    loss = criterion(output, targs)
                    test_acc += torch.sum(torch.argmax(output,axis=-1) == targs)
                    test_loss += loss.item()
                if verbose and epoch%5==4:
                    print(f'[{epoch + 1}] test_loss: {test_loss/len(labels_test):.3f} test_acc: {test_acc/len(labels_test)}')
            final_loss=test_loss/len(labels_test)
            final_acc=test_acc.numpy()/len(labels_test)
            train_acc=running_acc.numpy()/len(labels_train)
        else:
            m=eval(head)(max_iter=600,solver='lbfgs',n_jobs=-1,multi_class='ovr')
            m.fit(data_train,labels_train)
            train_acc=m.score(data_train,labels_train)
            final_acc=m.score(data_test,labels_test)
            if verbose:
                print(final_acc)
        results.append(train_acc)
        results.append(final_acc)
        
    return results

#RUN EXP MODULE ON EACH SET OF EMBEDDINGS
import os
cache=os.listdir('results/')
model_cache=os.listdir(embed_path)

# all_rows=[]
for m in tqdm(t3.model.values):  
    print(m)
    if f'{m}.npy' not in model_cache:
        print('Excluded model.')
        continue
    if f'{m}_targets.npy' not in model_cache:
        print('Bad embeds.')
        continue
    if f'{dataset}_FSL_results_{m}.npy' in cache:
        print('Already processed.')
        continue
    
    rows=[]
    print(m)
    for n in tqdm([1,5,10,20,40,80]):
        # print(n)
        accs=experiment(['LogisticRegression'], m, n=n, epochs=40, embed_path=embed_path)
        row=[m,n,*accs]
        print(row)
        rows.append(row)
    np.save(f'results/{dataset}_FSL_results_{m}', rows)
    # all_rows+=rows
