
import numpy as np
import torch, torchvision
from torch import nn
from torch.autograd import Variable
from tqdm import tqdm
import random
import argparse
# %matplotlib inline
import csv
import os
import seaborn as sns
import pandas as pd
import json

import matplotlib.pyplot as plt
import torch.nn.functional as F

from dataset.syn_env import CausalControlDataset, AntiCausalControlDataset, CausalControlDescentDataset, AntiCausalControlDatasetMultiClass

from dataset.pacs import PACS
from models.bag import BAG

from trainers.bag_trainer import BAGTrainer

from misc import create_DF, standalone_tunning_test, fine_tunning_test, BaseLoss, initialize_torchvision_model, FolderDataset
import optuna
from optuna.trial import FrozenTrial, TrialState
from optuna.distributions import IntDistribution, FloatDistribution, CategoricalDistribution



from rich import pretty
pretty.install()


# seaborn stuff
err_sty = 'band'
# import torch
torch.backends.cudnn.enabled = False

def set_seed(seed):
  random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  np.random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

class Identity(nn.Module):
  """An identity layer"""
  def __init__(self):
      super(Identity, self).__init__()

  def forward(self, x):
      return x

class ResNet(torch.nn.Module):
   """ResNet with the softmax chopped off and the batchnorm frozen"""
   def __init__(self, model):
      super(ResNet, self).__init__()
      self.network = model
      self.freeze_bn()
      
   def forward(self, x):
    """Encode x into a feature vector of size n_outputs."""
    return self.network(x)
    
   def train(self, mode=True):
    """
    Override the default train() to freeze the BN parameters
    """
    super().train(mode)
    self.freeze_bn()
  
   def freeze_bn(self):
    for m in self.network.modules():
      if isinstance(m, nn.BatchNorm2d):
        m.eval()









      
def main(args):
    



    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    pretty.pprint(f"Using {args.device} device")

    
    
    args.model_save_dir = args.log
    os.makedirs(args.model_save_dir, exist_ok=True)
    pretty.pprint(f"Model save directory: {args.model_save_dir}")



    args.num_workers = args.nb_workers
    args.torch_loader = True

    env = PACS(args)
    train_dataset = env.train_data_list
    val_dataset = env.val_data_list
    test_finetune_dataset, test_unlabelled_dataset, test_dataset = env.sample_envs(train_val_test=2)

    criterion = torch.nn.CrossEntropyLoss()
    pretty.pprint(vars(args))

    input_dim = env.input_dim
    out_dim = env.num_class
    args.num_class = env.num_class
    args.model_kwargs = {'pretrained': True}

    Phi = initialize_torchvision_model(
        name='resnet18', 
        d_out=args.resnet_dim,
        **args.model_kwargs
    )
    args.phi_odim = Phi.d_out
    print(f"Phi output dimension: {args.phi_odim}, Phi: {Phi}"
          )
    Phi = ResNet(Phi)

    model = BAG(
        args.n_envs, input_dim, Phi, args, out_dim=out_dim, phi_dim=args.phi_odim
    )
    model.to(args.device)

    trainer = BAGTrainer(
        model, criterion, args.reg_lambda, args, causal_dir=False
    )

    test_loss = trainer.train(train_dataset, args.batch_size, test_dataset,log_dir=args.model_save_dir)


    return test_loss  



        
        

if __name__ == '__main__':
  main()


      





     