import pyrootutils
root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git"],
    pythonpath=True,
    dotenv=True,
)

import os
from tqdm import tqdm
import pandas as pd
import torch
from omegaconf import DictConfig, OmegaConf
import hydra
from hydra.utils import instantiate
from lightning.pytorch import seed_everything
from lightning.pytorch.utilities import rank_zero_only

import einops as E
import numpy as np

from src.utils.animate import load_config_and_model
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn import preprocessing


def get_features(dataloader, model, device='cuda:0', epoch=1):
    all_features = []
    all_labels = []
    with torch.no_grad():
        for _ in range(epoch):
            for batch in tqdm(dataloader):
                (fullseq, actions, _) = model.preprocess(batch)
                fullseq = fullseq.to(device)

                latents = model.encode(fullseq)
                latents = model.evolver.actionable(latents)
                latents, deltas = model.evolver.alignment(latents)

                deltas = E.rearrange(deltas, 'seq batch slot nacts dim -> (seq batch) (slot nacts dim)')
                all_features.append(deltas)
                
                actions = E.rearrange(actions[:-1], 'seq batch -> (seq batch)')
                all_labels.append(actions)

    return torch.cat(all_features).cpu().numpy(), \
           torch.cat(all_labels).cpu().numpy()

def linear_probe(train_loader, test_loader, model, device, train_epoch, C=1.0):
    # Calculate the image features
    model.to(device)
    train_features, train_labels = get_features(train_loader, model, device, train_epoch)
    test_features, test_labels = get_features(test_loader, model, device)

    scaler = preprocessing.StandardScaler().fit(train_features)
    train_features = scaler.transform(train_features)
    test_features  = scaler.transform(test_features)

    classifier = LogisticRegression(random_state=0, C=C, max_iter=1000)
    classifier.fit(train_features, train_labels)
    predictions = classifier.predict(test_features)
    accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
    return accuracy

@hydra.main(version_base=None, config_path='../config', config_name='eval')
def main(config: DictConfig):

    config_model, model = load_config_and_model(config.model_path, finetune=True)
    OmegaConf.set_struct(config, None)
    OmegaConf.set_struct(config_model, None)

    config = OmegaConf.merge(config_model, config)

    # trainer = instantiate(config.trainer)
    # if trainer.global_rank == 0:
    #     print(OmegaConf.to_yaml(config))

    if config.seed is not None:
        seed_everything(config.seed, workers=True)

    dataset = instantiate(config.data.dataset, split='eval')

    # df = pd.DataFrame(columns=['env', 'method', 'metric', 'value'])
    df = []
    device = config.device
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        for variant in dataset.variants:
            print('variant:', variant)
            train_loader = instantiate(config.data, dataset={'split': 'train', 'env': variant, 'fix_start': False})
            test_loader = instantiate(config.data, dataset={'split': 'eval', 'env': variant})
            accuracy = linear_probe(train_loader, test_loader, model, device, config.linearprobe.train_epoch)

            df_args = {'env': variant, 'method': 'calm'}
            df.append(dict(**df_args, metric='ActionACC', value=accuracy, ))
            
    
    save_dir = config.save_dir
    # create {save_dir} dicrectory
    os.makedirs(save_dir, exist_ok=True)
    df = pd.DataFrame(df)
    df.to_csv(os.path.join(save_dir, 'calm_actionacc.csv'), index=False)

    return model


if __name__ == '__main__':
    main()