import torch
import numpy as np
from itertools import combinations, product
from sklearn.kernel_ridge import KernelRidge
from time import perf_counter
from tqdm import tqdm
import os
import uuid
import pickle

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist




class RFFModel(nn.Module):
    
    def __init__(self, frequencies):
        super().__init__()
        self.linear_cos = nn.Linear(len(frequencies), 1)
        self.linear_sin = nn.Linear(len(frequencies), 1)
        self.frequencies = frequencies.clone()
        
    def forward(self, x):
        dot_products = torch.matmul(x, self.frequencies.T)
        return self.linear_cos(torch.cos(dot_products)/np.sqrt(self.frequencies.shape[0])) + self.linear_sin(torch.sin(dot_products)/np.sqrt(self.frequencies.shape[0]))
    
    def move_freq_device(self, device):
        self.frequencies = self.frequencies.to(device)
    

def training_loop_single(rank, world_size, model, dataset, dataset_test, n_epochs=400, weight_decay=1e-8, lr=.001, batch_size=100, file_output=""):
    
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
    local_dev = 'cuda:'+str(rank) if torch.cuda.is_available() else 'cpu'
    model = model.to(local_dev)
    model.move_freq_device(local_dev)
    ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)
    
    sampler = DistributedSampler(dataset=dataset)
    dataloader = DataLoader(
        dataset, sampler=sampler, batch_size=batch_size, drop_last=False, shuffle=False, num_workers=world_size)
    
    dataloader_test = DataLoader(dataset_test, batch_size=batch_size)
    
    
    optimizer=torch.optim.Adam(ddp_model.parameters(), lr=lr, weight_decay=weight_decay)
    
    device = torch.device(local_dev)
    losses = np.zeros(n_epochs)
    print(rank, torch.cuda.memory_allocated(rank))

    for epoch in range(n_epochs):
        t0 = perf_counter()
        for x, y in dataloader:
            optimizer.zero_grad()
            pred = ddp_model(x.to(local_dev))
            loss = F.mse_loss(pred, y.to(local_dev))
            loss.backward()
            optimizer.step()
            del x, y
        t1 = perf_counter()
        del pred
        del loss
        torch.cuda.empty_cache()
        
        
        with torch.no_grad():
            loss_test = 0
            num_test = 0
            for x, y in dataloader_test:
                pred = ddp_model(x.to(local_dev))
                loss = F.mse_loss(pred, y.to(local_dev), reduction='sum')
                loss_test += loss
                num_test += len(x)
                del x, y
            losses[epoch] = loss_test.cpu().item() / num_test
        t2 = perf_counter()
        del pred
        del loss_test, loss
        torch.cuda.empty_cache()
        
        print(rank, epoch, torch.cuda.memory_allocated(rank))
        if rank==0:        
            outputs = []
            with torch.no_grad():
                for x, y in dataloader_test:
                    pred = model(x.to(device)).cpu().numpy()
                    outputs.append(pred)
                    del x, y
            del pred
            torch.cuda.empty_cache()
            np.save(f'{file_output}/output.npy', np.concatenate(outputs, axis=0))
            np.save(f'{file_output}/losses.npy', losses)
        print(f'Epoch {epoch} | Loss test {losses[epoch]} | Time {t1-t0} {t2-t1}')
        if epoch > 10:
            if np.abs(np.min(losses[epoch-10:epoch]) - np.max(losses[epoch-10:epoch])) < 1e-5:
                break
        print(rank, epoch, 'end', torch.cuda.memory_allocated(rank))
    return losses

def generate_ouput(d, N, n_data, n_sample, n_epochs=50, n_processes=1, lr=.01):
    

    torch.random.manual_seed(98)

    device = 'cuda:0'if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)
     
    n_data, n_sample = int(n_data), int(n_sample)
    
    frequencies = torch.tensor(list(product(torch.arange(N), repeat=d)), dtype=torch.float)
    grid = torch.rand(n_data, d) * 2 - 1
    
    true_function = torch.zeros((len(grid), 1))
    batch_size = max(1, int(1e9/(8*d * N**d)))

    n_freq = len(frequencies)
    coefs_cos = torch.rand(n_freq, 1).float() / np.sqrt(n_freq) 
    coefs_sin = torch.rand(n_freq, 1).float() / np.sqrt(n_freq)
    print(batch_size)
    for i, batch in enumerate(np.array_split(np.arange(len(grid)), max(len(grid)//(batch_size*12), 1))):
        print(i)
        dot_products = torch.matmul(grid[batch].to(device), frequencies.T.to(device))
        true_function[batch] += torch.matmul(torch.cos(dot_products), coefs_cos.to(device)).cpu()
        true_function[batch] += torch.matmul(torch.sin(dot_products), coefs_sin.to(device)).cpu()


    n_test = max(n_data//10, 1)

    train_set = np.arange(len(grid)-n_test)
    test_set = np.arange(len(grid)-n_test, len(grid))

    dataset = TensorDataset(grid.cpu()[train_set], true_function.reshape((-1, 1))[train_set])
    dataloader = DataLoader(dataset, batch_size=batch_size)

    dataset_test = TensorDataset(grid.cpu()[test_set], true_function.reshape((-1, 1))[test_set])
    dataloader_test = DataLoader(dataset_test, batch_size=batch_size)

    n_epochs = int(min(400, max(50, 400 * 5e4 / (N**d))))
    w = 1e-6
    
    exp_id = uuid.uuid4().hex[:10]
    file_output = f'results/{exp_id}'

    freq_samples = frequencies[np.random.choice(len(frequencies), size=(n_sample, ), replace=False)]
    model = RFFModel(freq_samples)

    os.mkdir(file_output)

    config={
        "N": N,
        "d": d,
        "n_sample": n_sample,
        "w": w,
        "n_epochs": n_epochs,
        "lr": lr,
        "n_data": n_data
        }
    with open(f'{file_output}/config.pickle', 'wb') as f:
        pickle.dump(config, f)
    
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '1909'
    mp.spawn(training_loop_single,
        args=(n_processes, model, dataset, dataset_test, n_epochs, w, lr, batch_size, file_output),
        nprocs=n_processes,
        join=True)
                    
                    
if __name__ == '__main__':
    
    for N in [int(1e4), int(1e5), int(1e6), int(1e7), int(1e8)]:
        n_sample_list = np.linspace(1, N, 11).astype(int)
        for n_sample in n_sample_list:
            generate_ouput(1, N, int(1e5), n_sample, 50, n_processes=1, lr=.001)
            
    N_list = np.arange(5, 12)
    d_list = np.arange(4, 8)
    
    for d in d_list:
        for N in N_list:
            n_sample_list = np.linspace(1, N, 11).astype(int)
            for n_sample in n_sample_list:
                generate_ouput(d, N, int(1e5), n_sample, 50, n_processes=1, lr=.001)
    
    #generate_ouput(7, 8, 1e7, 4.5e5, 50, n_processes=4)
    
    
    
    
