import sys
import os
import copy
cwd = os.getcwd()
sys.path.append(cwd)

from tqdm import tqdm
import torch.nn as nn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import wandb
import random
import numpy as np
import argparse
import datetime
from aesthetic_scorer import MLPDiff


class BaseNetwork(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.input_dim = input_dim
        # 
        self.layers =nn.Sequential( 
            nn.Linear(self.input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(16, 1),
        )

    def forward(self, x):
        return self.layers(x)

class BootstrappedNetwork(nn.Module):
    def __init__(self, input_dim, num_heads=4):
        super(BootstrappedNetwork, self).__init__()
        self.models = nn.ModuleList([BaseNetwork(input_dim) for _ in range(num_heads)])
    
    def forward(self, inputs, head_idx=None):
        if head_idx is None: # return all heads
            return [model(inputs) for model in self.models]
        else:  # return a specific head
            assert isinstance(head_idx, int)
            return self.models[head_idx](inputs[:,head_idx,:])

def bootstrapping(dataset, n_datasets=10):
    bootstrapped_data = []
    for _ in range(n_datasets):
        # Resample the dataset with replacement
        sampled_indices = [random.randint(0, len(dataset) - 1) for _ in range(len(dataset))]
        sampled_dataset = [dataset[i] for i in sampled_indices]
        bootstrapped_data.append(sampled_dataset)
    return bootstrapped_data

class BootstrappedDataset(Dataset):
    def __init__(self, bootstrapped_data):
        self.bootstrapped_data = bootstrapped_data

    def __len__(self):
        return len(self.bootstrapped_data[0])  # Assuming all datasets are of the same size

    def __getitem__(self, idx):
        # Retrieve the corresponding item from each dataset
        batch = [dataset[idx] for dataset in self.bootstrapped_data]
        inputs, targets = zip(*batch)
        return torch.stack(inputs), torch.stack(targets)

if __name__   == "__main__":
    parser = argparse.ArgumentParser()

    # Add arguments
    parser.add_argument('--num_epochs', type=int, default=100, help='number of epochs to train')
    parser.add_argument('--train_bs', type=int, default=256)
    parser.add_argument('--val_bs', type=int, default=512)
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--noise', type=float, default=0.1, help='noise level')
    parser.add_argument('--num_heads', type=int, default=4, help='number of heads')
    parser.add_argument('--run_name', type=str, default='test')

    args = parser.parse_args()

    unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
    if not args.run_name:
        args.run_name = unique_id
    else:
        args.run_name += "_" + unique_id

    wandb.init(project="offline_reward_aesthetic", name=args.run_name,
        config={
        'lr': args.lr,
        # 'num_data':args.num_data,
        'num_epochs':args.num_epochs,
        'train_batch_size':args.train_bs,
        'val_batch_size':args.val_bs,
    })

    x = np.load("./reward_aesthetic/data/ava_x_openclip_l14.npy")

    y = np.load("./reward_aesthetic/data/ava_y_openclip_l14.npy")

    val_percentage = 0.05 # 5% of the trainingdata will be used for validation

    train_border = int(x.shape[0] * (1 - val_percentage) )

    train_tensor_x = torch.Tensor(x[:train_border]) # transform to torch tensor
    train_tensor_y = torch.Tensor(y[:train_border])

    train_dataset = TensorDataset(train_tensor_x,train_tensor_y) # create your datset
    
    bootstrapped_traindata = bootstrapping(train_dataset, n_datasets=args.num_heads)
    bootstrapped_trainset = BootstrappedDataset(bootstrapped_traindata)
    
    train_loader = DataLoader(bootstrapped_trainset, batch_size=args.train_bs
                , shuffle=True,  num_workers=16) # create your dataloader

    val_tensor_x = torch.Tensor(x[train_border:]) # transform to torch tensor
    val_tensor_y = torch.Tensor(y[train_border:])
    val_dataset = TensorDataset(val_tensor_x,val_tensor_y)
    
    bootstrapped_valdata = bootstrapping(val_dataset, n_datasets=args.num_heads)
    bootstrapped_valset = BootstrappedDataset(bootstrapped_valdata)
    
    val_loader = DataLoader(bootstrapped_valset, batch_size=args.val_bs
                , num_workers=16)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = BootstrappedNetwork(input_dim=768, num_heads=args.num_heads).to(device)
    
    optimizer = torch.optim.Adam(model.parameters()) 
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

    # choose the loss you want to optimze for
    criterion = nn.MSELoss()
    criterion2 = nn.L1Loss()

    model.train()
    best_loss = 999
    
    eval_model = MLPDiff().to(device)
    eval_model.requires_grad_(False)
    eval_model.eval()
    s = torch.load("./reward_aesthetic/backup/sac+logos+ava1-l14-linearMSE.pth")   # load the model you trained previously or the model available in this repo
    eval_model.load_state_dict(s)
    

    for epoch in tqdm(range(args.num_epochs), desc="Epochs"):
        model.train()
        losses = []
        losses2 = []
        save_name = f'./reward_aesthetic/models/{args.run_name}_{epoch+1}.pth'
        
        for batch_num, (inputs,_) in enumerate(tqdm(train_loader,
                                desc=f"Epoch {epoch+1}/{args.num_epochs}")):
            optimizer.zero_grad()
            
            inputs = inputs.to(device).float()
            
            loss = 0
            for i in range(args.num_heads):
                output = model(inputs, head_idx=i)
                raw_target = eval_model(inputs[:,i,:]).to(device)
                noisy_target = raw_target + torch.randn_like(raw_target, device=device)*args.noise
                loss += criterion(output, noisy_target.detach())

            loss /= args.num_heads
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())

            if batch_num % 1000 == 0:
                print('\tEpoch %d | Batch %d | Loss %6.2f' % (epoch, batch_num, loss.item()))
                wandb.log({"batch_loss": loss.item()})

        print('Epoch %d | Loss %6.2f' % (epoch, sum(losses)/len(losses)))
        wandb.log({"epoch": epoch, "mean_batch_loss": sum(losses)/len(losses)})
        losses = []
        losses2 = []
        
        for batch_num, (inputs,_) in enumerate(val_loader):
            model.eval()
            optimizer.zero_grad()
            
            inputs = inputs.to(device).float()
            # targets = [eval_model(input).to(device) for input in inputs]
            
            loss = 0
            lossMAE = 0
            
            for i in range(args.num_heads):
                output = model(inputs, head_idx=i)
                target = eval_model(inputs[:,i,:]).to(device)
                # noisy_y = target + torch.randn_like(target,device=device) * args.noise
                
                loss += criterion(output, target.detach())
                lossMAE += criterion2(output, target.detach())
            loss /= args.num_heads
            lossMAE /= args.num_heads
            
            losses.append(loss.item())
            losses2.append(lossMAE.item())

            if batch_num % 1000 == 0:
                print('\tValidation - Epoch %d | Batch %d | MSE Loss %6.4f' % (epoch, batch_num, loss.item()))
                print('\tValidation - Epoch %d | Batch %d | MAE Loss %6.4f' % (epoch, batch_num, lossMAE.item()))
                
                #print(y)

        print('Validation - Epoch %d | MSE Loss %6.4f' % (epoch, sum(losses)/len(losses)))
        print('Validation - Epoch %d | MAE Loss %6.4f' % (epoch, sum(losses2)/len(losses2)))
        
        if sum(losses2)/len(losses2) < best_loss:
            print("Best MAE Val loss so far. Saving model")
            best_loss = sum(losses2)/len(losses2)
            print( best_loss ) 

            torch.save(model.state_dict(), save_name )

        scheduler.step(sum(losses)/len(losses))
        
    torch.save(model.state_dict(), save_name)

    print( best_loss ) 

    print("training done")
